Improve horizontal gather fusion: table dedup, ordering, and mixed-length table gathers#5036
Improve horizontal gather fusion: table dedup, ordering, and mixed-length table gathers#5036TedThemistokleous wants to merge 1 commit into
Conversation
…ngth merge Rework fuse_horizontal so independent embedding gathers bundle more effectively without disturbing surrounding fusion. - gather_horizontal_fusion now deduplicates shared embedding tables: it concatenates only the distinct tables (first-appearance order) and maps each gather to its table's row offset, instead of emitting one table copy per gather. Shared tables are no longer replicated, and a group that resolves to a single table skips the concat entirely. - Run cross-embedding fusion before same-table fusion. Same-table fusion (min group 2) previously ran first and consumed the same-table subset of a group, stranding the remaining same-embedding-dim siblings below cross-table fusion's min group of 4. Running cross-table first lets a group spanning several tables bundle together (table dedup keeps each table once); same-table fusion then mops up the smaller same-instance groups. - Add an opt-in same-table mixed-length merge (fuse_horizontal.merge_mixed_lengths, enabled in the GPU pipeline): same-table gathers whose indices differ only in shape are merged by flattening each index to 1-D, one batched gather, then slice + reshape back. Because the table is shared this adds no index offset, and uniform-shape groups still take the cheaper concat path. - Gate find_gather_slice_concat with simplify_reshapes.enable_gather_slice_concat (on by default) and disable it only in the pre-fuse_horizontal GPU pass, so the concat(slice(gather)) rewrite cannot reshape the gather/index graph that horizontal gather fusion groups on. Adds fuse_horizontal_test cases for table dedup and mixed-length same-table fusion. Co-authored-by: Cursor <cursoragent@cursor.com>
| std::transform( | ||
| indices.begin(), indices.end(), flat_inputs.begin(), [&](instruction_ref idx) { | ||
| if(idx->get_shape().lens().size() == 1) | ||
| return idx; | ||
| std::int64_t n = idx->get_shape().elements(); | ||
| return m.insert_instruction(insert_pt, make_op("reshape", {{"dims", {n}}}), idx); | ||
| }); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| std::transform( | |
| indices.begin(), indices.end(), flat_inputs.begin(), [&](instruction_ref idx) { | |
| if(idx->get_shape().lens().size() == 1) | |
| return idx; | |
| std::int64_t n = idx->get_shape().elements(); | |
| return m.insert_instruction(insert_pt, make_op("reshape", {{"dims", {n}}}), idx); | |
| }); | |
| std::transform(indices.begin(), indices.end(), flat_inputs.begin(), [&](instruction_ref idx) { | |
| if(idx->get_shape().lens().size() == 1) | |
| return idx; | |
| std::int64_t n = idx->get_shape().elements(); | |
| return m.insert_instruction(insert_pt, make_op("reshape", {{"dims", {n}}}), idx); | |
| }); |
| auto big_idx = | ||
| m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), flat_inputs); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto big_idx = | |
| m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), flat_inputs); | |
| auto big_idx = m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), flat_inputs); |
| std::transform( | ||
| gathers.begin(), | ||
| gathers.end(), | ||
| slice_starts.begin(), | ||
| results.begin(), | ||
| [&](instruction_ref g, std::size_t start) -> instruction_ref { | ||
| const auto& idx_lens = g->inputs().at(1)->get_shape().lens(); | ||
| std::size_t n = g->inputs().at(1)->get_shape().elements(); | ||
| auto sliced = m.insert_instruction( | ||
| insert_pt, | ||
| make_op("slice", | ||
| {{"axes", {0}}, | ||
| {"starts", {static_cast<std::int64_t>(start)}}, | ||
| {"ends", {static_cast<std::int64_t>(start + n)}}}), | ||
| batched_gather); | ||
| // A 1-D index already yields {n, emb_dim}; only multi-dim indices need a | ||
| // reshape back to (index dims + embedding dim). | ||
| if(idx_lens.size() == 1) | ||
| return sliced; | ||
| std::vector<std::int64_t> out_dims(idx_lens.begin(), idx_lens.end()); | ||
| out_dims.push_back(static_cast<std::int64_t>(emb_dim)); | ||
| return m.insert_instruction( | ||
| insert_pt, make_op("reshape", {{"dims", out_dims}}), sliced); | ||
| }); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| std::transform( | |
| gathers.begin(), | |
| gathers.end(), | |
| slice_starts.begin(), | |
| results.begin(), | |
| [&](instruction_ref g, std::size_t start) -> instruction_ref { | |
| const auto& idx_lens = g->inputs().at(1)->get_shape().lens(); | |
| std::size_t n = g->inputs().at(1)->get_shape().elements(); | |
| auto sliced = m.insert_instruction( | |
| insert_pt, | |
| make_op("slice", | |
| {{"axes", {0}}, | |
| {"starts", {static_cast<std::int64_t>(start)}}, | |
| {"ends", {static_cast<std::int64_t>(start + n)}}}), | |
| batched_gather); | |
| // A 1-D index already yields {n, emb_dim}; only multi-dim indices need a | |
| // reshape back to (index dims + embedding dim). | |
| if(idx_lens.size() == 1) | |
| return sliced; | |
| std::vector<std::int64_t> out_dims(idx_lens.begin(), idx_lens.end()); | |
| out_dims.push_back(static_cast<std::int64_t>(emb_dim)); | |
| return m.insert_instruction( | |
| insert_pt, make_op("reshape", {{"dims", out_dims}}), sliced); | |
| }); | |
| std::transform(gathers.begin(), | |
| gathers.end(), | |
| slice_starts.begin(), | |
| results.begin(), | |
| [&](instruction_ref g, std::size_t start) -> instruction_ref { | |
| const auto& idx_lens = g->inputs().at(1)->get_shape().lens(); | |
| std::size_t n = g->inputs().at(1)->get_shape().elements(); | |
| auto sliced = m.insert_instruction( | |
| insert_pt, | |
| make_op("slice", | |
| {{"axes", {0}}, | |
| {"starts", {static_cast<std::int64_t>(start)}}, | |
| {"ends", {static_cast<std::int64_t>(start + n)}}}), | |
| batched_gather); | |
| // A 1-D index already yields {n, emb_dim}; only multi-dim indices need a | |
| // reshape back to (index dims + embedding dim). | |
| if(idx_lens.size() == 1) | |
| return sliced; | |
| std::vector<std::int64_t> out_dims(idx_lens.begin(), idx_lens.end()); | |
| out_dims.push_back(static_cast<std::int64_t>(emb_dim)); | |
| return m.insert_instruction( | |
| insert_pt, make_op("reshape", {{"dims", out_dims}}), sliced); | |
| }); |
| auto squeeze = match::name("squeeze"); | ||
| auto unsqueeze = match::name("unsqueeze"); | ||
| auto same_shape_as_grandparent = | ||
| match::make_basic_pred_matcher([](instruction_ref ins) { | ||
| return ins->get_shape() == ins->inputs().front()->inputs().front()->get_shape(); | ||
| }); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto squeeze = match::name("squeeze"); | |
| auto unsqueeze = match::name("unsqueeze"); | |
| auto same_shape_as_grandparent = | |
| match::make_basic_pred_matcher([](instruction_ref ins) { | |
| return ins->get_shape() == ins->inputs().front()->inputs().front()->get_shape(); | |
| }); | |
| auto squeeze = match::name("squeeze"); | |
| auto unsqueeze = match::name("unsqueeze"); | |
| auto same_shape_as_grandparent = match::make_basic_pred_matcher([](instruction_ref ins) { | |
| return ins->get_shape() == ins->inputs().front()->inputs().front()->get_shape(); | |
| }); |
| auto concat_idx = | ||
| m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), | ||
| std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2}); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto concat_idx = | |
| m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), | |
| std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2}); | |
| auto concat_idx = m2.add_instruction( | |
| migraphx::make_op("concat", {{"axis", 0}}), | |
| std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2}); |
| migraphx::run_passes( | ||
| m, | ||
| {migraphx::fuse_horizontal{.merge_mixed_lengths = true}, migraphx::dead_code_elimination{}}); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| migraphx::run_passes( | |
| m, | |
| {migraphx::fuse_horizontal{.merge_mixed_lengths = true}, migraphx::dead_code_elimination{}}); | |
| migraphx::run_passes(m, | |
| {migraphx::fuse_horizontal{.merge_mixed_lengths = true}, | |
| migraphx::dead_code_elimination{}}); |
|
|
||
| auto big_idx = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), | ||
| std::vector<migraphx::instruction_ref>{f1, f2, f3, f4}); | ||
| auto bg = m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb, big_idx); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto bg = m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb, big_idx); | |
| auto bg = m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb, big_idx); |
| auto concat_idx = | ||
| m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), | ||
| std::vector<migraphx::instruction_ref>{idx_b1, idx_b2}); | ||
| std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2}); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto concat_idx = | |
| m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), | |
| std::vector<migraphx::instruction_ref>{idx_b1, idx_b2}); | |
| std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2}); | |
| auto concat_idx = m2.add_instruction( | |
| migraphx::make_op("concat", {{"axis", 0}}), | |
| std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2}); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
AMDMIGraphX/src/fuse_horizontal.cpp
Lines 433 to 448 in f027adc
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #5036 +/- ##
===========================================
+ Coverage 92.71% 92.78% +0.06%
===========================================
Files 596 596
Lines 31733 31971 +238
===========================================
+ Hits 29421 29662 +241
+ Misses 2312 2309 -3
🚀 New features to boost your workflow:
|
Regressions detected 🔴 * No develop baseline was found for this PR's branch point; compared against the latest available develop run instead. |
|
Followup with previous gather Dedupe and horizontal concat fusion #4984
Adds fuse_horizontal_test cases for table dedup and mixed-length same-table fusion to allow us to futher optimize and leter both the gather horizontal fusion as well as the previously added matchers to play better with the fusion
The gather_slice_concat matcher regressed with gathers in customer models (9->20 gathers). this ties things together so the gather fusion and the same table fusion work better, while also further refining the previous matching that was obfuscating fusion candidates before fusion could occur.
Solves the regression from 9->20 (back down to 9)
Further refines gather table fusions and sees a modest 5-10% boost that scales with batch. We also improve scatch used but adding the same-table/mixed size approach.
Further refines how well we can merge gather operations down to 7 from 9 - which should help higher batch size scaling since we're avoiding two additional larger gathers as batch size blows up.
Motivation
Improve gather de-dupe performance for static table lookups - Regressed on develop with #4984
Improve the overall gather bundling with mixed size tables
Technical Details
Got gathers in customer test model from 20->9-7
Better bundling on gathers and avoid regressions with other matchers/fusion passes
Handle prefusion matcher case better to avoid breaking gather horizontal fusion passes
AI Summary
Improve
fuse_horizontalso independent embeddinggathers bundle more effectively, and keep theconcat(slice(gather))rewrite from interfering with that bundling. Four related changes, all scoped to gather fusion.Changes
concat.src/fuse_horizontal.cpp(gather_horizontal_fusion::fuse)src/fuse_horizontal.cpp(fuse_horizontal::apply)fuse_horizontal.merge_mixed_lengths, enabled in the GPU pipeline). Same-table gathers whose indices differ only in shape are merged by flattening each index to 1-D → one batched gather → slice + reshape back. Shared table ⇒ no index offset added; uniform-shape groups still use the cheaper concat path.src/fuse_horizontal.cpp(fuse_gathers_flattened,same_table_gather_horizontal_fusion),src/include/migraphx/fuse_horizontal.hppfind_gather_slice_concatbehindsimplify_reshapes.enable_gather_slice_concat(on by default) and disable it only in the pre-fuse_horizontalGPU pass, so the rewrite can't reshape the gather/index graph fusion groups on.src/simplify_reshapes.cpp,src/include/migraphx/simplify_reshapes.hpp,src/targets/gpu/target.cppResults
Batch 8 — this PR vs. the rejected "group by embedding dim alone" variant (shown to justify the same-table scoping in #3):
addkernelsThe rejected variant produced fewer gathers but merged unrelated tables into one large gather, forcing a per-lookup int32 index offset-add that doesn't fuse (14 -> 328
addkernels, +90% kernel count, -21% throughput). Scoping mixed-length merging to a single shared table avoids the offset-adds entirely.Batch 1: no change vs. baseline (mixed-length merge is gated by
min_index_batch = 4, so it correctly does nothing where batching wouldn't help).Design notes for reviewers
std::transformaccumulator (transform application order is unspecified).Tests
test/fuse_horizontal_test.cpp: table dedup (single + two shared tables) and same-table mixed-index-length fusion. Existing cross-table / same-table cases updated where the ordering change (#2) unifies a previously-split group.Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable