Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion include/tvm/ir/base_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,16 @@ class PrimType final : public Type {
* This uses the same packed sub-byte dtype sizing rule as runtime tensors.
* Scalable vector types have no compile-time storage size and are rejected.
*/
TVM_DLL size_t StorageBytes() const;
TVM_FFI_INLINE size_t StorageBytes() const {
DLDataType dtype = get()->dtype;
int16_t encoded_lanes = static_cast<int16_t>(dtype.lanes);
if (TVM_FFI_PREDICT_FALSE(encoded_lanes < 0)) {
TVM_FFI_THROW(InternalError)
<< "Cannot compute compile-time storage bytes for non-fixed vector type " << dtype;
}
return static_cast<size_t>(
(static_cast<uint64_t>(dtype.bits) * static_cast<uint64_t>(dtype.lanes) + 7) / 8);
}

/*! \brief Return the same type with a different dtype code, preserving bits and lanes. */
TVM_FFI_INLINE PrimType WithCode(DLDataTypeCode code) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from __future__ import annotations

from tvm.runtime import DataType
from tvm.script import tirx as T
from tvm.tirx import PrimFunc, TilePrimitiveCall
from tvm.tirx.operator.tile_primitive import DispatchContext
Expand Down Expand Up @@ -100,10 +99,10 @@ def check(op_call: TilePrimitiveCall, sctx: DispatchContext) -> tuple[bool, str
def _max_layout_vec(plan, total: int, thread_cnt: int) -> int:
"""Widest vec_chunk dividing all operands' innermost extents AND
``total / thread_cnt``, within dtype-bit candidates ``{128,64,32,16,8}``."""
max_bits = DataType(plan.dst.buffer.dtype.dtype).bits
max_bits = plan.dst.buffer.dtype.dtype.bits
for s in plan.srcs:
if s.buf_region is not None:
max_bits = max(max_bits, DataType(s.buf_region.buffer.dtype.dtype).bits)
max_bits = max(max_bits, s.buf_region.buffer.dtype.dtype.bits)
per_thread = total // thread_cnt if thread_cnt > 0 else total
if total % thread_cnt != 0:
return 1
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/backend/metal/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ def __init__(self):
def simd_shuffle(var, lane):
if isinstance(var, Buffer):
var = var[0]
return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle", var, lane)
return _tir_op.call_intrin(var.ty, "tirx.metal.simd_shuffle", var, lane)

@staticmethod
def simd_shuffle_up(var, delta):
if isinstance(var, Buffer):
var = var[0]
return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_up", var, delta)
return _tir_op.call_intrin(var.ty, "tirx.metal.simd_shuffle_up", var, delta)

@staticmethod
def simd_shuffle_down(var, delta):
if isinstance(var, Buffer):
var = var[0]
return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_down", var, delta)
return _tir_op.call_intrin(var.ty, "tirx.metal.simd_shuffle_down", var, delta)


__all__ = ["MetalNamespace"]
3 changes: 1 addition & 2 deletions python/tvm/backend/trn/transform/naive_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import functools

from tvm import DataType
from tvm.tirx import AllocBuffer, IntImm
from tvm.tirx.buffer import Buffer
from tvm.tirx.stmt_functor import StmtVisitor
Expand Down Expand Up @@ -48,7 +47,7 @@ def get_buffer_size(buffer: Buffer) -> int:
raise ValueError(
f"Buffer {buffer.name} has non-constant shape. Do not know how to allocate it."
)
return int(num_elem * DataType(buffer.dtype.dtype).itemsize)
return int(num_elem * buffer.dtype.dtype.itemsize)


class AllocInfoCollector(StmtVisitor):
Expand Down
10 changes: 0 additions & 10 deletions src/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
* \file src/ir/type.cc
* \brief Common type system AST nodes throughout the IR.
*/
#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/type.h>
Expand Down Expand Up @@ -136,15 +135,6 @@ PrimType PrimType::ScalableVector(DLDataTypeCode code, int bits, int lanes) {
return PrimType(ScalableVectorDType(code, bits, lanes));
}

size_t PrimType::StorageBytes() const {
int16_t encoded_lanes = static_cast<int16_t>(get()->dtype.lanes);
if (TVM_FFI_PREDICT_FALSE(encoded_lanes < 0)) {
TVM_FFI_THROW(InternalError)
<< "Cannot compute compile-time storage bytes for non-fixed vector type " << get()->dtype;
}
return ffi::GetDataSize(1, get()->dtype);
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("ir.PrimType", [](DLDataType dtype) { return PrimType(dtype); });
Expand Down
8 changes: 8 additions & 0 deletions tests/python/tirx/test_op_namespace_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,17 +295,23 @@ def device_namespaces(dst: T.handle, src: T.handle):
T.cuda.copy_bytes(dst, src, 16)
T.ptx.ldg32(R[0], 1, A[0], 0)
T.metal.simd_shuffle(A[0], 0)
T.metal.simd_shuffle_up(A[0], 1)
T.metal.simd_shuffle_down(A[0], 1)

calls = _expr_calls(device_namespaces)
assert [call.op.name for call in calls] == [
"tirx.cuda.copy_bytes",
"tirx.ptx.ldg32",
"tirx.metal.simd_shuffle",
"tirx.metal.simd_shuffle_up",
"tirx.metal.simd_shuffle_down",
]
for op_name, namespace in [
("tirx.cuda.copy_bytes", "cuda"),
("tirx.ptx.ldg32", "ptx"),
("tirx.metal.simd_shuffle", "metal"),
("tirx.metal.simd_shuffle_up", "metal"),
("tirx.metal.simd_shuffle_down", "metal"),
]:
assert _op_attr(op_name, "TIRxOpCategory") == "device_intrin"
assert _op_attr(op_name, "TDeviceIntrinsicNamespace") == namespace
Expand All @@ -315,6 +321,8 @@ def device_namespaces(dst: T.handle, src: T.handle):
assert "T.cuda.copy_bytes(" in code
assert "T.ptx.ldg32(" in code
assert "T.metal.simd_shuffle(" in code
assert "T.metal.simd_shuffle_up(" in code
assert "T.metal.simd_shuffle_down(" in code
assert "T.tirx." not in code
reparsed = tvm.script.from_source(code)
assert reparsed.script() == code
Expand Down
Loading