diff --git a/examples/python/qqmm.py b/examples/python/qqmm.py index 5be7eae2f3..dc46fa29a0 100644 --- a/examples/python/qqmm.py +++ b/examples/python/qqmm.py @@ -33,48 +33,71 @@ def test_qqmm(): [64, 128, 256, 1024, 1024 * 8], # N [64, 128, 256, 1024, 1024 * 8], # K ) + layouts = ["TN", "NT", "TT", "NN"] for group_size, mode, bits in tests: for M, N, K in product(*shapes): for dtype in dtypes: - x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype) - w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype) - w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode) - w_dq = mx.dequantize( - w_q, - scales_w, - group_size=group_size, - bits=bits, - mode=mode, - dtype=dtype, - ) - y_q = mx.qqmm( - x, - w_q, - scales_w, - group_size=group_size, - bits=bits, - mode=mode, - ) - x_q, scales_x = mx.quantize( - x, group_size=group_size, bits=bits, mode=mode - ) - x_dq = mx.dequantize( - x_q, - scales_x, - group_size=group_size, - bits=bits, - mode=mode, - dtype=dtype, - ) - y_hat = mx.matmul(x_dq, mx.transpose(w_dq)) - ulp = ulp_bf16_at(y_hat) - error = (y_q - y_hat).abs() - if not (mx.logical_or(error < 1e-3, error <= ulp).all()): - raise AssertionError( - f"qqmm test failed for shape {(M, N, K)}, " - f"group_size={group_size}, bits={bits}, " - f"mode={mode}, dtype={dtype}" + for layout in layouts: + if layout == "NT": + x_shape = (M, K) + w_shape = (N, K) + elif layout == "TN": + x_shape = (K, M) + w_shape = (K, N) + elif layout == "TT": + x_shape = (K, M) + w_shape = (N, K) + else: # "NN" + x_shape = (M, K) + w_shape = (K, N) + + x = mx.random.normal(shape=x_shape, key=k1, dtype=dtype) + w = mx.random.normal(shape=w_shape, key=k2, dtype=dtype) + + if layout == "TT": + x = mx.transpose(x) + elif layout == "TN": + w = mx.transpose(w) + x = mx.transpose(x) + elif layout == "NN": + w = mx.transpose(w) + + y_q = mx.qqmm( + x, + w, + group_size=group_size, + bits=bits, + mode=mode, + ) + w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode) + w_dq = mx.dequantize( + w_q, + scales_w, + group_size=group_size, + bits=bits, + mode=mode, + dtype=dtype, + ) + x_q, scales_x = mx.quantize( + x, group_size=group_size, bits=bits, mode=mode + ) + x_dq = mx.dequantize( + x_q, + scales_x, + group_size=group_size, + bits=bits, + mode=mode, + dtype=dtype, ) + y_hat = mx.matmul(x_dq, mx.transpose(w_dq)) + ulp = ulp_bf16_at(y_hat) + error = (y_q - y_hat).abs() + if not (mx.logical_or(error < 1e-3, error <= ulp).all()): + raise AssertionError( + f"qqmm test failed for shape {(M, N, K)}, " + f"group_size={group_size}, bits={bits}, " + f"mode={mode}, dtype={dtype}, layout={layout}" + ) def test_qqmm_vjp(): diff --git a/mlx/backend/cuda/ptx.cuh b/mlx/backend/cuda/ptx.cuh new file mode 100644 index 0000000000..6ec7caedd6 --- /dev/null +++ b/mlx/backend/cuda/ptx.cuh @@ -0,0 +1,127 @@ +#pragma once + +#include +#include + +namespace mlx::core { + +namespace ptx { + +#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \ + defined(__CUDA_ARCH_SPECIFIC__) + +__device__ __forceinline__ void mbarrier_init(uint64_t* mbar, uint32_t count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.init.shared.b64 [%0], %1;" + : + : "r"(mbar_ptr), "r"(count) + : "memory"); +} + +__device__ __forceinline__ void mbarrier_invalidate(uint64_t* mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.inval.shared.b64 [%0];" : : "r"(mbar_ptr) : "memory"); +} + +// Arrive at barrier (non-master threads) +__device__ __forceinline__ void mbarrier_arrive(uint64_t* mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.shared.b64 _, [%0];" + : + : "r"(mbar_ptr) + : "memory"); +} + +// Arrive at barrier and set expected transaction count (master thread) +__device__ __forceinline__ void mbarrier_arrive_expect_tx( + uint64_t* mbar, + uint32_t tx_count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + : + : "r"(mbar_ptr), "r"(tx_count) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-asynchronous-copy-completion-mechanisms-mbarrier +__device__ __forceinline__ void mbarrier_wait_parity( + uint64_t* mbar, + uint32_t parity) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile( + "{\n\t" + ".reg .pred P;\n\t" + "WAIT_LOOP:\n\t" + "mbarrier.try_wait.parity.shared.b64 P, [%0], %1;\n\t" + "@!P bra WAIT_LOOP;\n\t" + "}\n\t" + : + : "r"(mbar_ptr), "r"(parity) + : "memory"); +} + +// Async bulk tensor copy: global -> shared (2D) +__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( + void* dst_shmem, + const CUtensorMap* tensor_map, + uint32_t tile_x, + uint32_t tile_y, + uint64_t* mbar) { + uint32_t dst_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" + : + : "r"(dst_ptr), "l"(tensor_map), "r"(tile_x), "r"(tile_y), "r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( + const uint64_t* tensor_map_ptr, + const uint32_t offset_x, + const uint32_t offset_y, + uint64_t* src_shmem) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" :: + "l"(tensor_map_ptr), + "r"(offset_x), + "r"(offset_y), + "r"(src_shmem_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +template +__device__ __forceinline__ void cp_async_bulk_wait_group_read() { + if constexpr (N == 0) { + asm volatile("cp.async.bulk.wait_group.read 0;"); + } else if constexpr (N == 1) { + asm volatile("cp.async.bulk.wait_group.read 1;"); + } else if constexpr (N == 2) { + asm volatile("cp.async.bulk.wait_group.read 2;"); + } else if constexpr (N == 4) { + asm volatile("cp.async.bulk.wait_group.read 4;"); + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group +__device__ __forceinline__ void cp_async_bulk_commit_group() { + asm volatile("cp.async.bulk.commit_group;"); +} + +// Ccreates a memory ordering barrier between generic and async proxies +// details: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#proxies +__device__ __forceinline__ void fence_proxy_async_shared_cta() { + asm volatile("fence.proxy.async.shared::cta;"); +} + +#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && + // (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000) +} // namespace ptx +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index d36ae6581e..301e0660a2 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -2,8 +2,8 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" -#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" +#include "mlx/backend/cuda/quantized/fp_quantize.cuh" +#include "mlx/backend/cuda/quantized/fp_quantize_tma.cuh" #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/quantized/quantized_utils.cuh" #include "mlx/backend/cuda/vector_types.cuh" @@ -14,353 +14,58 @@ #include #include -constexpr float F8E4M3_MAX = 448.0f; -constexpr float F4E2M1_MAX = 6.0f; - namespace mlx::core { namespace cu { -template -struct Dequantize { - __device__ float operator()(uint8_t x) { - if constexpr (bits == 8) { - return float(*(__nv_fp8_e4m3*)(&x)); - } else { - return float(*(__nv_fp4_e2m1*)(&x)); - } - } -}; - -namespace cg = cooperative_groups; - -template -__global__ void fp_quantize_dequantize( - T* w, - T* out, - size_t size, - float* global_scale = nullptr) { - const bool use_global_scale = global_scale != nullptr; - const float scale_enc = - use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; - const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f; - - using Tx2 = Vector2_t; - using Tx4 = Vector4_t; - uint32_t rbits = 0; // reserved bits for future use - auto block_size = cg::this_thread_block().dim_threads(); - auto block_idx = cg::this_thread_block().group_index(); - auto idx_in_block = cg::this_thread_block().thread_index(); - auto tidx = block_idx.x * block_size.x + idx_in_block.x; - auto tidy = block_idx.y * block_size.y + idx_in_block.y; - auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; - - size_t thread_idx = tidx + grid_dim_x * size_t(tidy); - size_t base_idx = thread_idx * group_size; - - if (base_idx >= size) { - return; - } - - auto w_tile = load_vector(w, thread_idx); - float scale_dec_b = 0.0f; - - Tx2 amax_2x = Tx2{0.0f, 0.0f}; - -#pragma unroll - for (int i = 0; i < group_size; i += 2) { - auto pair = Tx2{w_tile[i], w_tile[i + 1]}; - absmax_x2(amax_2x, amax_2x, pair); - } - - scale_dec_b = static_cast( - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y)))); - - scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; - scale_dec_b *= scale_enc; - // Convert to mx scale or nv scale - using ScaleType = - std::conditional_t; - auto s = ScaleType(scale_dec_b); - float scale_enc_b = scale_enc / float(s); - float scale_dec = float(s) * inv_scale_enc; - AlignedVector w_hat; - -#pragma unroll - for (int i = 0; i < group_size / 4; i++) { - Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); - float4 dq; - if constexpr (bits == 8) { - uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); - dq = dequant_fp8(quantized_val); - } else { - uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); - dq = dequant_fp4(quantized_val); - } - w_hat[i * 4] = static_cast(dq.x * scale_dec); - w_hat[i * 4 + 1] = static_cast(dq.y * scale_dec); - w_hat[i * 4 + 2] = static_cast(dq.z * scale_dec); - w_hat[i * 4 + 3] = static_cast(dq.w * scale_dec); - } - store_vector(out, thread_idx, w_hat); -} - -template -__global__ void fp_quantize_rowwise( - T* w, - uint8_t* out, - uint8_t* scales, - size_t size, - float* global_scale = nullptr) { - // NVFP4 conversion: - // Global encode scale: (448 × 6) / *global_scale - // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8 - // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b - const bool use_global_scale = global_scale != nullptr; - const float scale_enc = - use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; - - using Tx2 = Vector2_t; - using Tx4 = Vector4_t; - uint32_t rbits = 0; // reserved bits for future use - auto block_size = cg::this_thread_block().dim_threads(); - auto block_idx = cg::this_thread_block().group_index(); - auto idx_in_block = cg::this_thread_block().thread_index(); - auto tidx = block_idx.x * block_size.x + idx_in_block.x; - auto tidy = block_idx.y * block_size.y + idx_in_block.y; - auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; - - size_t thread_idx = tidx + grid_dim_x * size_t(tidy); - size_t base_idx = thread_idx * group_size; - - if (base_idx >= size) { - return; - } - - auto w_tile = load_vector(w, thread_idx); - float scale_dec_b = 0.0f; - - Tx2 amax_2x = Tx2{0.0f, 0.0f}; - -#pragma unroll - for (int i = 0; i < group_size; i += 2) { - auto pair = Tx2{w_tile[i], w_tile[i + 1]}; - absmax_x2(amax_2x, amax_2x, pair); - } - - scale_dec_b = static_cast( - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y)))); - - scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; - scale_dec_b *= scale_enc; - // Convert to mx scale or nv scale - using ScaleType = - std::conditional_t; - auto s = ScaleType(scale_dec_b); - uint8_t q_scale = s.__x; - float scale_enc_b = scale_enc / float(s); - - scales[thread_idx] = q_scale; - constexpr int elem_per_byte = bits == 8 ? 1 : 2; - AlignedVector quantized; - -#pragma unroll - for (int i = 0; i < group_size / 4; i++) { - Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); - if constexpr (bits == 8) { - uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); - *reinterpret_cast(&quantized[i * 4]) = quantized_val; - } else { - uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); - *reinterpret_cast(&quantized[i * 2]) = quantized_val; - } - } - store_vector(out, thread_idx, quantized); -} - -template -__global__ void fp_quantize_columnwise( - T* w, - uint8_t* out, - uint8_t* scales, - size_t size, - int M, - int K, - float* global_scale = nullptr) { - // Input: [M, K] with strides [1, M] (M-major) - // Quantized output: [M, K/elem_per_byte] row-major (K-major) - // Scales: [M, K/group_size] row-major (K-major) - // Quantize along K (last dimension, groups of group_size elements) - const bool use_global_scale = global_scale != nullptr; - const float scale_enc = - use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; - - using Tx2 = Vector2_t; - using Tx4 = Vector4_t; - uint32_t rbits = 0; - - auto block_idx = cg::this_thread_block().group_index(); - auto idx_in_block = cg::this_thread_block().thread_index(); - - constexpr int BLOCK_X = 16; - constexpr int BLOCK_Y = 32; - constexpr int elem_per_byte = (bits == 8) ? 1 : 2; - constexpr int bytes_per_group = group_size / elem_per_byte; - - constexpr int rows_per_block = BLOCK_X; - constexpr int cols_per_block = BLOCK_Y * group_size; - constexpr int local_cols = cols_per_block / elem_per_byte; - constexpr int bytes_per_block = rows_per_block * local_cols; - - constexpr int SMEM_PAD = 4; - constexpr int padded_local_cols = local_cols + SMEM_PAD; - - auto tidx = idx_in_block.x; - auto tidy = idx_in_block.y; - - int num_col_blocks = (K + cols_per_block - 1) / cols_per_block; - auto bidx = block_idx.x % num_col_blocks; - auto bidy = block_idx.x / num_col_blocks; - - T thread_data[group_size]; - - __shared__ uint8_t quantized_smem[rows_per_block * padded_local_cols]; - __shared__ uint8_t scales_smem[BLOCK_X][BLOCK_Y + SMEM_PAD]; - - int row_base = bidy * rows_per_block + tidx; - int col_base = bidx * cols_per_block + tidy * group_size; - - bool valid = (row_base < M) && (col_base + group_size <= K); - if (valid) { -#pragma unroll - for (int i = 0; i < group_size; i++) { - auto index = row_base + (col_base + i) * M; - thread_data[i] = w[index]; - } +template < + int TILE_M, + int TILE_K, + int THREADS_PER_BLOCK, + int STAGES, + int SCALES_PER_STAGE> +inline std::tuple get_tma_launch_args( + size_t grid_dim_x_size, // rows + size_t grid_dim_y_size, // cols + size_t block_size_x, // ROWS_PER_BLOCK + size_t block_size_y, // COL_PER_BLOCK + int in_size_bytes, // itemsize + int bits) { + dim3 grid; + grid.x = (grid_dim_x_size + block_size_x - 1) / block_size_x; + grid.y = (grid_dim_y_size + block_size_y - 1) / block_size_y; + grid.z = 1; - // Compute scale - Tx2 amax_2x = Tx2{0.0f, 0.0f}; -#pragma unroll - for (int r = 0; r < group_size; r += 2) { - auto pair = Tx2{thread_data[r], thread_data[r + 1]}; - absmax_x2(amax_2x, amax_2x, pair); - } - float scale_dec_b = - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y))); - scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; - scale_dec_b *= scale_enc; - // Convert to mx scale or nv scale - using ScaleType = - std::conditional_t; - auto s = ScaleType(scale_dec_b); - float scale_enc_b = scale_enc / float(s); - scales_smem[tidx][tidy] = s.__x; - - int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group; - -#pragma unroll - for (int j = 0; j < group_size / 4; j++) { - Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); - if constexpr (bits == 8) { - uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); - *reinterpret_cast(&quantized_smem[shared_idx + j * 4]) = - quantized_val; - } else { - uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); - *reinterpret_cast(&quantized_smem[shared_idx + j * 2]) = - quantized_val; - } - } - } - __syncthreads(); - - int output_cols = K / elem_per_byte; - int num_groups_per_row = K / group_size; - int linear_tid = tidx + tidy * BLOCK_X; - // Write back quantized values -#pragma unroll - for (int i = linear_tid; i < bytes_per_block; i += BLOCK_X * BLOCK_Y) { - int local_row = i / local_cols; - int local_col = i % local_cols; - - int global_row = bidy * rows_per_block + local_row; - int global_col = bidx * local_cols + local_col; - - if (global_row < M && global_col < output_cols) { - int physical_idx = local_row * padded_local_cols + local_col; - out[global_row * output_cols + global_col] = quantized_smem[physical_idx]; - } - } - // Write back scales - constexpr int num_scales = BLOCK_X * BLOCK_Y; -#pragma unroll - for (int i = linear_tid; i < num_scales; i += BLOCK_X * BLOCK_Y) { - int local_row = i / BLOCK_Y; - int local_col = i % BLOCK_Y; - - int global_row = bidy * BLOCK_X + local_row; - int global_col = bidx * BLOCK_Y + local_col; - - if (global_row < M && global_col < num_groups_per_row) { - scales[global_row * num_groups_per_row + global_col] = - scales_smem[local_row][local_col]; - } - } + dim3 block(THREADS_PER_BLOCK, 1, 1); + + const int elem_per_byte = bits == 8 ? 1 : 2; + constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; + const size_t in_tile_size = BUFF_ELEMS * in_size_bytes; + const size_t in_buff_size_aligned = + ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + const size_t out_tile_elems = BUFF_ELEMS / elem_per_byte; + const size_t out_buff_size_aligned = + ((out_tile_elems * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + const size_t smem_size = + in_buff_size_aligned + out_buff_size_aligned + TMA_SHMEM_ALIGNMENT; + return std::make_tuple(grid, block, smem_size); } -template -__global__ void fp_dequantize( - const uint8_t* w, - const uint8_t* scales, - T* out, - size_t size, - float* global_scale = nullptr) { - auto block_size = cg::this_thread_block().dim_threads(); - auto block_idx = cg::this_thread_block().group_index(); - auto idx_in_block = cg::this_thread_block().thread_index(); - - auto tidx = block_idx.x * block_size.x + idx_in_block.x; - auto tidy = block_idx.y * block_size.y + idx_in_block.y; - - auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; - - constexpr int pack_factor = bits == 8 ? 1 : 2; - const bool use_global_scale = global_scale != nullptr; - const float inv_scale_enc = use_mx_scale - ? 1.0f - : (use_global_scale ? (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX) : 1.0f); - size_t offset = tidx + grid_dim_x * size_t(tidy); - size_t oindex = offset * pack_factor; - - if (oindex >= size) { - return; - } - - size_t gindex = oindex / group_size; - using ScaleType = - std::conditional_t; - auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc; - - out += oindex; - - uint32_t val = w[offset]; -#pragma clang loop unroll(full) - for (int i = 0; i < pack_factor; i++) { - uint8_t d; - if (bits == 4) { - d = (val >> (bits * i)) & 0x0f; - } else if (bits == 8) { - d = val; - } - out[i] = static_cast(scale * Dequantize{}(d)); +inline CUtensorMapDataType get_tma_dtype(Dtype dtype) { + switch (dtype) { + case float16: + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + case bfloat16: + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + case float32: + return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + default: + throw std::runtime_error( + "[fp_quantize_columnwise_tma] Unsupported dtype for TMA"); } } @@ -424,7 +129,7 @@ void fp_quantize_dequantize( }); } -void fp_quantize( +void fp_quantize_rowwise( const array& w, array& wq, array& scales, @@ -439,68 +144,218 @@ void fp_quantize( } enc.set_output_array(wq); enc.set_output_array(scales); - if (w.strides().back() != 1) { - dispatch_float_types(w.dtype(), "fp_quantize_columnwise", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - auto M = w.shape(-2); - auto K = w.shape(-1); - auto kernel = cu::fp_quantize_columnwise; - if (bits == 8) { - kernel = cu::fp_quantize_columnwise; - } else if (group_size == 16) { - kernel = cu::fp_quantize_columnwise; - } - auto [num_blocks, block_dims] = - cu::get_columnwise_quantize_launch_args(w.size(), group_size, M, K); - enc.add_kernel_node( - kernel, - num_blocks, - block_dims, - 0, - gpu_ptr(w), - gpu_ptr(wq), - gpu_ptr(scales), - w.size(), - M, - K, - global_scale.has_value() ? gpu_ptr(global_scale.value()) - : nullptr); - } else { - throw std::runtime_error( - "[Quantize::eval_gpu] Can not quantize input with type float64."); + dispatch_float_types(w.dtype(), "fp_quantize_rowwise", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto kernel = cu::fp_quantize_rowwise; + if (bits == 8) { + kernel = cu::fp_quantize_rowwise; + } else if (group_size == 16) { + kernel = cu::fp_quantize_rowwise; } - }); - } else { - dispatch_float_types(w.dtype(), "fp_quantize_rowwise", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - auto kernel = cu::fp_quantize_rowwise; - if (bits == 8) { - kernel = cu::fp_quantize_rowwise; - } else if (group_size == 16) { - kernel = cu::fp_quantize_rowwise; + bool large = w.size() > UINT_MAX; + auto [num_blocks, block_dims] = + get_launch_args(w.size(), w.shape(), w.strides(), large, group_size); + + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), + w.size(), + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not quantize input with type float64."); + } + }); +} + +void fp_quantize_columnwise_fallback( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + const std::optional& global_scale /* = std::nullopt */, + cu::CommandEncoder& enc, + const Stream& s) { + enc.set_input_array(w); + if (global_scale.has_value()) { + enc.set_input_array(global_scale.value()); + } + enc.set_output_array(wq); + enc.set_output_array(scales); + dispatch_float_types( + w.dtype(), "fp_quantize_columnwise_fallback", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto M = w.shape(-2); + auto K = w.shape(-1); + auto kernel = + cu::fp_quantize_columnwise_fallback; + if (bits == 8) { + kernel = cu::fp_quantize_columnwise_fallback; + } else if (group_size == 16) { + kernel = + cu::fp_quantize_columnwise_fallback; + } + auto [num_blocks, block_dims] = + cu::get_columnwise_quantize_launch_args( + w.size(), group_size, M, K); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), + w.size(), + M, + K, + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not quantize input with type float64."); } - bool large = w.size() > UINT_MAX; - auto [num_blocks, block_dims] = get_launch_args( - w.size(), w.shape(), w.strides(), large, group_size); - - enc.add_kernel_node( - kernel, - num_blocks, - block_dims, - 0, - gpu_ptr(w), - gpu_ptr(wq), - gpu_ptr(scales), - w.size(), - global_scale.has_value() ? gpu_ptr(global_scale.value()) - : nullptr); - } else { - throw std::runtime_error( - "[Quantize::eval_gpu] Can not quantize input with type float64."); - } - }); + }); +} + +void fp_quantize_columnwise_tma( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + const std::optional& global_scale /* = std::nullopt */, + cu::CommandEncoder& enc, + const Stream& s) { + enc.set_input_array(w); + enc.set_output_array(wq); + enc.set_output_array(scales); + + size_t rows = w.shape(-1); + size_t cols = w.size() / rows; + size_t stride_bytes = w.strides(-1) * w.itemsize(); + + if (bits == 8 && group_size == 32) { + dispatch_float_types( + w.dtype(), "fp_quantize_columnwise_mxfp8", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + constexpr int THREADS_PER_BLOCK = 64; + constexpr int ROWS_PER_BLOCK = 64; + constexpr int COLS_PER_BLOCK = 64; + constexpr size_t TILE_M = 32; + constexpr size_t TILE_K = COLS_PER_BLOCK; + constexpr size_t STAGES = ROWS_PER_BLOCK / TILE_M; + + // For columnwise: grid.x = cols, grid.y = rows + // scales_per_stage = TILE_K (one scale per column per stage) + auto [grid, block, smem_size] = cu::get_tma_launch_args< + TILE_M, + TILE_K, + THREADS_PER_BLOCK, + STAGES, + TILE_K>( + cols, rows, COLS_PER_BLOCK, ROWS_PER_BLOCK, w.itemsize(), bits); + + CUtensorMap tensor_map_input; + CUtensorMap tensor_map_output; + + create_2D_tensor_map( + &tensor_map_input, + gpu_ptr(w), + cu::get_tma_dtype(w.dtype()), + rows, + cols, + TILE_M, + TILE_K, + stride_bytes); + + create_2D_tensor_map( + &tensor_map_output, + gpu_ptr(wq), + CU_TENSOR_MAP_DATA_TYPE_UINT8, + cols, + rows, + TILE_K, + TILE_M, + rows); + + auto kernel = cu::fp_quantize_columnwise_tma_mxfp8< + T, + false, + THREADS_PER_BLOCK, + COLS_PER_BLOCK, + ROWS_PER_BLOCK>; + enc.add_kernel_node( + kernel, + grid, + block, + smem_size, + tensor_map_input, + tensor_map_output, + gpu_ptr(scales), + rows, + cols); + } else { + throw std::runtime_error( + "[fp_quantize_columnwise_tma] Cannot quantize input with type float64."); + } + }); + } else { + throw std::runtime_error( + "[fp_quantize_columnwise_tma] TMA quantization only implemented for bits=8 and group_size=32."); + } +} + +void fp_quantize_columnwise( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + const std::optional& global_scale /* = std::nullopt */, + cu::CommandEncoder& enc, + const Stream& s) { + // Use TMA version for SM100+ with MXFP8 (bits=8, group_size=32) + // NVFP4 todo + const size_t rows = w.shape(-1); + const size_t cols = w.size() / rows; + const bool has_full_tma_tiles = ((rows % 128) == 0) && ((cols % 128) == 0); + bool use_tma = + (enc.device().compute_capability_major() >= 10 && bits == 8 && + group_size == 32 && has_full_tma_tiles); + if (use_tma) { + fp_quantize_columnwise_tma( + w, wq, scales, group_size, bits, global_scale, enc, s); + } else { + fp_quantize_columnwise_fallback( + w, wq, scales, group_size, bits, global_scale, enc, s); + } +} + +void fp_quantize( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + const std::optional& global_scale /* = std::nullopt */, + cu::CommandEncoder& enc, + const Stream& s) { + if (w.strides(-1) == 1) { + fp_quantize_rowwise(w, wq, scales, group_size, bits, global_scale, enc, s); + } else { + fp_quantize_columnwise( + w, wq, scales, group_size, bits, global_scale, enc, s); } } diff --git a/mlx/backend/cuda/quantized/fp_quantize.cuh b/mlx/backend/cuda/quantized/fp_quantize.cuh new file mode 100644 index 0000000000..83ff912868 --- /dev/null +++ b/mlx/backend/cuda/quantized/fp_quantize.cuh @@ -0,0 +1,363 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" +#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" +#include "mlx/backend/cuda/quantized/quantized_utils.cuh" +#include "mlx/backend/cuda/vector_types.cuh" +#include "mlx/dtype_utils.h" + +#include +#include +#include +#include + +namespace mlx::core { +namespace cu { + +template +struct Dequantize { + __device__ float operator()(uint8_t x) { + if constexpr (bits == 8) { + return float(*(__nv_fp8_e4m3*)(&x)); + } else { + return float(*(__nv_fp4_e2m1*)(&x)); + } + } +}; + +namespace cg = cooperative_groups; + +template +__global__ void fp_quantize_dequantize( + T* w, + T* out, + size_t size, + float* global_scale = nullptr) { + const bool use_global_scale = global_scale != nullptr; + const float scale_enc = + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f; + + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + uint32_t rbits = 0; // reserved bits for future use + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; + + size_t thread_idx = tidx + grid_dim_x * size_t(tidy); + size_t base_idx = thread_idx * group_size; + + if (base_idx >= size) { + return; + } + + auto w_tile = load_vector(w, thread_idx); + float scale_dec_b = 0.0f; + + Tx2 amax_2x = Tx2{0.0f, 0.0f}; + +#pragma unroll + for (int i = 0; i < group_size; i += 2) { + auto pair = Tx2{w_tile[i], w_tile[i + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + + scale_dec_b = static_cast( + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y)))); + + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + scale_dec_b *= scale_enc; + // Convert to mx scale or nv scale + using ScaleType = + std::conditional_t; + auto s = ScaleType(scale_dec_b); + float scale_enc_b = scale_enc / float(s); + float scale_dec = float(s) * inv_scale_enc; + AlignedVector w_hat; + +#pragma unroll + for (int i = 0; i < group_size / 4; i++) { + Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); + float4 dq; + if constexpr (bits == 8) { + uint32_t quantized_val = + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); + dq = dequant_fp8(quantized_val); + } else { + uint16_t quantized_val = + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); + dq = dequant_fp4(quantized_val); + } + w_hat[i * 4] = static_cast(dq.x * scale_dec); + w_hat[i * 4 + 1] = static_cast(dq.y * scale_dec); + w_hat[i * 4 + 2] = static_cast(dq.z * scale_dec); + w_hat[i * 4 + 3] = static_cast(dq.w * scale_dec); + } + store_vector(out, thread_idx, w_hat); +} + +template +__global__ void fp_quantize_rowwise( + T* w, + uint8_t* out, + uint8_t* scales, + size_t size, + float* global_scale = nullptr) { + // NVFP4 conversion: + // Global encode scale: (448 × 6) / *global_scale + // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8 + // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b + const bool use_global_scale = global_scale != nullptr; + const float scale_enc = + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + uint32_t rbits = 0; // reserved bits for future use + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; + + size_t thread_idx = tidx + grid_dim_x * size_t(tidy); + size_t base_idx = thread_idx * group_size; + + if (base_idx >= size) { + return; + } + + auto w_tile = load_vector(w, thread_idx); + float scale_dec_b = 0.0f; + + Tx2 amax_2x = Tx2{0.0f, 0.0f}; + +#pragma unroll + for (int i = 0; i < group_size; i += 2) { + auto pair = Tx2{w_tile[i], w_tile[i + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + + scale_dec_b = static_cast( + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y)))); + + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + scale_dec_b *= scale_enc; + // Convert to mx scale or nv scale + using ScaleType = + std::conditional_t; + auto s = ScaleType(scale_dec_b); + uint8_t q_scale = s.__x; + float scale_enc_b = scale_enc / float(s); + + scales[thread_idx] = q_scale; + constexpr int elem_per_byte = bits == 8 ? 1 : 2; + AlignedVector quantized; + +#pragma unroll + for (int i = 0; i < group_size / 4; i++) { + Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); + if constexpr (bits == 8) { + uint32_t quantized_val = + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); + *reinterpret_cast(&quantized[i * 4]) = quantized_val; + } else { + uint16_t quantized_val = + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); + *reinterpret_cast(&quantized[i * 2]) = quantized_val; + } + } + store_vector(out, thread_idx, quantized); +} + +template +__global__ void fp_quantize_columnwise_fallback( + T* w, + uint8_t* out, + uint8_t* scales, + size_t size, + int M, + int K, + float* global_scale = nullptr) { + // Input: [M, K] with strides [1, M] (M-major) + // Quantized output: [M, K/elem_per_byte] row-major (K-major) + // Scales: [M, K/group_size] row-major (K-major) + // Quantize along K (last dimension, groups of group_size elements) + const bool use_global_scale = global_scale != nullptr; + const float scale_enc = + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + uint32_t rbits = 0; + + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + constexpr int BLOCK_X = 16; + constexpr int BLOCK_Y = 32; + constexpr int elem_per_byte = (bits == 8) ? 1 : 2; + constexpr int bytes_per_group = group_size / elem_per_byte; + + constexpr int rows_per_block = BLOCK_X; + constexpr int cols_per_block = BLOCK_Y * group_size; + constexpr int local_cols = cols_per_block / elem_per_byte; + constexpr int bytes_per_block = rows_per_block * local_cols; + + constexpr int SMEM_PAD = 4; + constexpr int padded_local_cols = local_cols + SMEM_PAD; + + auto tidx = idx_in_block.x; + auto tidy = idx_in_block.y; + + int num_col_blocks = (K + cols_per_block - 1) / cols_per_block; + auto bidx = block_idx.x % num_col_blocks; + auto bidy = block_idx.x / num_col_blocks; + + T thread_data[group_size]; + + __shared__ uint8_t quantized_smem[rows_per_block * padded_local_cols]; + __shared__ uint8_t scales_smem[BLOCK_X][BLOCK_Y + SMEM_PAD]; + + int row_base = bidy * rows_per_block + tidx; + int col_base = bidx * cols_per_block + tidy * group_size; + + bool valid = (row_base < M) && (col_base + group_size <= K); + if (valid) { +#pragma unroll + for (int i = 0; i < group_size; i++) { + auto index = row_base + (col_base + i) * M; + thread_data[i] = w[index]; + } + + // Compute scale + Tx2 amax_2x = Tx2{0.0f, 0.0f}; +#pragma unroll + for (int r = 0; r < group_size; r += 2) { + auto pair = Tx2{thread_data[r], thread_data[r + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + float scale_dec_b = + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y))); + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + scale_dec_b *= scale_enc; + // Convert to mx scale or nv scale + using ScaleType = + std::conditional_t; + auto s = ScaleType(scale_dec_b); + float scale_enc_b = scale_enc / float(s); + scales_smem[tidx][tidy] = s.__x; + + int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group; + +#pragma unroll + for (int j = 0; j < group_size / 4; j++) { + Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); + if constexpr (bits == 8) { + uint32_t quantized_val = + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); + *reinterpret_cast(&quantized_smem[shared_idx + j * 4]) = + quantized_val; + } else { + uint16_t quantized_val = + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); + *reinterpret_cast(&quantized_smem[shared_idx + j * 2]) = + quantized_val; + } + } + } + __syncthreads(); + + int output_cols = K / elem_per_byte; + int num_groups_per_row = K / group_size; + int linear_tid = tidx + tidy * BLOCK_X; + // Write back quantized values +#pragma unroll + for (int i = linear_tid; i < bytes_per_block; i += BLOCK_X * BLOCK_Y) { + int local_row = i / local_cols; + int local_col = i % local_cols; + + int global_row = bidy * rows_per_block + local_row; + int global_col = bidx * local_cols + local_col; + + if (global_row < M && global_col < output_cols) { + int physical_idx = local_row * padded_local_cols + local_col; + out[global_row * output_cols + global_col] = quantized_smem[physical_idx]; + } + } + // Write back scales + constexpr int num_scales = BLOCK_X * BLOCK_Y; +#pragma unroll + for (int i = linear_tid; i < num_scales; i += BLOCK_X * BLOCK_Y) { + int local_row = i / BLOCK_Y; + int local_col = i % BLOCK_Y; + + int global_row = bidy * BLOCK_X + local_row; + int global_col = bidx * BLOCK_Y + local_col; + + if (global_row < M && global_col < num_groups_per_row) { + scales[global_row * num_groups_per_row + global_col] = + scales_smem[local_row][local_col]; + } + } +} + +template +__global__ void fp_dequantize( + const uint8_t* w, + const uint8_t* scales, + T* out, + size_t size, + float* global_scale = nullptr) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; + + constexpr int pack_factor = bits == 8 ? 1 : 2; + const bool use_global_scale = global_scale != nullptr; + const float inv_scale_enc = use_mx_scale + ? 1.0f + : (use_global_scale ? (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX) : 1.0f); + size_t offset = tidx + grid_dim_x * size_t(tidy); + size_t oindex = offset * pack_factor; + + if (oindex >= size) { + return; + } + + size_t gindex = oindex / group_size; + using ScaleType = + std::conditional_t; + auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc; + + out += oindex; + + uint32_t val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = static_cast(scale * Dequantize{}(d)); + } +} + +} // namespace cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh new file mode 100644 index 0000000000..459bbc4a82 --- /dev/null +++ b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh @@ -0,0 +1,257 @@ +// Copyright © 2026 Apple Inc. +#pragma once + +#include "mlx/backend/cuda/ptx.cuh" +#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" +#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" +#include "mlx/backend/cuda/quantized/quantized_utils.cuh" +#include "mlx/backend/cuda/vector_types.cuh" +#include "mlx/dtype_utils.h" + +#include +#include +#include +#include + +namespace mlx::core { +namespace cu { + +constexpr size_t TMA_SHMEM_ALIGNMENT = 128; +constexpr size_t BUFFS_NUM = 2; + +namespace cg = cooperative_groups; + +template < + typename T, + bool USE_SR, + int THREADS_PER_BLOCK, + int COLS_PER_BLOCK, + int ROWS_PER_BLOCK> +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + fp_quantize_columnwise_tma_mxfp8( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + uint8_t* __restrict__ scales, + const size_t rows, + const size_t cols) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + + constexpr size_t TILE_M = 32; + constexpr size_t TILE_K = COLS_PER_BLOCK; + constexpr size_t STEPS = ROWS_PER_BLOCK / TILE_M; + constexpr int elem_per_byte = 1; + + const auto block_idx = cg::this_thread_block().group_index(); + const auto idx_in_block = cg::this_thread_block().thread_index(); + const int tidx = idx_in_block.x; // Thread handles column tidx + const bool is_master = (tidx == 0); + + const size_t block_offset_col = block_idx.x * COLS_PER_BLOCK; + const size_t block_offset_row = block_idx.y * ROWS_PER_BLOCK; + + constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; + constexpr size_t in_tile_size = BUFF_ELEMS * sizeof(T); + constexpr size_t in_buff_size_aligned = + ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + constexpr size_t out_tile_elems = BUFF_ELEMS / elem_per_byte; + constexpr size_t out_tile_size = out_tile_elems; + constexpr size_t out_buff_size_aligned = + ((out_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + extern __shared__ char shared_mem[]; + uintptr_t aligned_shared = + (reinterpret_cast(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + T* in_sh = reinterpret_cast(aligned_shared); + uint8_t* out_sh = + reinterpret_cast(aligned_shared + in_buff_size_aligned); + + constexpr uint32_t tile_bytes = static_cast(in_tile_size); + + __shared__ alignas(8) uint64_t mbar[STEPS]; + + T thread_data[TILE_M]; + uint32_t rbits = 0; // Reserved for stochastic rounding + const size_t scale_stride = rows / TILE_M; + + // Master thread init memory barriers for all steps + // fence for tma, synchronize threads so all see mbarrier + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STEPS; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + // Launch first async copy before entering the loop + copy_2d_to_shared( + &in_sh[0], + &tensor_map_input, + static_cast(block_offset_col), + static_cast(block_offset_row), + tile_bytes, + &mbar[0], + is_master); + +#pragma unroll + for (size_t step = 0; step < STEPS; ++step) { + // buffer memory offset in shared memory (we use double buffering for + // pipelining) + const size_t buff = step % BUFFS_NUM; + const size_t next_step = step + 1; + const size_t step_row_offset = step * TILE_M; + + if (next_step < STEPS) { + // before launching another async copy, check that there is less than 2 + // (to ensure that shared -> global synch is finished and buffer can be + // reused) + ptx::cp_async_bulk_wait_group_read<1>(); + const size_t next_buff = next_step % BUFFS_NUM; + const size_t next_row_offset = block_offset_row + next_step * TILE_M; + const size_t next_buff_elem_offset = next_buff * BUFF_ELEMS; + + copy_2d_to_shared( + &in_sh[next_buff_elem_offset], + &tensor_map_input, + static_cast(block_offset_col), + static_cast(next_row_offset), + tile_bytes, + &mbar[next_step], + is_master); + } + + ptx::fence_proxy_async_shared_cta(); + // Wait until the data is ready, parity is always 0 because for simplicity + // we dont reuse barriers between steps + ptx::mbarrier_wait_parity(&mbar[step], 0); + const size_t buff_offset = buff * BUFF_ELEMS; + // Read the data from shared to registers +#pragma unroll + for (int row = 0; row < TILE_M; ++row) { + thread_data[row] = in_sh[buff_offset + row * TILE_K + tidx]; + } + Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; +#pragma unroll + for (int row = 0; row < TILE_M; row += 2) { + auto pair = Tx2{thread_data[row], thread_data[row + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + + float scale = + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y))); + + scale /= F8E4M3_MAX; + + using ScaleType = __nv_fp8_e8m0; + auto s = ScaleType(scale); + scale = float(s); + // Write scale directly to global memory + const size_t global_col = block_offset_col + tidx; + const size_t global_row_group = + (block_offset_row + step_row_offset) / TILE_M; + if (global_col < cols && (block_offset_row + step_row_offset) < rows) { + scales[global_col * scale_stride + global_row_group] = s.__x; + } + const size_t out_buff_offset = buff * out_tile_elems; + // Quantize to registers first + constexpr int GROUPS = TILE_M / 4; + uint32_t quantized_regs[GROUPS]; +#pragma unroll + for (int j = 0; j < GROUPS; ++j) { + Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); + quantized_regs[j] = + cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + } + // Write output to shared memory with swapped store order to reduce bank + // conflicts. Without swap: stride between threads is TILE_M=32 bytes + // = 8 banks, every 4th thread hits the same bank -> 8-way conflict. + const int lane = tidx % 32; + const int group = (lane / 4) % 2; + const size_t base = out_buff_offset + tidx * TILE_M; + switch (group) { + case 0: + *reinterpret_cast(&out_sh[base + 0]) = { + quantized_regs[0], + quantized_regs[1], + quantized_regs[2], + quantized_regs[3]}; + *reinterpret_cast(&out_sh[base + 16]) = { + quantized_regs[4], + quantized_regs[5], + quantized_regs[6], + quantized_regs[7]}; + break; + case 1: + *reinterpret_cast(&out_sh[base + 16]) = { + quantized_regs[4], + quantized_regs[5], + quantized_regs[6], + quantized_regs[7]}; + *reinterpret_cast(&out_sh[base + 0]) = { + quantized_regs[0], + quantized_regs[1], + quantized_regs[2], + quantized_regs[3]}; + break; + } + __syncthreads(); + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + if (is_master) { + const size_t global_row = block_offset_row + step_row_offset; + const uint32_t out_x = static_cast(global_row); + const uint32_t out_y = static_cast(block_offset_col); + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), + out_x, + out_y, + reinterpret_cast(&out_sh[out_buff_offset])); + ptx::cp_async_bulk_commit_group(); + } + } + // Wait for all TMA stores to complete + ptx::cp_async_bulk_wait_group_read<0>(); + + __syncthreads(); + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STEPS; ++iter) { + ptx::mbarrier_invalidate(&mbar[iter]); + } + } +#endif // __CUDA_ARCH__ >= 1000 +} + +template < + typename T, + bool USE_SR, + int THREADS_PER_BLOCK, + int COLS_PER_BLOCK, + int ROWS_PER_BLOCK> +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + fp_quantize_columnwise_tma_nvfp4( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + uint8_t* __restrict__ scales, + const size_t rows, + const size_t cols, + float* global_scale) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) + // placeholder TODO - NVFP4 TMA kernel not yet implemented +#endif // __CUDA_ARCH__ >= 1000 +} + +} // namespace cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.cu b/mlx/backend/cuda/quantized/qqmm_utils.cu index d19865a3b3..831f69620c 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.cu +++ b/mlx/backend/cuda/quantized/qqmm_utils.cu @@ -10,11 +10,6 @@ namespace mlx::core { namespace cg = cooperative_groups; -constexpr int TILE_ROWS = 128; -constexpr int TILE_COLS = 4; -constexpr int TILES_PER_LANE = 1; -constexpr int LANES_PER_BLOCK = 32; - // To pass scales to tensor cores, they need to be repacked into a tiled layout // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout // Tiled layout for scale factors is very well described in CUTLASS @@ -48,6 +43,15 @@ constexpr int LANES_PER_BLOCK = 32; // [252, 253, 254, 255], // [380, 381, 382, 383], // [508, 509, 510, 511]]]]], +namespace cu { + +constexpr float F8E4M3_MAX = 448.0f; +constexpr float F4E2M1_MAX = 6.0f; + +constexpr int TILE_ROWS = 128; +constexpr int TILE_COLS = 4; +constexpr int TILES_PER_LANE = 1; +constexpr int LANES_PER_BLOCK = 32; inline std::tuple get_swizzle_launch_args( size_t M_swizzled, @@ -68,11 +72,6 @@ inline std::tuple get_swizzle_launch_args( return std::make_tuple(grid, block); } -namespace cu { - -constexpr float F8E4M3_MAX = 448.0f; -constexpr float F4E2M1_MAX = 6.0f; - __global__ void compute_qqmm_pointers( float* alpha_out, float* beta_out, @@ -225,7 +224,7 @@ void swizzle_scales( size_t output_cols = scales_tiled.shape(-1); auto [num_blocks, block_dims] = - get_swizzle_launch_args(output_rows, output_cols); + cu::get_swizzle_launch_args(output_rows, output_cols); enc.add_kernel_node( cu::swizzle_scales, num_blocks, diff --git a/mlx/backend/cuda/quantized/quantized_utils.cuh b/mlx/backend/cuda/quantized/quantized_utils.cuh index 8cbe9b297d..92bc315d4a 100644 --- a/mlx/backend/cuda/quantized/quantized_utils.cuh +++ b/mlx/backend/cuda/quantized/quantized_utils.cuh @@ -1,12 +1,17 @@ // Copyright © 2025 Apple Inc. +#pragma once #include #include +#include "mlx/backend/cuda/ptx.cuh" namespace mlx::core { namespace cu { +constexpr float F8E4M3_MAX = 448.0f; +constexpr float F4E2M1_MAX = 6.0f; + inline __device__ float4 dequant_fp8(uint32_t bits) { auto out = *(__nv_fp8x4_e4m3*)(&bits); return out.operator float4(); @@ -44,6 +49,28 @@ __device__ __forceinline__ void absmax_x2(T& out, const T& x1, const T& x2) { } } +__device__ __forceinline__ void copy_2d_to_shared( + void* dst, + const CUtensorMap* tensor_map, + uint32_t tile_x, + uint32_t tile_y, + uint32_t num_bytes, + uint64_t* barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Arrive and tell how many bytes are expected + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + dst, tensor_map, tile_x, tile_y, barrier); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + } // namespace cu template @@ -85,4 +112,37 @@ void dispatch_bits(int bits, F&& f) { } } +// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html +inline void create_2D_tensor_map( + CUtensorMap* tensorMap, + void* input_ptr, + CUtensorMapDataType dtype, + uint64_t rows, + uint64_t cols, + uint32_t tile_y, + uint32_t tile_x, + uint64_t stride_bytes, + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE) { + constexpr uint32_t rank = 2; // 2D + uint64_t global_dim[rank] = {cols, rows}; + // For row-major layout + uint64_t strides[rank - 1] = {stride_bytes}; + uint32_t tile_dim[rank] = {tile_x, tile_y}; + uint32_t elem_stride[rank] = {1, 1}; + + CHECK_CUDA_ERROR(cuTensorMapEncodeTiled( + tensorMap, + dtype, + rank, + input_ptr, + global_dim, + strides, + tile_dim, + elem_stride, + CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, + CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); +} + } // namespace mlx::core