From 3591501ce30ee7a2597d9533c0bdcbf15eee524a Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Thu, 25 Jun 2026 02:40:56 +0000 Subject: [PATCH] [ARITH] Use IntImm in canonical scalar hot paths Canonical simplify operates on scalar index expressions in these paths, so direct IntImm construction avoids the generic MakeConst scalar/vector dispatch. This keeps MakeConst for generic helper sites while streamlining the focused split normal-form constants. --- src/arith/canonical_simplify.cc | 43 ++++++++++++++++----------------- src/arith/const_fold.h | 6 ++--- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index ce906ff143bb..1c7c979ba459 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -133,15 +133,15 @@ class SplitExprNode : public CanonicalExprNode { return IntImm(dtype, 0); } if (this->upper_factor != SplitExprNode::kPosInf) { - res = ModImpl(res, MakeConst(dtype, this->upper_factor), div_mode); + res = ModImpl(res, IntImm(dtype, this->upper_factor), div_mode); } if (this->lower_factor != 1) { - res = DivImpl(res, MakeConst(dtype, this->lower_factor), div_mode); + res = DivImpl(res, IntImm(dtype, this->lower_factor), div_mode); } sscale *= this->scale; if (sscale != 1) { TVM_FFI_ICHECK(dtype.code() != DLDataTypeCode::kDLUInt || sscale > 0); - res = res * MakeConst(dtype, sscale); + res = res * IntImm(dtype, sscale); } return res; } @@ -172,20 +172,20 @@ class SplitExprNode : public CanonicalExprNode { return false; } if (this->upper_factor != SplitExprNode::kPosInf) { - res = ModImpl(res, MakeConst(this->ty(), this->upper_factor), div_mode); + res = ModImpl(res, IntImm(this->ty(), this->upper_factor), div_mode); if (!CastIsSafe(dtype, res, analyzer)) { return false; } } if (this->lower_factor != 1) { - res = DivImpl(res, MakeConst(this->ty(), this->lower_factor), div_mode); + res = DivImpl(res, IntImm(this->ty(), this->lower_factor), div_mode); if (!CastIsSafe(dtype, res, analyzer)) { return false; } } if (this->scale != 1) { TVM_FFI_ICHECK(this->ty().code() != DLDataTypeCode::kDLUInt || this->scale > 0); - res = res * MakeConst(this->ty(), this->scale); + res = res * IntImm(this->ty(), this->scale); if (!CastIsSafe(dtype, res, analyzer)) { return false; } @@ -252,7 +252,7 @@ class SumExprNode : public CanonicalExprNode { PrimExpr Normalize() const final { // quick path 1. if (this->args.size() == 0) { - return MakeConst(this->ty(), this->base); + return IntImm(this->ty(), this->base); } return Normalize_(this->ty(), SimplifySplitExprs(args), base); } @@ -354,7 +354,7 @@ class SumExprNode : public CanonicalExprNode { } } if (base > 0 || is_min_value) { - res = res + MakeConst(dtype, base); + res = res + IntImm(dtype, base); if (!CastIsSafe(dtype, res, analyzer)) { return false; } @@ -369,7 +369,7 @@ class SumExprNode : public CanonicalExprNode { } } if (base < 0 && !is_min_value) { - res = res - MakeConst(dtype, -base); + res = res - IntImm(dtype, -base); if (!CastIsSafe(dtype, res, analyzer)) { return false; } @@ -507,7 +507,7 @@ class SumExprNode : public CanonicalExprNode { } } if (base > 0 || is_min_value) { - res = res + MakeConst(dtype, base); + res = res + IntImm(dtype, base); } // negative scales follows using sub. for (size_t i = 0; i < args.size(); ++i) { @@ -516,7 +516,7 @@ class SumExprNode : public CanonicalExprNode { } } if (base < 0 && !is_min_value) { - res = res - MakeConst(dtype, -base); + res = res - IntImm(dtype, -base); } return res; } @@ -837,8 +837,7 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, return ToSplitExpr(IntImm(lhs.ty(), 0)); } else { // move the upper_factor modular into index. - lhs.CopyOnWrite()->index = - ModImpl(lhs->index, MakeConst(lhs.ty(), lhs->upper_factor), div_mode); + lhs.CopyOnWrite()->index = ModImpl(lhs->index, IntImm(lhs.ty(), lhs->upper_factor), div_mode); lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf; lhs.CopyOnWrite()->scale = 1; lhs.CopyOnWrite()->lower_factor *= scaled_cval; @@ -863,8 +862,8 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, // collect lhs products and try to eliminate by matching them to prod in rhs ffi::Array> lhs_prods; PrimType rhs_ty = prhs->ty(); - PrimExpr new_rhs = MakeConst(rhs_ty, 1); - PrimExpr new_common_scale = MakeConst(rhs_ty, 1); + PrimExpr new_rhs = IntImm(rhs_ty, 1); + PrimExpr new_common_scale = IntImm(rhs_ty, 1); int64_t lhs_cscale = 1, rhs_cscale = 1; int num_elimination = 0; @@ -907,13 +906,13 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, // construct prod via canonical form PrimType lhs_ty = plhs->ty(); - PrimExpr new_lhs = MakeConst(lhs_ty, 1); + PrimExpr new_lhs = IntImm(lhs_ty, 1); for (ffi::Optional val : lhs_prods) { if (val.defined()) new_lhs = new_lhs * val.value(); } - *plhs = new_lhs * MakeConst(lhs_ty, lhs_cscale); - *prhs = new_rhs * MakeConst(rhs_ty, rhs_cscale); - *common_scale = new_common_scale * MakeConst(rhs_ty, cscale_gcd); + *plhs = new_lhs * IntImm(lhs_ty, lhs_cscale); + *prhs = new_rhs * IntImm(rhs_ty, rhs_cscale); + *common_scale = new_common_scale * IntImm(rhs_ty, cscale_gcd); return true; } @@ -1051,7 +1050,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { } // Apply floormod(floordiv_result, m) to complete the identity PrimExpr div_result = Normalize(lhs); - return this->VisitExpr(floormod(div_result, MakeConst(a.ty(), new_mod))); + return this->VisitExpr(floormod(div_result, IntImm(a.ty(), new_mod))); } } } @@ -1098,7 +1097,7 @@ SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval, // Do a recursive call to simplify the mod with the new factor. if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) { auto updated = ToSplitExpr( - this->VisitExpr(ModImpl(lhs->index, MakeConst(lhs.ty(), new_upper_factor), div_mode))); + this->VisitExpr(ModImpl(lhs->index, IntImm(lhs.ty(), new_upper_factor), div_mode))); // re-apply the lower_factor if (lhs->lower_factor != 1) { auto ret = SplitDivConst(updated, lhs->lower_factor, div_mode); @@ -1416,7 +1415,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) { PrimType dtype = divisible->ty(); TVM_FFI_ICHECK(extra->ty() == dtype); PrimExpr normal_extra = extra->Normalize(); - if (this->analyzer_->CanProve(normal_extra < MakeConst(dtype, gcd)) && + if (this->analyzer_->CanProve(normal_extra < IntImm(dtype, gcd)) && this->analyzer_->CanProve(normal_extra >= IntImm(dtype, 0))) { // Case 1. 0 <= xn < d divisible.CopyOnWrite()->DivideBy(gcd); diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h index 4793538316a3..0f1bd4d8a681 100644 --- a/src/arith/const_fold.h +++ b/src/arith/const_fold.h @@ -272,8 +272,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr b) { if (pa->value == 0) return a; } if (pb) { - // MakeConst can handle both vector and scalar types. - if (pb->value == 1) return tirx::MakeConst(result_ty, 0); + if (pb->value == 1) return IntImm(result_ty, 0); TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } }); @@ -329,8 +328,7 @@ inline ffi::Optional TryConstFold(PrimExpr a, PrimExpr if (pa->value == 0) return a; } if (pb) { - // MakeConst can handle both vector and scalar types. - if (pb->value == 1) return tirx::MakeConst(result_ty, 0); + if (pb->value == 1) return IntImm(result_ty, 0); TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero"; } });