Skip to content

Improve horizontal gather fusion: table dedup, ordering, and mixed-length table gathers#5036

Open
TedThemistokleous wants to merge 1 commit into
developfrom
unify_gather_dedupe
Open

Improve horizontal gather fusion: table dedup, ordering, and mixed-length table gathers#5036
TedThemistokleous wants to merge 1 commit into
developfrom
unify_gather_dedupe

Conversation

@TedThemistokleous

@TedThemistokleous TedThemistokleous commented Jul 5, 2026

Copy link
Copy Markdown
Collaborator

Followup with previous gather Dedupe and horizontal concat fusion #4984

Adds fuse_horizontal_test cases for table dedup and mixed-length same-table fusion to allow us to futher optimize and leter both the gather horizontal fusion as well as the previously added matchers to play better with the fusion

The gather_slice_concat matcher regressed with gathers in customer models (9->20 gathers). this ties things together so the gather fusion and the same table fusion work better, while also further refining the previous matching that was obfuscating fusion candidates before fusion could occur.

Solves the regression from 9->20 (back down to 9)
Further refines gather table fusions and sees a modest 5-10% boost that scales with batch. We also improve scatch used but adding the same-table/mixed size approach.

Further refines how well we can merge gather operations down to 7 from 9 - which should help higher batch size scaling since we're avoiding two additional larger gathers as batch size blows up.

Motivation

Improve gather de-dupe performance for static table lookups - Regressed on develop with #4984
Improve the overall gather bundling with mixed size tables

Technical Details

Got gathers in customer test model from 20->9-7
Better bundling on gathers and avoid regressions with other matchers/fusion passes
Handle prefusion matcher case better to avoid breaking gather horizontal fusion passes

AI Summary

Improve fuse_horizontal so independent embedding gathers bundle more effectively, and keep the concat(slice(gather)) rewrite from interfering with that bundling. Four related changes, all scoped to gather fusion.

Changes

# Change Files
1 Deduplicate shared tables in cross-embedding fusion. Concatenate only the distinct embedding tables (first-appearance order) and map each gather to its table's row offset, instead of one table copy per gather. Single-table groups skip the concat. src/fuse_horizontal.cpp (gather_horizontal_fusion::fuse)
2 Run cross-embedding fusion before same-table fusion. Same-table fusion (min group 2) previously ran first and consumed the same-table subset, stranding same-emb-dim siblings below cross-table's min group of 4. src/fuse_horizontal.cpp (fuse_horizontal::apply)
3 Same-table mixed-length merge (opt-in fuse_horizontal.merge_mixed_lengths, enabled in the GPU pipeline). Same-table gathers whose indices differ only in shape are merged by flattening each index to 1-D → one batched gather → slice + reshape back. Shared table ⇒ no index offset added; uniform-shape groups still use the cheaper concat path. src/fuse_horizontal.cpp (fuse_gathers_flattened, same_table_gather_horizontal_fusion), src/include/migraphx/fuse_horizontal.hpp
4 Gate find_gather_slice_concat behind simplify_reshapes.enable_gather_slice_concat (on by default) and disable it only in the pre-fuse_horizontal GPU pass, so the rewrite can't reshape the gather/index graph fusion groups on. src/simplify_reshapes.cpp, src/include/migraphx/simplify_reshapes.hpp, src/targets/gpu/target.cpp

Results

