Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 148 additions & 50 deletions src/fuse_horizontal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,79 @@ static std::vector<instruction_ref> slice_gather_rows(module& m,
return results;
}

// Fuse a group of gathers whose indices may have *different* shapes. Each index is
// flattened to 1-D and concatenated, a single batched gather runs over `table`, then each
// element range is sliced back out and reshaped to the original gather's output shape
// (index dims + embedding dim). indices[i] is the (already offset-adjusted) index for
// gathers[i]; table is the gather data operand (a single, possibly concatenated table).
static std::vector<instruction_ref>
fuse_gathers_flattened(module& m,
const std::vector<instruction_ref>& gathers,
const std::vector<instruction_ref>& indices,
instruction_ref table,
instruction_ref insert_pt)
{
std::vector<instruction_ref> flat_inputs(indices.size());
std::transform(indices.begin(), indices.end(), flat_inputs.begin(), [&](instruction_ref idx) {
if(idx->get_shape().lens().size() == 1)
return idx;
std::int64_t n = idx->get_shape().elements();
return m.insert_instruction(insert_pt, make_op("reshape", {{"dims", {n}}}), idx);
});

auto big_idx = m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), flat_inputs);
auto batched_gather =
m.insert_instruction(insert_pt, make_op("gather", {{"axis", 0}}), table, big_idx);

// Inclusive prefix sum of element counts gives each range's end; shift right and
// prepend 0 for the (exclusive) start offsets into the batched gather.
std::vector<std::size_t> slice_ends(gathers.size());
transform_partial_sum(
gathers.begin(), gathers.end(), slice_ends.begin(), std::plus<>{}, [](auto g) {
return g->inputs().at(1)->get_shape().elements();
});
std::vector<std::size_t> slice_starts(gathers.size());
slice_starts[0] = 0;
std::copy(slice_ends.begin(), std::prev(slice_ends.end()), slice_starts.begin() + 1);

const std::int64_t emb_dim = table->get_shape().lens().back();
std::vector<instruction_ref> results(gathers.size());
std::transform(gathers.begin(),
gathers.end(),
slice_starts.begin(),
results.begin(),
[&](instruction_ref g, std::size_t start) -> instruction_ref {
const auto& idx_lens = g->inputs().at(1)->get_shape().lens();
std::size_t n = g->inputs().at(1)->get_shape().elements();
auto sliced = m.insert_instruction(
insert_pt,
make_op("slice",
{{"axes", {0}},
{"starts", {static_cast<std::int64_t>(start)}},
{"ends", {static_cast<std::int64_t>(start + n)}}}),
batched_gather);
// A 1-D index already yields {n, emb_dim}; only multi-dim indices need a
// reshape back to (index dims + embedding dim).
if(idx_lens.size() == 1)
return sliced;
std::vector<std::int64_t> out_dims(idx_lens.begin(), idx_lens.end());
out_dims.push_back(emb_dim);
return m.insert_instruction(
insert_pt, make_op("reshape", {{"dims", out_dims}}), sliced);
});

return results;
}

// ---------------------------------------------------------------------------
// Same-table gather horizontal fusion
//
// Candidates: gather(axis=0) with 2D constant embedding table and a non-scalar
// index whose first dim is >= min_index_batch (worthwhile batch)
// Grouping: by (table instruction, index type, index trailing dims) so only
// gathers reading the *same* table are merged
// gathers reading the *same* table are merged. With merge_mixed_lengths
// the trailing dims are dropped, so same-table gathers with differing
// index shapes are merged too (via the flattened path).
// Fusion: concatenate the indices, single batched gather, slice rows back.
// No index offset adjustment is needed since the table is shared.
// ---------------------------------------------------------------------------
Expand All @@ -174,6 +240,11 @@ struct same_table_gather_horizontal_fusion
// Minimum first index dim for the batched gather to be worthwhile.
static constexpr std::size_t min_index_batch = 4;

// When set, merge same-table gathers whose indices differ only in shape by flattening
// each index (rather than requiring matching trailing dims). Because the table is
// shared, no index offset adjustment is introduced.
bool merge_mixed_lengths = false;

std::size_t min_group_size() const { return 2; }

bool is_candidate(instruction_ref ins) const
Expand Down Expand Up @@ -211,9 +282,12 @@ struct same_table_gather_horizontal_fusion
auto idx_type = idx->get_shape().type();
const auto& lens = idx->get_shape().lens();
assert(not lens.empty());
// Trailing index dims (all except first) — must match for concat on axis 0.
// Trailing index dims (all except first) — must match for concat on axis 0. When
// merging mixed lengths the indices are flattened, so trailing dims are ignored.
// Keying on the data instruction itself restricts grouping to one table.
std::vector<std::size_t> trailing(lens.begin() + 1, lens.end());
std::vector<std::size_t> trailing;
if(not merge_mixed_lengths)
trailing.assign(lens.begin() + 1, lens.end());
return std::make_tuple(data, idx_type, std::move(trailing));
}

Expand All @@ -224,11 +298,25 @@ struct same_table_gather_horizontal_fusion
auto data = gathers.front()->inputs().at(0);
assert(data->get_shape().lens().size() == 2);

// Concatenate the per-gather indices into a single index tensor.
// Collect the per-gather indices (the table is shared, so no offset adjustment).
std::vector<instruction_ref> idx_inputs(gathers.size());
std::transform(gathers.begin(), gathers.end(), idx_inputs.begin(), [](auto g) {
return g->inputs().at(1);
});

// Mixed index shapes cannot be concatenated on axis 0; flatten each instead. Only
// take the flattened path when the group is actually non-uniform, so uniform-shape
// groups produce the same (cheaper, reshape-free) IR regardless of the flag.
if(merge_mixed_lengths)
{
const auto& first_lens = gathers.front()->inputs().at(1)->get_shape().lens();
bool uniform = std::all_of(gathers.begin(), gathers.end(), [&](auto g) {
return g->inputs().at(1)->get_shape().lens() == first_lens;
});
if(not uniform)
return fuse_gathers_flattened(m, gathers, idx_inputs, data, insert_pt);
}

auto concat_idx =
m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), idx_inputs);

Expand All @@ -246,8 +334,9 @@ struct same_table_gather_horizontal_fusion
// Candidates: gather(axis=0) with 2D constant embedding table, static shapes,
// non-scalar index
// Grouping: by (embedding dimension, index type, index trailing dims)
// Fusion: concatenate embedding tables, adjust indices with offsets,
// single batched gather, slice results back
// Fusion: concatenate the *distinct* embedding tables (shared tables are kept
// once), adjust indices with per-table offsets, single batched gather,
// slice results back
// ---------------------------------------------------------------------------

struct gather_horizontal_fusion
Expand Down Expand Up @@ -296,51 +385,50 @@ struct gather_horizontal_fusion
{
auto idx_type = gathers.front()->inputs().at(1)->get_shape().type();

// Concatenate all embedding tables
std::vector<instruction_ref> emb_inputs(gathers.size());
std::transform(gathers.begin(), gathers.end(), emb_inputs.begin(), [](auto g) {
return g->inputs().at(0);
});
// Deduplicate shared tables. Several gathers in a group may read the *same*
// table instruction (e.g. after same-table siblings are folded in). Emitting
// one table copy per gather would replicate that data, bloating both the
// concatenated table literal and the batched gather. Instead keep one copy per
// distinct table (in first-appearance order) and record the row offset of each.
std::vector<instruction_ref> unique_tables;
std::unordered_map<instruction_ref, std::size_t> table_offset;
std::size_t running_offset = 0;
for(auto g : gathers)
{
auto data = g->inputs().at(0);
if(table_offset.emplace(data, running_offset).second)
{
unique_tables.push_back(data);
running_offset += data->get_shape().lens().front();
}
}

// Concatenate the distinct tables (skip the concat when only one remains).
auto concat_emb =
m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), emb_inputs);

