diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fed2159bb3..87673f00161 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ Full documentation for MIGraphX is available at * Added `--log-stdout` flag to migraphx-driver to log to stdout instead of stderr (#4959). * Added slice squeeze matcher to propogate squeeze downstream and allow for parallel branches to merge together (#5004) * Added GPU kernel for ONNX `NonMaxSuppression` operation and redesigned the `nonmaxsuppression` operation to better represent the data-dependent output shape in the MIGraphX IR (#4893). +* Added opt-in hipGraph capture/replay for the decode eval path (`MIGRAPHX_ENABLE_HIPGRAPH`, default off): the single-context per-token kernel sequence is captured into a hipGraph once and replayed with one launch per eval, cutting host dispatch overhead. Gated to fp16; quantized (int4/fp4) programs run the eager path. ### Changed diff --git a/src/include/migraphx/context.hpp b/src/include/migraphx/context.hpp index bcc679493c1..f9c563dec17 100644 --- a/src/include/migraphx/context.hpp +++ b/src/include/migraphx/context.hpp @@ -93,6 +93,15 @@ bool is_cross_compile_context(const T&) return false; } +// default execution hook -- a target without its own execute() just runs the +// eval kernel loop eagerly. The gpu target overrides this (via a member execute()) +// to optionally capture/replay the loop as a hipGraph. +template +void execute_context(T&, const std::function& run_kernels) +{ + run_kernels(); +} + #ifdef TYPE_ERASED_DECLARATION // Type-erased interface for: @@ -114,6 +123,8 @@ struct MIGRAPHX_EXPORT context void finish_on(any_ptr queue); // (optional) bool is_cross_compile() const; + // (optional) + void execute(const std::function& run_kernels); // void finish() const; }; @@ -231,6 +242,21 @@ struct context return is_cross_compile_context(private_detail_te_self); } + template + static auto private_detail_te_default_execute( + char, T&& private_detail_te_self, const std::function& run_kernels) + -> decltype(private_detail_te_self.execute(run_kernels)) + { + private_detail_te_self.execute(run_kernels); + } + + template + static void private_detail_te_default_execute( + float, T&& private_detail_te_self, const std::function& run_kernels) + { + execute_context(private_detail_te_self, run_kernels); + } + template struct private_te_unwrap_reference { @@ -264,6 +290,10 @@ struct context char(0), std::declval(), std::declval()), private_detail_te_default_is_cross_compile( char(0), std::declval()), + private_detail_te_default_execute( + char(0), + std::declval(), + std::declval&>()), std::declval().finish(), void()); @@ -387,6 +417,12 @@ struct context return (*this).private_detail_te_get_handle().is_cross_compile(); } + void execute(const std::function& run_kernels) + { + assert((*this).private_detail_te_handle_mem_var); + (*this).private_detail_te_get_handle().execute(run_kernels); + } + void finish() const { assert((*this).private_detail_te_handle_mem_var); @@ -414,6 +450,7 @@ struct context virtual void wait_for(any_ptr queue) = 0; virtual void finish_on(any_ptr queue) = 0; virtual bool is_cross_compile() const = 0; + virtual void execute(const std::function& run_kernels) = 0; virtual void finish() const = 0; }; @@ -492,6 +529,12 @@ struct context return private_detail_te_default_is_cross_compile(char(0), private_detail_te_value); } + void execute(const std::function& run_kernels) override + { + + private_detail_te_default_execute(char(0), private_detail_te_value, run_kernels); + } + void finish() const override { private_detail_te_value.finish(); } PrivateDetailTypeErasedT private_detail_te_value; diff --git a/src/program.cpp b/src/program.cpp index 4e1caa81fa1..5eaa76cbd1c 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -78,6 +78,12 @@ struct program_impl std::unordered_map modules; std::vector contexts; std::vector targets; + // cached eval results for the hipGraph replay path. On the capture eval we + // record the output arguments (which point at fixed device buffers); on every + // subsequent replay (hipGraphLaunch) the host-side generic_eval loop is skipped, + // so we return these cached arguments -- valid because graph replay reuses the + // same fixed input/output device buffers (static-shape decode). + std::vector graph_cached_results; }; program::program() : impl(std::make_unique()) { this->create_module("main"); } @@ -704,6 +710,23 @@ std::vector program::eval(const parameter_map& params, return result; }); } + else if(contexts.size() == 1) + { + // route the single-context eval (the EP decode path, async or not) + // through context.execute, which optionally captures the kernel loop into a + // hipGraph (gated by MIGRAPHX_ENABLE_HIPGRAPH inside the gpu context) and + // replays it on later evals. On a replay the host-side generic_eval is + // skipped, so we cache and reuse the output arguments -- valid because graph + // replay reuses the same fixed input/output device buffers (static-shape + // decode). When the env flag is off, context.execute just runs the loop + // eagerly, i.e. behavior is byte-identical to the prior path. + contexts.front().execute([&] { + ret = generic_eval(*this, contexts, params, [&](auto&&, auto f) { return f(); }); + impl->graph_cached_results = ret; + }); + if(ret.empty()) + ret = impl->graph_cached_results; + } else { ret = generic_eval(*this, contexts, params, [&](auto&&, auto f) { return f(); }); diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 61827ca4838..23bafe1564a 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -1504,6 +1504,28 @@ void fuse_mlir::apply(module_pass_manager& mpm) const const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name(); const bool is_navi = starts_with(device_name, "gfx11") or starts_with(device_name, "gfx12"); + // mark this program non-capturable for hipGraph if it contains any quantized / + // low-bit op. hipGraph capture/replay regresses int4/fp4 decode substantially (up to + // ~2x slower than the eager path on discrete GPUs). Allowlist-by-absence: any quantized + // program (int4 nibble unpack, fp4, or the int8/int4 dequantize+quant_dot path that has + // no nibble unpack) takes the eager path; only fp16 (none of these ops) captures. + // Scanned here -- before fusion consumes these into a code_object whose name no longer + // reveals them -- and recorded on the (shared) context so the hipGraph path (gated by + // context::is_graph_enabled) skips capture. + if(ctx != nullptr) + { + static const std::array low_bit_ops = { + {"unpack_int4", "unpack_fp4", "dequantizelinear", "quant_dot"}}; + for(const auto& ins : mpm.get_module()) + { + if(contains(low_bit_ops, ins.name())) + { + ctx->set_graph_not_capturable(); + break; + } + } + } + auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) { if(specific_op(option)) return mlir_mode::none; diff --git a/src/targets/gpu/include/migraphx/gpu/context.hpp b/src/targets/gpu/include/migraphx/gpu/context.hpp index 75c00176437..a8458334b17 100644 --- a/src/targets/gpu/include/migraphx/gpu/context.hpp +++ b/src/targets/gpu/include/migraphx/gpu/context.hpp @@ -44,6 +44,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -51,8 +52,17 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NULL_STREAM) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_NSTREAMS) - -using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy); +// opt-in hipGraph capture/replay of the eval kernel sequence. When set, a +// steady-state (static-shape, fixed-buffer) eval is captured into a hipGraph once +// and replayed with a single hipGraphLaunch per subsequent eval -- collapsing the +// ~52 per-token individual kernel dispatches into one host submit (cuts dispatch +// overhead + the GPU-clock throttle the per-dispatch bubbles cause). Additive: the +// default per-op eval path is unchanged when this is off. +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPGRAPH) + +using hip_event_ptr = MIGRAPHX_MANAGE_PTR(hipEvent_t, hipEventDestroy); +using hip_graph_ptr = MIGRAPHX_MANAGE_PTR(hipGraph_t, hipGraphDestroy); +using hip_graph_exec_ptr = MIGRAPHX_MANAGE_PTR(hipGraphExec_t, hipGraphExecDestroy); struct hip_device { @@ -482,6 +492,98 @@ struct context pc->auto_save = true; } + // hipGraph capture/replay ------------------------------------------- + // is_graph_enabled(): opt-in via MIGRAPHX_ENABLE_HIPGRAPH, only when not + // cross-compiling, and only for a capturable program (see graph_capturable). + // has_graph(): a graph has already been captured+instantiated for this context, + // so eval should replay instead of re-running the per-op loop. + bool is_graph_enabled() const + { + return enabled(MIGRAPHX_ENABLE_HIPGRAPH{}) and not is_cross_compile() and graph_capturable; + } + bool has_graph() const { return graph_exec != nullptr; } + + // Capture eligibility, set at compile time (fuse_mlir). hipGraph capture/replay + // regresses low-bit-quantized (int4/fp4) decode substantially (up to ~2x slower than + // the eager per-op path on discrete GPUs), so a program that fuses any low-bit dequant + // op is marked non-capturable and runs the eager path; fp16 still captures. + bool is_graph_capturable() const { return graph_capturable; } + void set_graph_not_capturable() { graph_capturable = false; } + + // Begin capturing all kernel launches on the current stream into a hipGraph. + // Uses ThreadLocal mode so only this stream's work is captured. Caller runs the + // normal eval kernel loop between begin/end; the terminal stream sync stays + // OUTSIDE the capture (capture must contain only async work). + void begin_graph_capture() + { + auto status = hipStreamBeginCapture(get_stream().get(), hipStreamCaptureModeThreadLocal); + if(status != hipSuccess) + MIGRAPHX_THROW("hipStreamBeginCapture failed: " + hip_error(status)); + } + + // End capture, instantiate the executable graph, and cache it. Returns false + // (without throwing) if capture produced no usable graph, so the caller can + // fall back to having just executed the work eagerly during capture. + bool end_graph_capture() + { + hipGraph_t g = nullptr; + auto status = hipStreamEndCapture(get_stream().get(), &g); + if(status != hipSuccess or g == nullptr) + return false; + captured_graph = hip_graph_ptr{g}; + hipGraphExec_t exec = nullptr; + status = hipGraphInstantiate(&exec, g, nullptr, nullptr, 0); + if(status != hipSuccess or exec == nullptr) + { + captured_graph.reset(); + return false; + } + graph_exec = hip_graph_exec_ptr{exec}; + return true; + } + + // Replay the captured graph (one host submit for the whole kernel sequence). + void replay_graph() + { + auto status = hipGraphLaunch(graph_exec.get(), get_stream().get()); + if(status != hipSuccess) + MIGRAPHX_THROW("hipGraphLaunch failed: " + hip_error(status)); + } + + // single execution entry point used by program::eval. When hipGraph mode + // is off (env flag unset, cross-compiling, or the program is not capturable -- see + // is_graph_enabled / graph_capturable), just runs the eval kernel loop (run_kernels) + // eagerly -- identical to the prior behavior. When on: the first call captures the loop + // into a hipGraph and instantiates it (still executing it eagerly during capture); + // subsequent calls replay the cached graph with one host submit and SKIP the loop + // entirely. The terminal stream sync is the caller's job and stays outside the + // captured region. + void execute(const std::function& run_kernels) + { + if(not is_graph_enabled()) + { + run_kernels(); + return; + } + if(has_graph()) + { + replay_graph(); + return; + } + // First eval: capture the loop into a graph. NOTE under hipStreamBeginCapture + // the kernel launches are RECORDED, not executed -- so run_kernels() here + // produces no output; we must launch the instantiated graph once to actually + // compute this first token. If capture/instantiate fails, fall back to a real + // eager run so the first token is still correct (and future evals stay eager + // since graph_exec remains null). + begin_graph_capture(); + run_kernels(); + if(end_graph_capture()) + replay_graph(); + else + run_kernels(); // capture failed -> eager fallback (graph_exec stays null) + } + private: // TODO: Make this a vector to support multiple devices std::shared_ptr current_device; @@ -495,6 +597,12 @@ struct context shared begin_event = nullptr; shared finish_event = nullptr; std::shared_ptr pc = nullptr; + // hipGraph capture/replay state (lifetime tied to the context, which the + // EP keeps alive across decode-step evals on the same program). + shared captured_graph = nullptr; + shared graph_exec = nullptr; + // capture eligibility (set false at compile time for low-bit/int4 programs). + bool graph_capturable = true; }; inline void migraphx_to_value(value& v, const context& ctx) { v = ctx.to_value(); }