Batch 8 — this PR vs. the rejected "group by embedding dim alone" variant (shown to justify the same-table scoping in #3):

Variant GPU kernels distinct embedding gathers standalone add kernels scratch throughput
This PR (same-table mixed-length) 343 7 14 53.9 MB 1.00x (ref)
Group-by-emb-dim (rejected) 656 4 328 67.8 MB 0.79x

The rejected variant produced fewer gathers but merged unrelated tables into one large gather, forcing a per-lookup int32 index offset-add that doesn't fuse (14 -> 328 add kernels, +90% kernel count, -21% throughput). Scoping mixed-length merging to a single shared table avoids the offset-adds entirely.

Batch 1: no change vs. baseline (mixed-length merge is gated by min_index_batch = 4, so it correctly does nothing where batching wouldn't help).

Design notes for reviewers

  • 1 is the enabler for 2: with tables deduped, bundling a group that spans tables no longer replicates a shared table.
  • 3 deliberately lives on the same-table finder, not cross-table — cross-table merging of distinct tables requires offset-adds that de-fuse; same-table merging is offset-free.
  • All new behavior is off by default except where explicitly enabled in the GPU pipeline; non-GPU pipelines and existing call sites are unchanged.
  • The reconstruction offsets use a prefix sum rather than a stateful std::transform accumulator (transform application order is unspecified).

Tests

test/fuse_horizontal_test.cpp: table dedup (single + two shared tables) and same-table mixed-index-length fusion. Existing cross-table / same-table cases updated where the ordering change (#2) unifies a previously-split group.

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

…ngth merge

Rework fuse_horizontal so independent embedding gathers bundle more effectively
without disturbing surrounding fusion.

- gather_horizontal_fusion now deduplicates shared embedding tables: it
  concatenates only the distinct tables (first-appearance order) and maps each
  gather to its table's row offset, instead of emitting one table copy per
  gather. Shared tables are no longer replicated, and a group that resolves to a
  single table skips the concat entirely.

- Run cross-embedding fusion before same-table fusion. Same-table fusion (min
  group 2) previously ran first and consumed the same-table subset of a group,
  stranding the remaining same-embedding-dim siblings below cross-table fusion's
  min group of 4. Running cross-table first lets a group spanning several tables
  bundle together (table dedup keeps each table once); same-table fusion then
  mops up the smaller same-instance groups.

- Add an opt-in same-table mixed-length merge (fuse_horizontal.merge_mixed_lengths,
  enabled in the GPU pipeline): same-table gathers whose indices differ only in
  shape are merged by flattening each index to 1-D, one batched gather, then
  slice + reshape back. Because the table is shared this adds no index offset,
  and uniform-shape groups still take the cheaper concat path.

- Gate find_gather_slice_concat with simplify_reshapes.enable_gather_slice_concat
  (on by default) and disable it only in the pre-fuse_horizontal GPU pass, so the
  concat(slice(gather)) rewrite cannot reshape the gather/index graph that
  horizontal gather fusion groups on.

Adds fuse_horizontal_test cases for table dedup and mixed-length same-table
fusion.

Co-authored-by: Cursor <cursoragent@cursor.com>
@TedThemistokleous TedThemistokleous requested review from pfultz2 and shivadbhavsar and removed request for causten July 5, 2026 02:52
Comment thread src/fuse_horizontal.cpp
Comment on lines +174 to +180
std::transform(
indices.begin(), indices.end(), flat_inputs.begin(), [&](instruction_ref idx) {
if(idx->get_shape().lens().size() == 1)
return idx;
std::int64_t n = idx->get_shape().elements();
return m.insert_instruction(insert_pt, make_op("reshape", {{"dims", {n}}}), idx);
});

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
std::transform(
indices.begin(), indices.end(), flat_inputs.begin(), [&](instruction_ref idx) {
if(idx->get_shape().lens().size() == 1)
return idx;
std::int64_t n = idx->get_shape().elements();
return m.insert_instruction(insert_pt, make_op("reshape", {{"dims", {n}}}), idx);
});
std::transform(indices.begin(), indices.end(), flat_inputs.begin(), [&](instruction_ref idx) {
if(idx->get_shape().lens().size() == 1)
return idx;
std::int64_t n = idx->get_shape().elements();
return m.insert_instruction(insert_pt, make_op("reshape", {{"dims", {n}}}), idx);
});

Comment thread src/fuse_horizontal.cpp
Comment on lines +182 to +183
auto big_idx =
m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), flat_inputs);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
auto big_idx =
m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), flat_inputs);
auto big_idx = m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), flat_inputs);

Comment thread src/fuse_horizontal.cpp
Comment on lines +200 to +223
std::transform(
gathers.begin(),
gathers.end(),
slice_starts.begin(),
results.begin(),
[&](instruction_ref g, std::size_t start) -> instruction_ref {
const auto& idx_lens = g->inputs().at(1)->get_shape().lens();
std::size_t n = g->inputs().at(1)->get_shape().elements();
auto sliced = m.insert_instruction(
insert_pt,
make_op("slice",
{{"axes", {0}},
{"starts", {static_cast<std::int64_t>(start)}},
{"ends", {static_cast<std::int64_t>(start + n)}}}),
batched_gather);
// A 1-D index already yields {n, emb_dim}; only multi-dim indices need a
// reshape back to (index dims + embedding dim).
if(idx_lens.size() == 1)
return sliced;
std::vector<std::int64_t> out_dims(idx_lens.begin(), idx_lens.end());
out_dims.push_back(static_cast<std::int64_t>(emb_dim));
return m.insert_instruction(
insert_pt, make_op("reshape", {{"dims", out_dims}}), sliced);
});

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
std::transform(
gathers.begin(),
gathers.end(),
slice_starts.begin(),
results.begin(),
[&](instruction_ref g, std::size_t start) -> instruction_ref {
const auto& idx_lens = g->inputs().at(1)->get_shape().lens();
std::size_t n = g->inputs().at(1)->get_shape().elements();
auto sliced = m.insert_instruction(
insert_pt,
make_op("slice",
{{"axes", {0}},
{"starts", {static_cast<std::int64_t>(start)}},
{"ends", {static_cast<std::int64_t>(start + n)}}}),
batched_gather);
// A 1-D index already yields {n, emb_dim}; only multi-dim indices need a
// reshape back to (index dims + embedding dim).
if(idx_lens.size() == 1)
return sliced;
std::vector<std::int64_t> out_dims(idx_lens.begin(), idx_lens.end());
out_dims.push_back(static_cast<std::int64_t>(emb_dim));
return m.insert_instruction(
insert_pt, make_op("reshape", {{"dims", out_dims}}), sliced);
});
std::transform(gathers.begin(),
gathers.end(),
slice_starts.begin(),
results.begin(),
[&](instruction_ref g, std::size_t start) -> instruction_ref {
const auto& idx_lens = g->inputs().at(1)->get_shape().lens();
std::size_t n = g->inputs().at(1)->get_shape().elements();
auto sliced = m.insert_instruction(
insert_pt,
make_op("slice",
{{"axes", {0}},
{"starts", {static_cast<std::int64_t>(start)}},
{"ends", {static_cast<std::int64_t>(start + n)}}}),
batched_gather);
// A 1-D index already yields {n, emb_dim}; only multi-dim indices need a
// reshape back to (index dims + embedding dim).
if(idx_lens.size() == 1)
return sliced;
std::vector<std::int64_t> out_dims(idx_lens.begin(), idx_lens.end());
out_dims.push_back(static_cast<std::int64_t>(emb_dim));
return m.insert_instruction(
insert_pt, make_op("reshape", {{"dims", out_dims}}), sliced);
});

Comment thread src/simplify_reshapes.cpp
Comment on lines +2099 to +2104
auto squeeze = match::name("squeeze");
auto unsqueeze = match::name("unsqueeze");
auto same_shape_as_grandparent =
match::make_basic_pred_matcher([](instruction_ref ins) {
return ins->get_shape() == ins->inputs().front()->inputs().front()->get_shape();
});

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
auto squeeze = match::name("squeeze");
auto unsqueeze = match::name("unsqueeze");
auto same_shape_as_grandparent =
match::make_basic_pred_matcher([](instruction_ref ins) {
return ins->get_shape() == ins->inputs().front()->inputs().front()->get_shape();
});
auto squeeze = match::name("squeeze");
auto unsqueeze = match::name("unsqueeze");
auto same_shape_as_grandparent = match::make_basic_pred_matcher([](instruction_ref ins) {
return ins->get_shape() == ins->inputs().front()->inputs().front()->get_shape();
});

Comment on lines +590 to +592
auto concat_idx =
m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2});

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
auto concat_idx =
m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2});
auto concat_idx = m2.add_instruction(
migraphx::make_op("concat", {{"axis", 0}}),
std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2});

Comment on lines +613 to +615
migraphx::run_passes(
m,
{migraphx::fuse_horizontal{.merge_mixed_lengths = true}, migraphx::dead_code_elimination{}});

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
migraphx::run_passes(
m,
{migraphx::fuse_horizontal{.merge_mixed_lengths = true}, migraphx::dead_code_elimination{}});
migraphx::run_passes(m,
{migraphx::fuse_horizontal{.merge_mixed_lengths = true},
migraphx::dead_code_elimination{}});


auto big_idx = m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
std::vector<migraphx::instruction_ref>{f1, f2, f3, f4});
auto bg = m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb, big_idx);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
auto bg = m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb, big_idx);
auto bg = m2.add_instruction(migraphx::make_op("gather", {{"axis", 0}}), emb, big_idx);

Comment on lines +1071 to +1073
auto concat_idx =
m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
std::vector<migraphx::instruction_ref>{idx_b1, idx_b2});
std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2});

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
auto concat_idx =
m2.add_instruction(migraphx::make_op("concat", {{"axis", 0}}),
std::vector<migraphx::instruction_ref>{idx_b1, idx_b2});
std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2});
auto concat_idx = m2.add_instruction(
migraphx::make_op("concat", {{"axis", 0}}),
std::vector<migraphx::instruction_ref>{idx_a1, adj_b1, idx_a2, adj_b2});

Comment thread src/fuse_horizontal.cpp

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

return m.insert_instruction(insert_pt, make_op("add"), idx, offset_broadcast);
});
// Concatenate adjusted indices
auto concat_idx =
m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), adjusted_idx_inputs);
// Single batched gather
auto batched_gather = m.insert_instruction(
insert_pt, make_op("gather", {{"axis", 0}}), concat_emb, concat_idx);
return slice_gather_rows(m, batched_gather, gathers, insert_pt);
}
};
// ---------------------------------------------------------------------------

@codecov

codecov Bot commented Jul 5, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 97.01493% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/fuse_horizontal.cpp 96.92% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #5036      +/-   ##
===========================================
+ Coverage    92.71%   92.78%   +0.06%     
===========================================
  Files          596      596              
  Lines        31733    31971     +238     
===========================================
+ Hits         29421    29662     +241     
+ Misses        2312     2309       -3     
Files with missing lines Coverage Δ
src/include/migraphx/fuse_horizontal.hpp 100.00% <ø> (ø)
src/include/migraphx/simplify_reshapes.hpp 100.00% <ø> (ø)
src/simplify_reshapes.cpp 98.15% <100.00%> (+<0.01%) ⬆️
src/fuse_horizontal.cpp 98.34% <96.92%> (-0.91%) ⬇️

... and 15 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gh-app-migraphx-bot-pr-write

Copy link
Copy Markdown
Test Batch New Rate (f027ad) Old Rate (db2b92)* Diff Status
torchvision-resnet50 64 3,144.73 3,138.21 0.21%
torchvision-resnet50_fp16 64 6,658.68 6,618.72 0.60%
torchvision-densenet121 32 2,705.51 2,692.42 0.49%
torchvision-densenet121_fp16 32 4,555.94 4,524.66 0.69%
torchvision-inceptionv3 32 1,798.15 1,796.50 0.09%
torchvision-inceptionv3_fp16 32 2,824.53 2,822.94 0.06%
cadene-inceptionv4 16 823.16 821.97 0.15%
cadene-resnext64x4 16 783.62 783.26 0.05%
slim-mobilenet 64 8,428.34 8,384.67 0.52%
slim-nasnetalarge 64 228.83 229.26 -0.19%
slim-resnet50v2 64 3,177.38 3,162.56 0.47%
bert-mrpc-onnx 8 1,169.35 1,171.62 -0.19%
bert-mrpc-tf 1 485.41 490.19 -0.98%
pytorch-examples-wlang-gru 1 329.04 477.77 -31.13% 🔴
pytorch-examples-wlang-lstm 1 486.17 407.03 19.44% 🔆
torchvision-resnet50_1 1 759.41 755.89 0.47%
cadene-dpn92_1 1 444.39 444.00 0.09%
cadene-resnext101_1 1 365.92 365.82 0.03%
onnx-taau-downsample 1 401.24 399.65 0.40%
dlrm-criteoterabyte 1 32.53 32.40 0.40%
dlrm-criteoterabyte_fp16 1 52.67 51.81 1.65%
agentmodel 1 8,623.51 9,208.36 -6.35% 🔴
unet_fp16 2 57.23 57.04 0.34%
resnet50v1_fp16 1 937.22 946.18 -0.95%
resnet50v1_int8 1 937.11 939.13 -0.22%
bert_base_cased_fp16 64 1,102.26 1,097.88 0.40%
bert_large_uncased_fp16 32 347.34 346.28 0.31%
bert_large_fp16 1 204.31 205.00 -0.34%
distilgpt2_fp16 16 2,097.87 2,094.44 0.16%
yolov5s 1 594.77 561.59 5.91% 🔆
tinyllama 1 45.98 45.98 -0.00%
vicuna-fastchat 1 43.98 44.11 -0.29%
whisper-tiny-encoder 1 419.95 413.15 1.65%
whisper-tiny-decoder 1 417.20 410.07 1.74%
llama2_7b 1 20.45 20.83 -1.82%
qwen1.5-7b 1 23.65 23.60 0.21%
phi3-3.8b 1 26.81 26.80 0.06%
llama3-8b 1 21.81 21.72 0.45%
whisper-large-encoder 1 10.30 10.17 1.29%
whisper-large-decoder 1 104.11 104.83 -0.69%
mistral-7b 1 23.84 23.83 0.06%
FLUX.1-schnell 1 751.03 760.85 -1.29%

Regressions detected 🔴

* No develop baseline was found for this PR's branch point; compared against the latest available develop run instead.

@gh-app-migraphx-bot-pr-write

Copy link
Copy Markdown
Test Status Result
bert-mrpc-onnx PASSED: MIGraphX meets tolerance
bert-mrpc-tf PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-gru PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-lstm PASSED: MIGraphX meets tolerance
dlrm-criteoterabyte PASSED: MIGraphX meets tolerance
agentmodel PASSED: MIGraphX meets tolerance
unet PASSED: MIGraphX meets tolerance
resnet50v1 PASSED: MIGraphX meets tolerance
bert_base_cased_fp16 PASSED: MIGraphX meets tolerance
bert_large_uncased_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
bert_large PASSED: MIGraphX meets tolerance
yolov5s PASSED: MIGraphX meets tolerance
tinyllama PASSED: MIGraphX meets tolerance
vicuna-fastchat PASSED: MIGraphX meets tolerance
whisper-tiny-encoder PASSED: MIGraphX meets tolerance
whisper-tiny-decoder PASSED: MIGraphX meets tolerance
distilgpt2_fp16 PASSED: MIGraphX meets tolerance
llama2_7b PASSED: MIGraphX meets tolerance
qwen1.5-7b PASSED: MIGraphX meets tolerance
phi3-3.8b PASSED: MIGraphX meets tolerance
llama3-8b PASSED: MIGraphX meets tolerance
whisper-large-encoder PASSED: MIGraphX meets tolerance
whisper-large-decoder PASSED: MIGraphX meets tolerance
mistral-7b PASSED: MIGraphX meets tolerance
FLUX.1-schnell PASSED: MIGraphX meets tolerance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant