Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,19 +984,26 @@ struct time_cmd : command<time_cmd>
{
compiler c;
unsigned n = 100;
unsigned nbuffers = 1;
void parse(argument_parser& ap)
{
ap(n, {"--iterations", "-n"}, ap.help("Number of iterations to run."));
ap(nbuffers, {"--buffers", "-b"}, ap.help("Number of rotated buffers to use."));
c.parse(ap);
}

void run()
{
auto p = c.compile();
log::info() << "Allocating params ...";
auto m = c.params(p);
std::vector<parameter_map> ms;
for(auto i : range(nbuffers))
{
(void)i;
ms.push_back(c.params(p));
}
log::info() << "Running ...";
double t = time_run(p, m, n);
double t = time_run(p, ms, n);
std::cout << "Total time: " << t << "ms" << std::endl;
}
};
Expand Down
7 changes: 3 additions & 4 deletions src/driver/perf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,15 @@ bool is_offload_copy_set(const program& p)
return param_ins.empty();
}

double time_run(const program& p, const parameter_map& m, int n)
double time_run(const program& p, const std::vector<parameter_map>& ms, int n)
{
// Run once without timing
p.eval(m);
p.eval(ms.back());
p.finish();
double total = time<milliseconds>([&] {
for(auto i : range(n))
{
(void)i;
p.eval(m);
p.eval(ms[i % ms.size()]);
}
p.finish();
});
Expand Down
2 changes: 1 addition & 1 deletion src/driver/perf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ target get_target(bool gpu);
*/
bool is_offload_copy_set(const program& p);

double time_run(const program& p, const parameter_map& m, int n = 100);
double time_run(const program& p, const std::vector<parameter_map>& ms, int n = 100);

} // namespace MIGRAPHX_INLINE_NS
} // namespace driver
Expand Down
42 changes: 41 additions & 1 deletion src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,54 @@
#include <migraphx/kernels/args.hpp>
#include <migraphx/kernels/tile.hpp>
#include <migraphx/kernels/tuple.hpp>
#include <migraphx/kernels/bit_cast.hpp>

namespace migraphx {

// Unsigned integer with the same size as T, used to move struct element types (eg fp8)
// through a builtin that only accepts arithmetic and vector types.
template <class T>
using nontemporal_storage = conditional_t<
sizeof(T) == 1,
uint8_t,
conditional_t<sizeof(T) == 2, uint16_t, conditional_t<sizeof(T) == 4, uint32_t, uint64_t>>>;

// Load a single value with a nontemporal hint so it bypasses the cache. The builtin only
// accepts arithmetic and vector types, so any other trivially-copyable type is loaded
// through a same-sized integer and bit-cast back.
template <class T>
__device__ T nontemporal_load(const T* ptr)
{
if constexpr(is_integral<T>{} or is_floating_point<T>{} or is_any_vec<T>())
{
return __builtin_nontemporal_load(ptr);
}
else
{
static_assert(is_trivially_copyable<T>{});
using storage = nontemporal_storage<T>;
static_assert(sizeof(storage) == sizeof(T));
return bit_cast<T>(__builtin_nontemporal_load(reinterpret_cast<const storage*>(ptr)));
}
}

// Read an element from an input tensor. Inputs that are not broadcasted are read only
// once, so a nontemporal load avoids polluting the cache. Broadcasted inputs reuse the
// same element across threads, so a regular cached load is kept for them.
template <class T, class I>
__device__ auto pointwise_load(const T& x, I i)
{
if constexpr(get_shape_c<T>{}.broadcasted())
return x[i];
else
return nontemporal_load(&x[i]);
}

template <class Stride, class F, class Output, class T, class... Ts>
__device__ void pointwise_tensor(Stride stride, F f, Output out, T x, Ts... xs)
{
stride(x.get_shape().elements(), [&](auto i) {
auto r = f(x[i], xs[i]...);
auto r = f(pointwise_load(x, i), pointwise_load(xs, i)...);
out([&](auto... outs) {
r([&](auto... rs) {
static_assert(sizeof...(outs) == sizeof...(rs));
Expand Down
Loading