Skip to content

[REFACTOR][IR] Clean up PrimType follow-ups#19884

Merged
tqchen merged 3 commits into
apache:mainfrom
tqchen:tvm-cleanup-primtype-refactor-follow-up-issues
Jun 25, 2026
Merged

[REFACTOR][IR] Clean up PrimType follow-ups#19884
tqchen merged 3 commits into
apache:mainfrom
tqchen:tvm-cleanup-primtype-refactor-follow-up-issues

Conversation

@tqchen

@tqchen tqchen commented Jun 25, 2026

Copy link
Copy Markdown
Member

This PR follows up on the PrimType refactor by cleaning up a few residual type-boundary cases.

Summary:

  • Use expression PrimType access for Metal TVMScript shuffle helpers.
  • Extend namespace cleanup coverage for Metal shuffle up/down helpers.
  • Simplify narrow backend Python raw dtype boundaries while leaving true raw ABI/codegen boundaries intact.
  • Inline PrimType::StorageBytes() in the header and remove the exported out-of-line symbol.

Validation:

  • git diff --check
  • pre-commit run --files python/tvm/backend/metal/script.py tests/python/tirx/test_op_namespace_cleanup.py python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py python/tvm/backend/trn/transform/naive_allocator.py
  • pre-commit run --files include/tvm/ir/base_expr.h src/ir/type.cc
  • cmake --build build --target tvm_runtime tvm_compiler -j$(nproc)
  • python -m pytest tests/python/tirx/test_op_namespace_cleanup.py -xvs -rs (passes with 10 skips due the suite's sm_100a device gate)
  • checked libtvm_compiler.so no longer exports tvm::PrimType::StorageBytes

tqchen added 2 commits June 25, 2026 02:17
The PrimExpr type migration requires TVMScript intrinsic helpers to read expression types through PrimType rather than the removed dtype shortcut. This updates the Metal shuffle helpers to match the CUDA namespace pattern and extends namespace coverage for the additional shuffle forms.

The backend raw dtype audit also removes redundant DataType reconstruction at the two named byte/bit calculation boundaries while keeping the explicit runtime dtype unwraps local to those raw-size decisions.
PrimType storage sizing is a small value helper and does not need an exported out-of-line symbol. Inline StorageBytes in the header and remove the compiler-library definition.
@tqchen tqchen changed the title [REFRACTOR][IR] Clean up PrimType follow-ups [REFACTOR][IR] Clean up PrimType follow-ups Jun 25, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request inlines the StorageBytes() method of PrimType in include/tvm/ir/base_expr.h, removes its out-of-line definition in src/ir/type.cc, and cleans up unused DataType imports and dtype property accesses in Python backends. It also updates Metal SIMD shuffle intrinsics to use var.ty instead of var.dtype and adds corresponding unit tests. Regarding the feedback, the implementation of StorageBytes() contains a redundant check for 1-bit, 1-lane kDLUInt that can be removed since the general formula already handles this case correctly.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread include/tvm/ir/base_expr.h Outdated
Comment on lines +224 to +229
if (dtype.code == kDLUInt && dtype.bits == 1 && dtype.lanes == 1) {
return 1;
}
return static_cast<size_t>(
(static_cast<uint64_t>(dtype.bits) * static_cast<uint64_t>(dtype.lanes) + 7) / 8);
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The special case check for kDLUInt with 1 bit and 1 lane is redundant. The general formula (bits * lanes + 7) / 8 already correctly evaluates to 1 when bits == 1 and lanes == 1 (i.e., (1 * 1 + 7) / 8 = 1). Removing this check simplifies the code and avoids unnecessary branching.

    return static_cast<size_t>(
        (static_cast<uint64_t>(dtype.bits) * static_cast<uint64_t>(dtype.lanes) + 7) / 8);

The general byte-rounding formula already returns one byte for uint1 scalars, so remove the redundant special case from StorageBytes.
@tqchen tqchen merged commit 59516c1 into apache:main Jun 25, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants