Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ class SplitExprNode : public CanonicalExprNode {
return IntImm(dtype, 0);
}
if (this->upper_factor != SplitExprNode::kPosInf) {
res = ModImpl(res, MakeConst(dtype, this->upper_factor), div_mode);
res = ModImpl(res, IntImm(dtype, this->upper_factor), div_mode);
}
if (this->lower_factor != 1) {
res = DivImpl(res, MakeConst(dtype, this->lower_factor), div_mode);
res = DivImpl(res, IntImm(dtype, this->lower_factor), div_mode);
}
sscale *= this->scale;
if (sscale != 1) {
TVM_FFI_ICHECK(dtype.code() != DLDataTypeCode::kDLUInt || sscale > 0);
res = res * MakeConst(dtype, sscale);
res = res * IntImm(dtype, sscale);
}
return res;
}
Expand Down Expand Up @@ -172,20 +172,20 @@ class SplitExprNode : public CanonicalExprNode {
return false;
}
if (this->upper_factor != SplitExprNode::kPosInf) {
res = ModImpl(res, MakeConst(this->ty(), this->upper_factor), div_mode);
res = ModImpl(res, IntImm(this->ty(), this->upper_factor), div_mode);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
if (this->lower_factor != 1) {
res = DivImpl(res, MakeConst(this->ty(), this->lower_factor), div_mode);
res = DivImpl(res, IntImm(this->ty(), this->lower_factor), div_mode);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
if (this->scale != 1) {
TVM_FFI_ICHECK(this->ty().code() != DLDataTypeCode::kDLUInt || this->scale > 0);
res = res * MakeConst(this->ty(), this->scale);
res = res * IntImm(this->ty(), this->scale);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
Expand Down Expand Up @@ -252,7 +252,7 @@ class SumExprNode : public CanonicalExprNode {
PrimExpr Normalize() const final {
// quick path 1.
if (this->args.size() == 0) {
return MakeConst(this->ty(), this->base);
return IntImm(this->ty(), this->base);
}
return Normalize_(this->ty(), SimplifySplitExprs(args), base);
}
Expand Down Expand Up @@ -354,7 +354,7 @@ class SumExprNode : public CanonicalExprNode {
}
}
if (base > 0 || is_min_value) {
res = res + MakeConst(dtype, base);
res = res + IntImm(dtype, base);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
Expand All @@ -369,7 +369,7 @@ class SumExprNode : public CanonicalExprNode {
}
}
if (base < 0 && !is_min_value) {
res = res - MakeConst(dtype, -base);
res = res - IntImm(dtype, -base);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
Expand Down Expand Up @@ -507,7 +507,7 @@ class SumExprNode : public CanonicalExprNode {
}
}
if (base > 0 || is_min_value) {
res = res + MakeConst(dtype, base);
res = res + IntImm(dtype, base);
}
// negative scales follows using sub.
for (size_t i = 0; i < args.size(); ++i) {
Expand All @@ -516,7 +516,7 @@ class SumExprNode : public CanonicalExprNode {
}
}
if (base < 0 && !is_min_value) {
res = res - MakeConst(dtype, -base);
res = res - IntImm(dtype, -base);
}
return res;
}
Expand Down Expand Up @@ -837,8 +837,7 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval,
return ToSplitExpr(IntImm(lhs.ty(), 0));
} else {
// move the upper_factor modular into index.
lhs.CopyOnWrite()->index =
ModImpl(lhs->index, MakeConst(lhs.ty(), lhs->upper_factor), div_mode);
lhs.CopyOnWrite()->index = ModImpl(lhs->index, IntImm(lhs.ty(), lhs->upper_factor), div_mode);
lhs.CopyOnWrite()->upper_factor = SplitExprNode::kPosInf;
lhs.CopyOnWrite()->scale = 1;
lhs.CopyOnWrite()->lower_factor *= scaled_cval;
Expand All @@ -863,8 +862,8 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs,
// collect lhs products and try to eliminate by matching them to prod in rhs
ffi::Array<ffi::Optional<PrimExpr>> lhs_prods;
PrimType rhs_ty = prhs->ty();
PrimExpr new_rhs = MakeConst(rhs_ty, 1);
PrimExpr new_common_scale = MakeConst(rhs_ty, 1);
PrimExpr new_rhs = IntImm(rhs_ty, 1);
PrimExpr new_common_scale = IntImm(rhs_ty, 1);
int64_t lhs_cscale = 1, rhs_cscale = 1;
int num_elimination = 0;

Expand Down Expand Up @@ -907,13 +906,13 @@ bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs,

// construct prod via canonical form
PrimType lhs_ty = plhs->ty();
PrimExpr new_lhs = MakeConst(lhs_ty, 1);
PrimExpr new_lhs = IntImm(lhs_ty, 1);
for (ffi::Optional<PrimExpr> val : lhs_prods) {
if (val.defined()) new_lhs = new_lhs * val.value();
}
*plhs = new_lhs * MakeConst(lhs_ty, lhs_cscale);
*prhs = new_rhs * MakeConst(rhs_ty, rhs_cscale);
*common_scale = new_common_scale * MakeConst(rhs_ty, cscale_gcd);
*plhs = new_lhs * IntImm(lhs_ty, lhs_cscale);
*prhs = new_rhs * IntImm(rhs_ty, rhs_cscale);
*common_scale = new_common_scale * IntImm(rhs_ty, cscale_gcd);
return true;
}

Expand Down Expand Up @@ -1051,7 +1050,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
}
// Apply floormod(floordiv_result, m) to complete the identity
PrimExpr div_result = Normalize(lhs);
return this->VisitExpr(floormod(div_result, MakeConst(a.ty(), new_mod)));
return this->VisitExpr(floormod(div_result, IntImm(a.ty(), new_mod)));
}
}
}
Expand Down Expand Up @@ -1098,7 +1097,7 @@ SplitExpr CanonicalSimplifier::Impl::SplitModConst(SplitExpr lhs, int64_t cval,
// Do a recursive call to simplify the mod with the new factor.
if (new_upper_factor < lhs->upper_factor && lhs->upper_factor != SplitExprNode::kPosInf) {
auto updated = ToSplitExpr(
this->VisitExpr(ModImpl(lhs->index, MakeConst(lhs.ty(), new_upper_factor), div_mode)));
this->VisitExpr(ModImpl(lhs->index, IntImm(lhs.ty(), new_upper_factor), div_mode)));
// re-apply the lower_factor
if (lhs->lower_factor != 1) {
auto ret = SplitDivConst(updated, lhs->lower_factor, div_mode);
Expand Down Expand Up @@ -1416,7 +1415,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) {
PrimType dtype = divisible->ty();
TVM_FFI_ICHECK(extra->ty() == dtype);
PrimExpr normal_extra = extra->Normalize();
if (this->analyzer_->CanProve(normal_extra < MakeConst(dtype, gcd)) &&
if (this->analyzer_->CanProve(normal_extra < IntImm(dtype, gcd)) &&
this->analyzer_->CanProve(normal_extra >= IntImm(dtype, 0))) {
// Case 1. 0 <= xn < d
divisible.CopyOnWrite()->DivideBy(gcd);
Expand Down
6 changes: 2 additions & 4 deletions src/arith/const_fold.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,7 @@ inline ffi::Optional<PrimExpr> TryConstFold<tirx::Mod>(PrimExpr a, PrimExpr b) {
if (pa->value == 0) return a;
}
if (pb) {
// MakeConst can handle both vector and scalar types.
if (pb->value == 1) return tirx::MakeConst(result_ty, 0);
if (pb->value == 1) return IntImm(result_ty, 0);
TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero";
}
});
Expand Down Expand Up @@ -329,8 +328,7 @@ inline ffi::Optional<PrimExpr> TryConstFold<tirx::FloorMod>(PrimExpr a, PrimExpr
if (pa->value == 0) return a;
}
if (pb) {
// MakeConst can handle both vector and scalar types.
if (pb->value == 1) return tirx::MakeConst(result_ty, 0);
if (pb->value == 1) return IntImm(result_ty, 0);
TVM_FFI_ICHECK_NE(pb->value, 0) << "Divide by zero";
}
});
Expand Down
Loading