Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Full documentation for MIGraphX is available at
* Added `--log-stdout` flag to migraphx-driver to log to stdout instead of stderr (#4959).
* Added slice squeeze matcher to propogate squeeze downstream and allow for parallel branches to merge together (#5004)
* Added GPU kernel for ONNX `NonMaxSuppression` operation and redesigned the `nonmaxsuppression` operation to better represent the data-dependent output shape in the MIGraphX IR (#4893).
* Added opt-in hipGraph capture/replay for the decode eval path (`MIGRAPHX_ENABLE_HIPGRAPH`, default off): the single-context per-token kernel sequence is captured into a hipGraph once and replayed with one launch per eval, cutting host dispatch overhead. Gated to fp16; quantized (int4/fp4) programs run the eager path.

### Changed

Expand Down
43 changes: 43 additions & 0 deletions src/include/migraphx/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ bool is_cross_compile_context(const T&)
return false;
}

// default execution hook -- a target without its own execute() just runs the
// eval kernel loop eagerly. The gpu target overrides this (via a member execute())
// to optionally capture/replay the loop as a hipGraph.
template <class T>
void execute_context(T&, const std::function<void()>& run_kernels)
{
run_kernels();
}

#ifdef TYPE_ERASED_DECLARATION

// Type-erased interface for:
Expand All @@ -114,6 +123,8 @@ struct MIGRAPHX_EXPORT context
void finish_on(any_ptr queue);
// (optional)
bool is_cross_compile() const;
// (optional)
void execute(const std::function<void()>& run_kernels);
//
void finish() const;
};
Expand Down Expand Up @@ -231,6 +242,21 @@ struct context
return is_cross_compile_context(private_detail_te_self);
}

template <class T>
static auto private_detail_te_default_execute(
char, T&& private_detail_te_self, const std::function<void()>& run_kernels)
-> decltype(private_detail_te_self.execute(run_kernels))
{
private_detail_te_self.execute(run_kernels);
}

template <class T>
static void private_detail_te_default_execute(
float, T&& private_detail_te_self, const std::function<void()>& run_kernels)
{
execute_context(private_detail_te_self, run_kernels);
}

template <class PrivateDetailTypeErasedT>
struct private_te_unwrap_reference
{
Expand Down Expand Up @@ -264,6 +290,10 @@ struct context
char(0), std::declval<PrivateDetailTypeErasedT>(), std::declval<any_ptr>()),
private_detail_te_default_is_cross_compile(
char(0), std::declval<PrivateDetailTypeErasedT>()),
private_detail_te_default_execute(
char(0),
std::declval<PrivateDetailTypeErasedT>(),
std::declval<const std::function<void()>&>()),
std::declval<PrivateDetailTypeErasedT>().finish(),
void());

Expand Down Expand Up @@ -387,6 +417,12 @@ struct context
return (*this).private_detail_te_get_handle().is_cross_compile();
}

void execute(const std::function<void()>& run_kernels)
{
assert((*this).private_detail_te_handle_mem_var);
(*this).private_detail_te_get_handle().execute(run_kernels);
}

void finish() const
{
assert((*this).private_detail_te_handle_mem_var);
Expand Down Expand Up @@ -414,6 +450,7 @@ struct context
virtual void wait_for(any_ptr queue) = 0;
virtual void finish_on(any_ptr queue) = 0;
virtual bool is_cross_compile() const = 0;
virtual void execute(const std::function<void()>& run_kernels) = 0;
virtual void finish() const = 0;
};

Expand Down Expand Up @@ -492,6 +529,12 @@ struct context
return private_detail_te_default_is_cross_compile(char(0), private_detail_te_value);
}

void execute(const std::function<void()>& run_kernels) override
{

private_detail_te_default_execute(char(0), private_detail_te_value, run_kernels);
}

void finish() const override { private_detail_te_value.finish(); }

PrivateDetailTypeErasedT private_detail_te_value;
Expand Down
23 changes: 23 additions & 0 deletions src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ struct program_impl
std::unordered_map<std::string, module> modules;
std::vector<context> contexts;
std::vector<target> targets;
// cached eval results for the hipGraph replay path. On the capture eval we
// record the output arguments (which point at fixed device buffers); on every
// subsequent replay (hipGraphLaunch) the host-side generic_eval loop is skipped,
// so we return these cached arguments -- valid because graph replay reuses the
// same fixed input/output device buffers (static-shape decode).
std::vector<argument> graph_cached_results;
};

program::program() : impl(std::make_unique<program_impl>()) { this->create_module("main"); }
Expand Down Expand Up @@ -704,6 +710,23 @@ std::vector<argument> program::eval(const parameter_map& params,
return result;
});
}
else if(contexts.size() == 1)
{
// route the single-context eval (the EP decode path, async or not)
// through context.execute, which optionally captures the kernel loop into a
// hipGraph (gated by MIGRAPHX_ENABLE_HIPGRAPH inside the gpu context) and
// replays it on later evals. On a replay the host-side generic_eval is
// skipped, so we cache and reuse the output arguments -- valid because graph
// replay reuses the same fixed input/output device buffers (static-shape
// decode). When the env flag is off, context.execute just runs the loop
// eagerly, i.e. behavior is byte-identical to the prior path.
contexts.front().execute([&] {
ret = generic_eval(*this, contexts, params, [&](auto&&, auto f) { return f(); });
impl->graph_cached_results = ret;
});
if(ret.empty())
ret = impl->graph_cached_results;
Comment on lines +723 to +728
}
else
{
ret = generic_eval(*this, contexts, params, [&](auto&&, auto f) { return f(); });
Expand Down
22 changes: 22 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,28 @@
const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name();
const bool is_navi = starts_with(device_name, "gfx11") or starts_with(device_name, "gfx12");

// mark this program non-capturable for hipGraph if it contains any quantized /
// low-bit op. hipGraph capture/replay regresses int4/fp4 decode substantially (up to
// ~2x slower than the eager path on discrete GPUs). Allowlist-by-absence: any quantized
// program (int4 nibble unpack, fp4, or the int8/int4 dequantize+quant_dot path that has
// no nibble unpack) takes the eager path; only fp16 (none of these ops) captures.
Comment on lines +1509 to +1511
// Scanned here -- before fusion consumes these into a code_object whose name no longer
// reveals them -- and recorded on the (shared) context so the hipGraph path (gated by
// context::is_graph_enabled) skips capture.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to capture all kernels and not certain non fused ones, otherwise you'll still get idle bubbles in the pipeline of runs.

if(ctx != nullptr)
{
static const std::array<std::string, 4> low_bit_ops = {
{"unpack_int4", "unpack_fp4", "dequantizelinear", "quant_dot"}};
for(const auto& ins : mpm.get_module())
{
if(contains(low_bit_ops, ins.name()))
{

Check warning on line 1522 in src/targets/gpu/fuse_mlir.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Consider using std::any_of algorithm instead of a raw loop. [useStlAlgorithm]
ctx->set_graph_not_capturable();
break;
}
}
Comment on lines +1517 to +1526
}

auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
if(specific_op<rejected>(option))
return mlir_mode::none;
Expand Down
112 changes: 110 additions & 2 deletions src/targets/gpu/include/migraphx/gpu/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,25 @@
#include <unordered_map>
#include <memory>
#include <optional>
#include <functional>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NULL_STREAM)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_NSTREAMS)

using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
// opt-in hipGraph capture/replay of the eval kernel sequence. When set, a
// steady-state (static-shape, fixed-buffer) eval is captured into a hipGraph once
// and replayed with a single hipGraphLaunch per subsequent eval -- collapsing the
// ~52 per-token individual kernel dispatches into one host submit (cuts dispatch
// overhead + the GPU-clock throttle the per-dispatch bubbles cause). Additive: the
// default per-op eval path is unchanged when this is off.
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPGRAPH)

using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy);
using hip_graph_ptr = MIGRAPHX_MANAGE_PTR(hipGraph_t, hipGraphDestroy);
using hip_graph_exec_ptr = MIGRAPHX_MANAGE_PTR(hipGraphExec_t, hipGraphExecDestroy);

struct hip_device
{
Expand Down Expand Up @@ -482,6 +492,98 @@ struct context
pc->auto_save = true;
}

// hipGraph capture/replay -------------------------------------------
// is_graph_enabled(): opt-in via MIGRAPHX_ENABLE_HIPGRAPH, only when not
// cross-compiling, and only for a capturable program (see graph_capturable).
// has_graph(): a graph has already been captured+instantiated for this context,
// so eval should replay instead of re-running the per-op loop.
bool is_graph_enabled() const
{
return enabled(MIGRAPHX_ENABLE_HIPGRAPH{}) and not is_cross_compile() and graph_capturable;
}
bool has_graph() const { return graph_exec != nullptr; }

// Capture eligibility, set at compile time (fuse_mlir). hipGraph capture/replay
// regresses low-bit-quantized (int4/fp4) decode substantially (up to ~2x slower than
// the eager per-op path on discrete GPUs), so a program that fuses any low-bit dequant
// op is marked non-capturable and runs the eager path; fp16 still captures.
bool is_graph_capturable() const { return graph_capturable; }

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't occur and signals a bug - We should be agnostic between what the lower level library is doing on the capture. Since MIGraphX is working on the higher level than MLIR and other libraries we should see a reduction in launch execution overall.

We do fusions to further optimize the model and keep the GPU queue full so it doesnt go idle.

void set_graph_not_capturable() { graph_capturable = false; }

// Begin capturing all kernel launches on the current stream into a hipGraph.
// Uses ThreadLocal mode so only this stream's work is captured. Caller runs the
// normal eval kernel loop between begin/end; the terminal stream sync stays
// OUTSIDE the capture (capture must contain only async work).
void begin_graph_capture()
{
auto status = hipStreamBeginCapture(get_stream().get(), hipStreamCaptureModeThreadLocal);
if(status != hipSuccess)
MIGRAPHX_THROW("hipStreamBeginCapture failed: " + hip_error(status));
}

// End capture, instantiate the executable graph, and cache it. Returns false
// (without throwing) if capture produced no usable graph, so the caller can
// fall back to having just executed the work eagerly during capture.
bool end_graph_capture()
{
hipGraph_t g = nullptr;
auto status = hipStreamEndCapture(get_stream().get(), &g);
if(status != hipSuccess or g == nullptr)
return false;
captured_graph = hip_graph_ptr{g};
hipGraphExec_t exec = nullptr;
status = hipGraphInstantiate(&exec, g, nullptr, nullptr, 0);
if(status != hipSuccess or exec == nullptr)
{
captured_graph.reset();
return false;
}
graph_exec = hip_graph_exec_ptr{exec};
return true;
}

// Replay the captured graph (one host submit for the whole kernel sequence).
void replay_graph()
{
auto status = hipGraphLaunch(graph_exec.get(), get_stream().get());
if(status != hipSuccess)
MIGRAPHX_THROW("hipGraphLaunch failed: " + hip_error(status));
}

// single execution entry point used by program::eval. When hipGraph mode
// is off (env flag unset, cross-compiling, or the program is not capturable -- see
// is_graph_enabled / graph_capturable), just runs the eval kernel loop (run_kernels)
// eagerly -- identical to the prior behavior. When on: the first call captures the loop
// into a hipGraph and instantiates it (still executing it eagerly during capture);
// subsequent calls replay the cached graph with one host submit and SKIP the loop
// entirely. The terminal stream sync is the caller's job and stays outside the
// captured region.
void execute(const std::function<void()>& run_kernels)
{
if(not is_graph_enabled())
{
run_kernels();
return;
}
if(has_graph())
{
replay_graph();
return;
}
// First eval: capture the loop into a graph. NOTE under hipStreamBeginCapture
// the kernel launches are RECORDED, not executed -- so run_kernels() here
// produces no output; we must launch the instantiated graph once to actually
// compute this first token. If capture/instantiate fails, fall back to a real
// eager run so the first token is still correct (and future evals stay eager
// since graph_exec remains null).
Comment on lines +573 to +578
begin_graph_capture();
run_kernels();
if(end_graph_capture())
replay_graph();
else
run_kernels(); // capture failed -> eager fallback (graph_exec stays null)
}

private:
// TODO: Make this a vector to support multiple devices
std::shared_ptr<hip_device> current_device;
Expand All @@ -495,6 +597,12 @@ struct context
shared<hip_event_ptr> begin_event = nullptr;
shared<hip_event_ptr> finish_event = nullptr;
std::shared_ptr<auto_save_problem_cache> pc = nullptr;
// hipGraph capture/replay state (lifetime tied to the context, which the
// EP keeps alive across decode-step evals on the same program).
shared<hip_graph_ptr> captured_graph = nullptr;
shared<hip_graph_exec_ptr> graph_exec = nullptr;
// capture eligibility (set false at compile time for low-bit/int4 programs).
bool graph_capturable = true;
};

inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }
Expand Down
Loading