[REFACTOR][IR] Clean up PrimType follow-ups#19884
Conversation
The PrimExpr type migration requires TVMScript intrinsic helpers to read expression types through PrimType rather than the removed dtype shortcut. This updates the Metal shuffle helpers to match the CUDA namespace pattern and extends namespace coverage for the additional shuffle forms. The backend raw dtype audit also removes redundant DataType reconstruction at the two named byte/bit calculation boundaries while keeping the explicit runtime dtype unwraps local to those raw-size decisions.
PrimType storage sizing is a small value helper and does not need an exported out-of-line symbol. Inline StorageBytes in the header and remove the compiler-library definition.
There was a problem hiding this comment.
Code Review
This pull request inlines the StorageBytes() method of PrimType in include/tvm/ir/base_expr.h, removes its out-of-line definition in src/ir/type.cc, and cleans up unused DataType imports and dtype property accesses in Python backends. It also updates Metal SIMD shuffle intrinsics to use var.ty instead of var.dtype and adds corresponding unit tests. Regarding the feedback, the implementation of StorageBytes() contains a redundant check for 1-bit, 1-lane kDLUInt that can be removed since the general formula already handles this case correctly.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if (dtype.code == kDLUInt && dtype.bits == 1 && dtype.lanes == 1) { | ||
| return 1; | ||
| } | ||
| return static_cast<size_t>( | ||
| (static_cast<uint64_t>(dtype.bits) * static_cast<uint64_t>(dtype.lanes) + 7) / 8); | ||
| } |
There was a problem hiding this comment.
The special case check for kDLUInt with 1 bit and 1 lane is redundant. The general formula (bits * lanes + 7) / 8 already correctly evaluates to 1 when bits == 1 and lanes == 1 (i.e., (1 * 1 + 7) / 8 = 1). Removing this check simplifies the code and avoids unnecessary branching.
return static_cast<size_t>(
(static_cast<uint64_t>(dtype.bits) * static_cast<uint64_t>(dtype.lanes) + 7) / 8);The general byte-rounding formula already returns one byte for uint1 scalars, so remove the redundant special case from StorageBytes.
This PR follows up on the PrimType refactor by cleaning up a few residual type-boundary cases.
Summary:
PrimType::StorageBytes()in the header and remove the exported out-of-line symbol.Validation:
git diff --checkpre-commit run --files python/tvm/backend/metal/script.py tests/python/tirx/test_op_namespace_cleanup.py python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py python/tvm/backend/trn/transform/naive_allocator.pypre-commit run --files include/tvm/ir/base_expr.h src/ir/type.cccmake --build build --target tvm_runtime tvm_compiler -j$(nproc)python -m pytest tests/python/tirx/test_op_namespace_cleanup.py -xvs -rs(passes with 10 skips due the suite's sm_100a device gate)libtvm_compiler.sono longer exportstvm::PrimType::StorageBytes