Skip to content
Draft
99 changes: 61 additions & 38 deletions examples/python/qqmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,48 +33,71 @@ def test_qqmm():
[64, 128, 256, 1024, 1024 * 8], # N
[64, 128, 256, 1024, 1024 * 8], # K
)
layouts = ["TN", "NT", "TT", "NN"]
for group_size, mode, bits in tests:
for M, N, K in product(*shapes):
for dtype in dtypes:
x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype)
w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype)
w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode)
w_dq = mx.dequantize(
w_q,
scales_w,
group_size=group_size,
bits=bits,
mode=mode,
dtype=dtype,
)
y_q = mx.qqmm(
x,
w_q,
scales_w,
group_size=group_size,
bits=bits,
mode=mode,
)
x_q, scales_x = mx.quantize(
x, group_size=group_size, bits=bits, mode=mode
)
x_dq = mx.dequantize(
x_q,
scales_x,
group_size=group_size,
bits=bits,
mode=mode,
dtype=dtype,
)
y_hat = mx.matmul(x_dq, mx.transpose(w_dq))
ulp = ulp_bf16_at(y_hat)
error = (y_q - y_hat).abs()
if not (mx.logical_or(error < 1e-3, error <= ulp).all()):
raise AssertionError(
f"qqmm test failed for shape {(M, N, K)}, "
f"group_size={group_size}, bits={bits}, "
f"mode={mode}, dtype={dtype}"
for layout in layouts:
if layout == "NT":
x_shape = (M, K)
w_shape = (N, K)
elif layout == "TN":
x_shape = (K, M)
w_shape = (K, N)
elif layout == "TT":
x_shape = (K, M)
w_shape = (N, K)
else: # "NN"
x_shape = (M, K)
w_shape = (K, N)

x = mx.random.normal(shape=x_shape, key=k1, dtype=dtype)
w = mx.random.normal(shape=w_shape, key=k2, dtype=dtype)

if layout == "TT":
x = mx.transpose(x)
elif layout == "TN":
w = mx.transpose(w)
x = mx.transpose(x)
elif layout == "NN":
w = mx.transpose(w)

y_q = mx.qqmm(
x,
w,
group_size=group_size,
bits=bits,
mode=mode,
)
w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode)
w_dq = mx.dequantize(
w_q,
scales_w,
group_size=group_size,
bits=bits,
mode=mode,
dtype=dtype,
)
x_q, scales_x = mx.quantize(
x, group_size=group_size, bits=bits, mode=mode
)
x_dq = mx.dequantize(
x_q,
scales_x,
group_size=group_size,
bits=bits,
mode=mode,
dtype=dtype,
)
y_hat = mx.matmul(x_dq, mx.transpose(w_dq))
ulp = ulp_bf16_at(y_hat)
error = (y_q - y_hat).abs()
if not (mx.logical_or(error < 1e-3, error <= ulp).all()):
raise AssertionError(
f"qqmm test failed for shape {(M, N, K)}, "
f"group_size={group_size}, bits={bits}, "
f"mode={mode}, dtype={dtype}, layout={layout}"
)


def test_qqmm_vjp():
Expand Down
127 changes: 127 additions & 0 deletions mlx/backend/cuda/ptx.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#pragma once

#include <cuda.h>
#include <cuda_runtime.h>

namespace mlx::core {

namespace ptx {

#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
defined(__CUDA_ARCH_SPECIFIC__)

__device__ __forceinline__ void mbarrier_init(uint64_t* mbar, uint32_t count) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.init.shared.b64 [%0], %1;"
:
: "r"(mbar_ptr), "r"(count)
: "memory");
}

__device__ __forceinline__ void mbarrier_invalidate(uint64_t* mbar) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.inval.shared.b64 [%0];" : : "r"(mbar_ptr) : "memory");
}

// Arrive at barrier (non-master threads)
__device__ __forceinline__ void mbarrier_arrive(uint64_t* mbar) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.shared.b64 _, [%0];"
:
: "r"(mbar_ptr)
: "memory");
}

// Arrive at barrier and set expected transaction count (master thread)
__device__ __forceinline__ void mbarrier_arrive_expect_tx(
uint64_t* mbar,
uint32_t tx_count) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"
:
: "r"(mbar_ptr), "r"(tx_count)
: "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-asynchronous-copy-completion-mechanisms-mbarrier
__device__ __forceinline__ void mbarrier_wait_parity(
uint64_t* mbar,
uint32_t parity) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile(
"{\n\t"
".reg .pred P;\n\t"
"WAIT_LOOP:\n\t"
"mbarrier.try_wait.parity.shared.b64 P, [%0], %1;\n\t"
"@!P bra WAIT_LOOP;\n\t"
"}\n\t"
:
: "r"(mbar_ptr), "r"(parity)
: "memory");
}

// Async bulk tensor copy: global -> shared (2D)
__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
void* dst_shmem,
const CUtensorMap* tensor_map,
uint32_t tile_x,
uint32_t tile_y,
uint64_t* mbar) {
uint32_t dst_ptr = __cvta_generic_to_shared(dst_shmem);
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);

asm volatile(
"cp.async.bulk.tensor.2d.shared::cluster.global.tile"
".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];"
:
: "r"(dst_ptr), "l"(tensor_map), "r"(tile_x), "r"(tile_y), "r"(mbar_ptr)
: "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global
__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global(
const uint64_t* tensor_map_ptr,
const uint32_t offset_x,
const uint32_t offset_y,
uint64_t* src_shmem) {
uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem);
asm volatile(
"cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::
"l"(tensor_map_ptr),
"r"(offset_x),
"r"(offset_y),
"r"(src_shmem_ptr)
: "memory");
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
template <int N>
__device__ __forceinline__ void cp_async_bulk_wait_group_read() {
if constexpr (N == 0) {
asm volatile("cp.async.bulk.wait_group.read 0;");
} else if constexpr (N == 1) {
asm volatile("cp.async.bulk.wait_group.read 1;");
} else if constexpr (N == 2) {
asm volatile("cp.async.bulk.wait_group.read 2;");
} else if constexpr (N == 4) {
asm volatile("cp.async.bulk.wait_group.read 4;");
}
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;");
}

// Ccreates a memory ordering barrier between generic and async proxies
// details:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#proxies
__device__ __forceinline__ void fence_proxy_async_shared_cta() {
asm volatile("fence.proxy.async.shared::cta;");
}

#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
// (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
} // namespace ptx
} // namespace mlx::core
Loading
Loading