Skip to content

symbolic reshape ops#4977

Open
shivadbhavsar wants to merge 9 commits into
developfrom
sym_reshape_ops
Open

symbolic reshape ops#4977
shivadbhavsar wants to merge 9 commits into
developfrom
sym_reshape_ops

Conversation

@shivadbhavsar

Copy link
Copy Markdown
Contributor

Motivation

Allow symbolic computation in compute_shapes for reshape ops

Technical Details

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@shivadbhavsar shivadbhavsar self-assigned this Jun 16, 2026
@codecov

codecov Bot commented Jun 17, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 94.75983% with 12 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/include/migraphx/op/reshape_lazy.hpp 73.08% 7 Missing ⚠️
src/include/migraphx/op/layout.hpp 71.43% 2 Missing ⚠️
src/reshape_dims.cpp 97.62% 2 Missing ⚠️
src/shape.cpp 0.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4977      +/-   ##
===========================================
- Coverage    92.69%   92.67%   -0.02%     
===========================================
  Files          596      596              
  Lines        31603    31655      +52     
===========================================
+ Hits         29292    29334      +42     
- Misses        2311     2321      +10     
Files with missing lines Coverage Δ
src/include/migraphx/dim_like.hpp 100.00% <100.00%> (ø)
src/include/migraphx/op/flatten.hpp 81.82% <100.00%> (+4.04%) ⬆️
src/include/migraphx/op/reshape.hpp 100.00% <100.00%> (+1.19%) ⬆️
src/include/migraphx/op/squeeze.hpp 97.92% <100.00%> (ø)
src/include/migraphx/op/unsqueeze.hpp 100.00% <100.00%> (ø)
src/include/migraphx/shape.hpp 90.59% <ø> (ø)
src/shape.cpp 93.64% <0.00%> (-0.13%) ⬇️
src/include/migraphx/op/layout.hpp 75.00% <71.43%> (-15.00%) ⬇️
src/reshape_dims.cpp 93.75% <97.62%> (-0.69%) ⬇️
src/include/migraphx/op/reshape_lazy.hpp 86.67% <73.08%> (-10.26%) ⬇️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@shivadbhavsar shivadbhavsar marked this pull request as ready for review June 18, 2026 18:43
@shivadbhavsar shivadbhavsar requested a review from causten as a code owner June 18, 2026 18:43
Copilot AI review requested due to automatic review settings June 18, 2026 18:43

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends shape inference for reshape-like operators to support symbolic (migraphx::sym) dimensions, enabling symbolic computation in compute_shape() for flatten, reshape/reshape_lazy, squeeze, and unsqueeze.

Changes:

  • Add symbolic shape inference paths for flatten, reshape/reshape_lazy, squeeze, and unsqueeze (preserving layout where possible).
  • Generalize reshape_dims() to operate on sym::expr and add helpers to resolve/validate reshape dims (including 0 and -1 semantics) for symbolic inputs.
  • Add extensive unit tests covering symbolic reshape/flatten/squeeze/unsqueeze shape inference and stride/layout behavior.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
test/op_shape_test.cpp Adds test coverage for symbolic shape inference across flatten/reshape/squeeze/unsqueeze.
src/reshape_dims.cpp Implements symbolic-aware reshape_dims plus helpers to resolve/validate symbolic reshape dims.
src/include/migraphx/reshape_dims.hpp Exposes new symbolic reshape helpers and updates reshape_dims signature to sym::expr.
src/include/migraphx/op/unsqueeze.hpp Adds unified symbolic/static shape inference path for unsqueeze (including symbolic strides).
src/include/migraphx/op/squeeze.hpp Adds unified symbolic/static shape inference path for squeeze with stride preservation.
src/include/migraphx/op/reshape.hpp Switches static handling to symbolic resolution; attempts to preserve layout via reshape_dims.
src/include/migraphx/op/reshape_lazy.hpp Switches static handling to symbolic resolution and uses reshape_dims in lazy mode.
src/include/migraphx/op/flatten.hpp Adds symbolic flatten shape inference path.
src/include/migraphx/dim_like.hpp Adds is_symbolic(dim_like) helper used by reshape validation/resolution.

Comment thread src/include/migraphx/op/reshape.hpp Outdated
Comment thread src/include/migraphx/op/reshape_lazy.hpp Outdated
@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 19, 2026

Copy link
Copy Markdown
Test Batch New Rate (a99756) Old Rate (799492)* Diff Status
torchvision-resnet50 64 3,138.21 3,137.65 0.02%
torchvision-resnet50_fp16 64 6,624.87 6,626.32 -0.02%
torchvision-densenet121 32 2,694.12 2,696.62 -0.09%
torchvision-densenet121_fp16 32 4,533.28 4,528.92 0.10%
torchvision-inceptionv3 32 1,796.14 1,795.66 0.03%
torchvision-inceptionv3_fp16 32 2,820.43 2,821.40 -0.03%
cadene-inceptionv4 16 822.74 823.65 -0.11%
cadene-resnext64x4 16 782.77 782.53 0.03%
slim-mobilenet 64 8,387.14 8,378.48 0.10%
slim-nasnetalarge 64 228.86 228.88 -0.01%
slim-resnet50v2 64 3,163.31 3,163.63 -0.01%
bert-mrpc-onnx 8 1,172.43 1,171.00 0.12%
bert-mrpc-tf 1 483.88 488.93 -1.03%
pytorch-examples-wlang-gru 1 323.34 331.99 -2.61%
pytorch-examples-wlang-lstm 1 467.12 444.73 5.03% 🔆
torchvision-resnet50_1 1 765.87 752.34 1.80%
cadene-dpn92_1 1 474.83 441.73 7.50% 🔆
cadene-resnext101_1 1 365.92 366.76 -0.23%
onnx-taau-downsample 1 399.23 400.06 -0.21%
dlrm-criteoterabyte 1 32.43 32.41 0.05%
dlrm-criteoterabyte_fp16 1 51.79 51.75 0.07%
agentmodel 1 9,663.14 9,570.39 0.97%
unet_fp16 2 56.96 56.95 0.03%
resnet50v1_fp16 1 948.96 940.64 0.89%
resnet50v1_int8 1 932.69 942.01 -0.99%
bert_base_cased_fp16 64 1,097.70 1,097.60 0.01%
bert_large_uncased_fp16 32 346.17 346.47 -0.09%
bert_large_fp16 1 204.04 204.16 -0.06%
distilgpt2_fp16 16 2,086.43 2,091.67 -0.25%
yolov5s 1 592.74 597.86 -0.86%
tinyllama 1 45.94 45.98 -0.08%
vicuna-fastchat 1 44.22 44.16 0.15%
whisper-tiny-encoder 1 417.63 417.63 0.00%
whisper-tiny-decoder 1 418.00 412.69 1.29%
llama2_7b 1 20.33 20.34 -0.02%
qwen1.5-7b 1 23.57 23.58 -0.01%
phi3-3.8b 1 26.78 26.79 -0.01%
llama3-8b 1 21.76 21.73 0.14%
whisper-large-encoder 1 10.28 10.28 -0.05%
whisper-large-decoder 1 106.13 105.04 1.04%
mistral-7b 1 23.78 23.79 -0.03%
FLUX.1-schnell 1 762.22 767.41 -0.68%

Check flagged results 🔆

* No develop baseline was found for this PR's branch point; compared against the latest available develop run instead.

@gh-app-migraphx-bot-pr-write

gh-app-migraphx-bot-pr-write Bot commented Jun 19, 2026

Copy link
Copy Markdown
Test Status Result
bert-mrpc-onnx PASSED: MIGraphX meets tolerance
bert-mrpc-tf PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-gru PASSED: MIGraphX meets tolerance
pytorch-examples-wlang-lstm PASSED: MIGraphX meets tolerance
dlrm-criteoterabyte PASSED: MIGraphX meets tolerance
agentmodel PASSED: MIGraphX meets tolerance
unet PASSED: MIGraphX meets tolerance
resnet50v1 PASSED: MIGraphX meets tolerance
bert_base_cased_fp16 PASSED: MIGraphX meets tolerance
bert_large_uncased_fp16 🔴 FAILED: MIGraphX is not within tolerance - check verbose output
bert_large PASSED: MIGraphX meets tolerance
yolov5s PASSED: MIGraphX meets tolerance
tinyllama PASSED: MIGraphX meets tolerance
vicuna-fastchat PASSED: MIGraphX meets tolerance
whisper-tiny-encoder PASSED: MIGraphX meets tolerance
whisper-tiny-decoder PASSED: MIGraphX meets tolerance
distilgpt2_fp16 PASSED: MIGraphX meets tolerance
llama2_7b PASSED: MIGraphX meets tolerance
qwen1.5-7b PASSED: MIGraphX meets tolerance
phi3-3.8b PASSED: MIGraphX meets tolerance
llama3-8b PASSED: MIGraphX meets tolerance
whisper-large-encoder PASSED: MIGraphX meets tolerance
whisper-large-decoder PASSED: MIGraphX meets tolerance
mistral-7b PASSED: MIGraphX meets tolerance
FLUX.1-schnell PASSED: MIGraphX meets tolerance

Comment thread src/reshape_dims.cpp Outdated
return not *x_lt;
});
if(x != dim)
if(indeterminate or x != dim)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should use same_value instead of x != dim.

Comment thread src/reshape_dims.cpp Outdated
return nullopt;
// Broadcasted check to avoid division by zero
if(stride2 == 0)
if(stride2 == zero)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these equality comparisons should use same_value because it might come from a variable with constraints.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean same_symbol?

Comment thread src/reshape_dims.cpp
auto n = it - start;
assert((r + n) <= rdims.size());
auto stride = istrides[i] * idim;
std::for_each(start, it + 1, [&](auto dim) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add sym::expr to the AllowedTypes for copying in the .clang-tidy config.

@shivadbhavsar shivadbhavsar requested a review from pfultz2 June 27, 2026 00:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants