Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ CheckOptions:
- key: misc-const-correctness.AnalyzeValues
value: 'false'
- key: performance-for-range-copy.AllowedTypes
value: 'shape;operation;iterator;literal;tensor_view;match'
value: 'shape;operation;iterator;literal;tensor_view;match;sym::expr'
- key: performance-unnecessary-copy-initialization.AllowedTypes
value: 'shape;operation;iterator;literal;tensor_view;match'
value: 'shape;operation;iterator;literal;tensor_view;match;sym::expr'
- key: performance-unnecessary-value-param.AllowedTypes
value: 'shape;operation;iterator;literal;tensor_view;match'
value: 'shape;operation;iterator;literal;tensor_view;match;sym::expr'
- key: readability-function-size.BranchThreshold
value: '15'
- key: readability-function-size.LineThreshold
Expand Down
6 changes: 6 additions & 0 deletions src/include/migraphx/dim_like.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ struct dim_like_picker
// A dim attribute entry that may be either a plain int64_t or a dynamic_dimension.
using dim_like = picked_variant<dim_like_picker, int64_t, shape::dynamic_dimension>;

inline bool is_symbolic(const dim_like& d)
{
return std::holds_alternative<shape::dynamic_dimension>(d) and
std::get<shape::dynamic_dimension>(d).is_symbolic();
}

inline std::ostream& operator<<(std::ostream& os, const dim_like& d)
{
visit([&](const auto& x) { os << x; }, d);
Expand Down
36 changes: 25 additions & 11 deletions src/include/migraphx/op/flatten.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -27,6 +27,7 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/sym.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/dyn_output.hpp>
Expand Down Expand Up @@ -54,11 +55,32 @@ struct flatten
}

std::string name() const { return "flatten"; }

// Collapses the dims on each side of axis into a standard 2D shape, for
// static and symbolic input through one path.
shape symbolic_compute_shape(const shape& s) const
{
auto sym_in = s.to_symbolic();
const auto& dds = sym_in.dyn_dims();
auto x = std::accumulate(dds.begin(),
dds.begin() + axis,
shape::dynamic_dimension{sym::lit(1)},
std::multiplies<>{});
auto y = std::accumulate(dds.begin() + axis,
dds.end(),
shape::dynamic_dimension{sym::lit(1)},
std::multiplies<>{});
shape result{s.type(), {x, y}};
if(not s.symbolic())
return result.to_static();
return result;
}

shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
const auto& s = inputs[0];
if(s.dynamic())
if(s.dynamic() and not s.symbolic())
{
// Doesn't handle optimals
auto min_lens = s.min_lens();
Expand All @@ -78,15 +100,7 @@ struct flatten
{}};
return {s.type(), {x, y}};
}
else
{
auto&& lens = s.lens();
auto x = std::accumulate(
lens.begin(), lens.begin() + axis, std::size_t{1}, std::multiplies<>{});
auto y = std::accumulate(
lens.begin() + axis, lens.end(), std::size_t{1}, std::multiplies<>{});
return {s.type(), {x, y}};
}
return symbolic_compute_shape(s);
}
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
Expand Down
15 changes: 10 additions & 5 deletions src/include/migraphx/op/layout.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -58,10 +58,15 @@ struct layout : unary<layout>

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).only_dims(permutation.size());
auto lens = inputs.at(0).lens();
auto t = inputs.at(0).type();
return shape::from_permutation(t, lens, permutation);
check_shapes{inputs, *this, true}.has(1).only_dims(permutation.size());
const auto& input = inputs.at(0);
auto t = input.type();
// A range-based dynamic shape has no strides, so a permuted layout is not representable.
if(input.symbolic())
return shape::from_permutation(t, input.dyn_dims(), permutation);
if(input.dynamic())
MIGRAPHX_THROW("LAYOUT: non-symbolic dynamic shapes are not supported");
return shape::from_permutation(t, input.lens(), permutation);
}

auto apply() const
Expand Down
107 changes: 44 additions & 63 deletions src/include/migraphx/op/reshape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_RESHAPE_HPP
#define MIGRAPHX_GUARD_OPERATORS_RESHAPE_HPP

#include <numeric>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
Expand Down Expand Up @@ -136,82 +135,64 @@ struct reshape
return {s0.type(), output_dyn_dims};
}

shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
// Resolves the output dims for static and symbolic input through one path.
shape symbolic_compute_shape(const shape& s0) const
{
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.size());
std::transform(dims.begin(), dims.end(), rdims.begin(), [](const dim_like& d) {
return std::get<int64_t>(d);
});

for(std::size_t i = 0; i < dims.size(); i++)
// Lift static input to symbolic literals so the same dd arithmetic resolves both.
auto sym_in = s0.to_symbolic();
auto output_dyn_dims = resolve_reshape_dims(sym_in, dims);
const bool has_inferred_dim =
std::find(dims.begin(), dims.end(), dim_like{-1}) != dims.end();
const bool dims_have_symbolic = std::any_of(dims.begin(), dims.end(), is_symbolic);

// Preserve the input layout when reshape_dims can derive it; else standard.
std::vector<sym::expr> target(output_dyn_dims.size());
std::transform(output_dyn_dims.begin(),
output_dyn_dims.end(),
target.begin(),
[](const auto& dd) { return dd.sym_expr; });
auto result = reshape_dims(sym_in, target, {.lazy = false})
.value_or(shape{s0.type(), output_dyn_dims});

// An inferred -1 over a symbolic input is a floor division, so its element
// count is only resolvable at runtime; otherwise throw on a provably
// mismatched count (strict_less either way), letting indeterminate ones pass.
if(not(s0.symbolic() and has_inferred_dim))
{
if(dims[i] == dim_like{0})
rdims[i] = idims[i];

// convert -1 to 1 for rdims since rdims uses size_t (-1 is max_int for size_t)
if(dims[i] == dim_like{-1})
rdims[i] = 1;
}

if(n_neg_dims > 0)
{
size_t missing_dim =
inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == dim_like{-1})
rdims[i] = missing_dim;
}
auto out_elems = result.sym_elements();
auto in_elems = s0.sym_elements();
if(sym::strict_less(out_elems, in_elems).value_or(false) or
sym::strict_less(in_elems, out_elems).value_or(false))
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " +
to_string(out_elems) + " elements whereas the input has " +
to_string(in_elems));
}

auto nelements =
std::accumulate(rdims.begin(), rdims.end(), std::size_t{1}, std::multiplies<>{});

if(nelements != inputs.front().elements())
MIGRAPHX_THROW("Reshape: Wrong number of elements for reshape: reshape has " +
std::to_string(nelements) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));

auto s = reshape_dims(inputs.front(), rdims, {.lazy = false});
if(not s.has_value())
return shape{inputs.front().type(), rdims};

return s.value();
// Only a static input with integer dims is fully literal; evaluate it back to
// the concrete layout. Anything symbolic stays symbolic.
if(not s0.symbolic() and not dims_have_symbolic)
return result.to_static();
return result;
}

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1, 2);
if(inputs.size() == 2)
return inputs.back();

if(std::any_of(dims.begin(), dims.end(), [](const auto& d) {
return std::holds_alternative<shape::dynamic_dimension>(d);
}))
MIGRAPHX_THROW("Reshape: dynamic_dimension dim entries are not currently supported");

auto n_neg_dims = std::count(dims.begin(), dims.end(), dim_like{-1});
if(n_neg_dims > 1)
MIGRAPHX_THROW("Reshape: Dimensions for reshape can only have one -1 dim but given {" +
to_string_range(dims) + "} with " + to_string(n_neg_dims) + " -1 dims");
validate_reshape_dims(name(), dims);

const auto& s0 = inputs.front();
if(inputs.size() == 1)
{
if(s0.dynamic())
{
return dyn_1arg_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
}
else
if(s0.dynamic() and not s0.symbolic())
{
return inputs.back();
// A symbolic dim has no range interpretation, so it cannot target a
// range-based input.
if(std::any_of(dims.begin(), dims.end(), is_symbolic))
MIGRAPHX_THROW("reshape: range-based input only supports int64 dim entries");
return dyn_1arg_compute_shape(s0);
}
return symbolic_compute_shape(s0);
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
Expand Down
109 changes: 54 additions & 55 deletions src/include/migraphx/op/reshape_lazy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,75 +101,74 @@ struct reshape_lazy
return {s0.type(), output_dyn_dims};
}

shape static_compute_shape(std::vector<shape> inputs, std::size_t n_neg_dims) const
// Resolves the output layout for static and symbolic input through one path.
shape symbolic_compute_shape(const shape& s0) const
{
check_shapes{inputs, *this}.has(1);
auto&& idims = inputs.front().lens();
std::vector<std::size_t> rdims(dims.size());
std::transform(dims.begin(), dims.end(), rdims.begin(), [](const dim_like& d) {
return std::get<int64_t>(d);
});

for(std::size_t i = 0; i < dims.size(); i++)
{
if(dims[i] == dim_like{0})
rdims[i] = idims[i];
// Lift static input to symbolic literals so the same dd arithmetic resolves both.
auto sym_in = s0.to_symbolic();
auto output_dyn_dims = resolve_reshape_dims(sym_in, dims);
const bool has_inferred_dim =
std::find(dims.begin(), dims.end(), dim_like{-1}) != dims.end();

std::vector<sym::expr> target(output_dyn_dims.size());
std::transform(output_dyn_dims.begin(),
output_dyn_dims.end(),
target.begin(),
[](const auto& dd) { return dd.sym_expr; });

// Lazy reshape is a no-copy view: when the permutation can't be preserved we
// cannot fall back to a repacked standard layout the way reshape does.
auto s = reshape_dims(sym_in, target, {.lazy = true});
if(not s.has_value())
MIGRAPHX_THROW("reshape_lazy on axis that is not packed.");

// since rdims using size_t type, -1 is the max value
// is size_t that cause later compuation incorrect
if(dims[i] == dim_like{-1})
rdims[i] = 1;
const bool dims_have_symbolic = std::any_of(dims.begin(), dims.end(), is_symbolic);
// Only a static input with integer dims is fully literal; evaluate it back to
// the concrete layout (static results stay byte-identical). Else stays symbolic.
if(not s0.symbolic() and not dims_have_symbolic)
{
auto result = s->to_static();
if(result.elements() != s0.elements())
MIGRAPHX_THROW(
"reshape_lazy: Wrong number of elements for reshape_lazy: reshape_lazy has " +
std::to_string(result.elements()) + " elements whereas the input has " +
std::to_string(s0.elements()));
assert(result.bytes() == s0.bytes());
return result;
}

if(n_neg_dims > 0)
// An inferred -1 over a symbolic input is a floor division, so its element
// count is only resolvable at runtime; otherwise throw on a provably
// mismatched count (strict_less either way), letting indeterminate ones pass.
if(not(s0.symbolic() and has_inferred_dim))
{
size_t missing_dim =
inputs.front().elements() /
std::accumulate(rdims.begin(), rdims.end(), 1, std::multiplies<int64_t>());
for(std::size_t i = 0; i < rdims.size(); i++)
{
if(dims[i] == dim_like{-1})
rdims[i] = missing_dim;
}
auto out_elems = s->sym_elements();
auto in_elems = s0.sym_elements();
if(sym::strict_less(out_elems, in_elems).value_or(false) or
sym::strict_less(in_elems, out_elems).value_or(false))
MIGRAPHX_THROW(
"reshape_lazy: Wrong number of elements for reshape_lazy: reshape_lazy has " +
to_string(out_elems) + " elements whereas the input has " +
to_string(in_elems));
}

auto s = reshape_dims(inputs.front(), rdims, {.lazy = true});
if(not s.has_value())
MIGRAPHX_THROW("reshape_lazy on axis that is not packed.");

if(s->elements() != inputs.front().elements())
MIGRAPHX_THROW(
"reshape_lazy: Wrong number of elements for reshape_lazy: reshape_lazy has " +
std::to_string(s->elements()) + " elements whereas the input has " +
std::to_string(inputs.front().elements()));

assert(s->bytes() == inputs.front().bytes());
return *s;
}

shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this, true}.has(1);
if(std::any_of(dims.begin(), dims.end(), [](const auto& d) {
return std::holds_alternative<shape::dynamic_dimension>(d);
}))
MIGRAPHX_THROW(
"reshape_lazy: dynamic_dimension dim entries are not currently supported");

auto n_neg_dims = std::count(dims.begin(), dims.end(), dim_like{-1});
if(n_neg_dims > 1)
MIGRAPHX_THROW("reshape_lazy: Dimensions for reshape_lazy can only have one -1 dim but "
"given {" +
to_string_range(dims) + "} with " + to_string(n_neg_dims) + " -1 dims");
const auto& s0 = inputs[0];
if(s0.dynamic())

validate_reshape_dims(name(), dims);

const auto& s0 = inputs.front();
if(s0.dynamic() and not s0.symbolic())
{
// A symbolic dim has no range interpretation, so it cannot target a
// range-based input.
if(std::any_of(dims.begin(), dims.end(), is_symbolic))
MIGRAPHX_THROW("reshape_lazy: range-based input only supports int64 dim entries");
return dyn_compute_shape(s0);
}
else
{
return static_compute_shape(inputs, n_neg_dims);
}
return symbolic_compute_shape(s0);
}

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
Expand Down
Loading
Loading