// Compute cumulative embedding offsets using transform_partial_sum.
// Inclusive partial sum gives end offsets; shift right and prepend 0
// to get start (exclusive) offsets.
std::vector<std::size_t> cum_sizes(gathers.size());
transform_partial_sum(
gathers.begin(), gathers.end(), cum_sizes.begin(), std::plus<>{}, [](auto g) {
return g->inputs().at(0)->get_shape().lens().front();
});
unique_tables.size() == 1
? unique_tables.front()
: m.insert_instruction(insert_pt, make_op("concat", {{"axis", 0}}), unique_tables);

// Exclusive offsets: [0, cum_sizes[0], cum_sizes[1], ...]
std::vector<std::size_t> emb_offsets(gathers.size());
emb_offsets[0] = 0;
std::copy(cum_sizes.begin(), std::prev(cum_sizes.end()), emb_offsets.begin() + 1);

// Build adjusted indices (add offset to shift into concatenated table)
// Build adjusted indices (add each gather's table offset to shift into the
// concatenated table). Gathers whose table lands at offset 0 need no adjustment.
std::vector<instruction_ref> adjusted_idx_inputs;
adjusted_idx_inputs.reserve(gathers.size());

migraphx::for_each(
gathers.begin(), gathers.end(), emb_offsets.begin(), [&](auto g, auto offset) {
auto idx = g->inputs().at(1);
if(offset == 0)
{
adjusted_idx_inputs.push_back(idx);
}
else
{
auto offset_scalar = m.add_literal(literal{shape{idx_type}, {offset}});
auto offset_broadcast = m.insert_instruction(
insert_pt,
make_op("multibroadcast", {{"out_lens", idx->get_shape().lens()}}),
offset_scalar);
auto adjusted_idx =
m.insert_instruction(insert_pt, make_op("add"), idx, offset_broadcast);
adjusted_idx_inputs.push_back(adjusted_idx);
}
});
std::transform(gathers.begin(),
gathers.end(),
std::back_inserter(adjusted_idx_inputs),
[&](auto g) -> instruction_ref {
auto idx = g->inputs().at(1);
auto offset = table_offset.at(g->inputs().at(0));
if(offset == 0)
return idx;
auto offset_scalar = m.add_literal(literal{shape{idx_type}, {offset}});
auto offset_broadcast = m.insert_instruction(
insert_pt,
make_op("multibroadcast", {{"out_lens", idx->get_shape().lens()}}),
offset_scalar);
return m.insert_instruction(
insert_pt, make_op("add"), idx, offset_broadcast);
});

// Concatenate adjusted indices
auto concat_idx =
Expand Down Expand Up @@ -372,9 +460,19 @@ void fuse_horizontal::apply(module_pass_manager& mpm) const
{
auto& m = mpm.get_module();

// Collapse gathers that share the same table first; any sibling gathers left
// across *different* tables then fall through to cross-table fusion.
fuse_horizontal_ops(m, same_table_gather_horizontal_fusion{}, gather_horizontal_fusion{});
// Run cross-embedding fusion first so a group of gathers sharing an embedding
// dimension is bundled together even when it spans several tables (table dedup
// keeps each distinct table once, so shared tables are not replicated). Running
// same-table fusion first would greedily collapse the same-table subset and strand
// the remaining siblings below cross-table fusion's group-size threshold. Same-table
// fusion then mops up the smaller same-instance groups (size 2-3) that cross-table
// fusion's larger threshold skips. merge_mixed_lengths only affects same-table fusion:
// it lets same-table gathers with differing index shapes merge (via a flattened index)
// at no offset-add cost, without over-merging unrelated tables.
fuse_horizontal_ops(
m,
gather_horizontal_fusion{},
same_table_gather_horizontal_fusion{.merge_mixed_lengths = merge_mixed_lengths});
}

} // namespace MIGRAPHX_INLINE_NS
Expand Down
7 changes: 7 additions & 0 deletions src/include/migraphx/fuse_horizontal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,16 @@ inline namespace MIGRAPHX_INLINE_NS {
* embedding dimension and index layout, then fuses them into a single gather
* over a concatenated embedding table with offset-adjusted indices.
* The batched result is sliced back to produce the original outputs.
*
* When merge_mixed_lengths is set, same-table fusion also merges gathers on one
* table whose indices differ only in shape: each index is flattened to 1-D before
* the batched gather and each slice is reshaped back to its original output shape.
* Because the table is shared this introduces no index offset adjustment, and it
* never merges across distinct tables. Left off by default.
*/
struct MIGRAPHX_EXPORT fuse_horizontal
{
bool merge_mixed_lengths = false;
std::string name() const { return "fuse_horizontal"; }
void apply(module_pass_manager& mpm) const;
};
Expand Down
4 changes: 4 additions & 0 deletions src/include/migraphx/simplify_reshapes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ struct MIGRAPHX_EXPORT simplify_reshapes
size_t depth = 4;
bool enable_op_shape_transform_op = false;
bool enable_gather_rewrite = false;
// Controls find_gather_slice_concat (rewrite of concat(slice(gather)) patterns). On by
// default; disabled only in the pre-fuse_horizontal GPU pass so the rewrite cannot reshape
// the gather/index graph that horizontal gather fusion groups on.
bool enable_gather_slice_concat = true;
std::string name() const { return "simplify_reshapes"; }
void apply(module& m) const;
};
Expand Down
8 changes: 7 additions & 1 deletion src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2092,6 +2092,7 @@ struct find_slice_squeeze
m.replace_instruction(op_ins, new_sq);
}
};

} // namespace

void simplify_reshapes::apply(module& m) const
Expand All @@ -2101,13 +2102,18 @@ void simplify_reshapes::apply(module& m) const
if(enable_gather_rewrite)
match::find_matches(m, find_gather{});
m.repeat_while_changes(depth, [&] {
// find_gather_slice_concat rewrites concat(slice(gather)) patterns into a gather with
// reordered indices. It is on by default but disabled in the pre-fuse_horizontal GPU
// pass, so it cannot reshape the gather/index graph that horizontal gather fusion groups
// on (running it before fusion can change which fusion path a lookup takes).
if(enable_gather_slice_concat)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we adding a flag to disable this?

match::find_matches(m, find_gather_slice_concat{});
match::find_matches(m,
find_nop_reshapes{},
find_flatten{},
find_reshape_cont{},
find_slice_shape_transforms{},
find_nested_shape_transforms{},
find_gather_slice_concat{},
find_concat_slice{},
find_concat_transpose{},
find_concat_reshape{},
Expand Down
5 changes: 3 additions & 2 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ struct pipeline_factory
eliminate_data_type_for_gpu{.disable_64bit = options.fast_math, .ctx = get_context()},
rewrite_resize{.affine_only = true},
dead_code_elimination{},
simplify_reshapes{.enable_gather_rewrite = true},
simplify_reshapes{.enable_gather_rewrite = true, .enable_gather_slice_concat = false},
eliminate_identity{},
eliminate_pad{},
dead_code_elimination{},
Expand All @@ -159,7 +159,8 @@ struct pipeline_factory
optimize_module{},
layout_convolution{.channels_last = enabled(MIGRAPHX_ENABLE_NHWC{})},
dead_code_elimination{},
enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_horizontal{}),
enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}),
fuse_horizontal{.merge_mixed_lengths = true}),
dead_code_elimination{},
prefuse_ops{get_context()},
dead_code_elimination{},
Expand Down
Loading
Loading