Conversation
There was a problem hiding this comment.
Pull request overview
This PR implements CUDA support for the Hadamard transform (mx.hadamard_transform), following the same staged decomposition strategy as the Metal backend. The implementation decomposes the Hadamard transform into three stages (n1, n2, and m) to efficiently handle large transforms while respecting GPU memory constraints.
Changes:
- Added CUDA kernel implementation for Hadamard transform with JIT compilation support
- Enabled CUDA tests by removing skip entries for
test_hadamardandtest_hadamard_grad_vmap - Integrated the implementation into the CUDA backend build system
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
mlx/backend/cuda/hadamard.cu |
Implements the main CUDA evaluation logic with staged kernel launches and JIT code generation for non-power-of-two radices |
mlx/backend/cuda/device/hadamard.cuh |
Provides device-side kernel templates for n-stage and m-stage transforms with vectorized memory access |
mlx/backend/cuda/primitives.cpp |
Removes NO_GPU(Hadamard) to enable GPU evaluation path |
mlx/backend/cuda/jit_module.cpp |
Registers the hadamard.cuh header for JIT compilation |
mlx/backend/cuda/CMakeLists.txt |
Adds hadamard.cu to build sources |
python/tests/cuda_skip.py |
Removes CUDA skip entries to enable Hadamard tests |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Looks great, thanks for the contribution. Could you please share bandwidth numbers for the proposed kernel across a range of shapes? I’m also particularly interested in case where the Hadamard transform is applied to tiled inputs with x = mx.random.uniform(shape=(4096, 4096))
mx.hadamard_transform(x.reshape(4096, 4096 // N, N))where |
Proposed changes
This PR adds CUDA support for
Hadamard(mx.hadamard_transform) with the same staged decomposition strategy used by the Metal backend.Changed files
mlx/backend/cuda/hadamard.cu: implemented CUDAHadamard::eval_gpuand JIT launch flow (n1/n2/mstaged execution), reusingdecompose_hadamard(...).mlx/backend/cuda/device/hadamard.cuh: added JIT device kernelshadamard_n<...>andhadamard_m<...>plus radix helpers.python/tests/cuda_skip.py: removed CUDA skip entriesValidation
python -m pytest python/tests/test_ops.py -k test_hadamard -qpassed.python -m pytest python/tests/test_ops.py -k test_hadamard_grad_vmap -qpassed.Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes