Conversation
There was a problem hiding this comment.
Pull request overview
This PR implements CUDA support for the MaskedScatter operation, which scatters values from a source array into a destination array at positions specified by a boolean mask. The implementation follows the existing Metal backend pattern and properly integrates with the CUDA backend infrastructure.
Changes:
- Converted
indexing.cpptoindexing.cuand added full CUDAMaskedScatter::eval_gpuimplementation withmasked_assignkernel - Refactored scan launch logic into reusable
scan_gpu_inplacefunction with new header file - Removed CUDA skip entries for masked-scatter-related tests
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| mlx/backend/cuda/indexing.cu | Implemented masked_assign CUDA kernel and MaskedScatter::eval_gpu method; converted from .cpp to .cu |
| mlx/backend/cuda/scan.cu | Refactored scan logic into scan_gpu_inplace function for reuse in MaskedScatter |
| mlx/backend/cuda/scan.h | Added header declaring scan_gpu_inplace function |
| mlx/backend/cuda/primitives.cpp | Removed NO_GPU(MaskedScatter) macro to enable CUDA support |
| mlx/backend/cuda/CMakeLists.txt | Updated build to compile indexing.cu instead of indexing.cpp |
| tests/ops_tests.cpp | Removed CUDA skip guard from masked_scatter tests |
| tests/autograd_tests.cpp | Removed CUDA skip guard from masked_scatter autograd tests |
| python/tests/cuda_skip.py | Removed three masked-scatter-related test entries from skip list |
Comments suppressed due to low confidence (1)
mlx/backend/cuda/indexing.cu:80
- The
masked_assignkernel uses a signed 32-bitIdxTtogether withstride = static_cast<IdxT>(blockDim.x) * gridDim.x * gridDim.y * gridDim.z, which can overflow whenmask_flat.size()approachesINT32_MAX, causingstrideto wrap negative whiletotalremains positive. In that case, the loopfor (IdxT idx = thread_id; idx < total; idx += stride)can revisit with negativeidxvalues and read/writemask[idx],scatter_offsets[idx], andout[idx]out of bounds, leading to GPU memory corruption and potential data exposure or code execution in contexts that rely on untrusted shapes. To address this, ensure the index type used in this kernel cannot overflow for the chosen grid/block configuration (e.g., use an unsigned or 64-bit index consistently forIdxTwhen computingblock_id,stride, and indexing, or otherwise constraingridDim/blockDimso their product fits safely in the index type).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
zcbenz
left a comment
There was a problem hiding this comment.
Looks good to me, would like another review before merging.
39962fe to
f5693f7
Compare
nastya236
left a comment
There was a problem hiding this comment.
Looks good to me as well!
nastya236
left a comment
There was a problem hiding this comment.
As I said looks great, thanks for your contribution.
Could you please provide bandwidth numbers for masked scatter kernel for a range of shapes?
Proposed changes
This PR adds CUDA support for
MaskedScatter.Changed files
mlx/backend/cuda/indexing.cpp: implemented CUDAMaskedScatter::eval_gpuusing the CUDA JIT module path.mlx/backend/cuda/device/scatter.cuh: added the JIT device kernelmasked_scatter_assign<...>used by CUDA masked scatter.mlx/backend/cuda/scan.cu: refactored scan execution into reusablescan_gpu_inplace(...)and updatedScan::eval_gputo delegate to it.mlx/backend/cuda/scan.h: added declaration forscan_gpu_inplace(...).python/tests/cuda_skip.py,tests/ops_tests.cpp,tests/autograd_tests.cpp: removed CUDA skip entries for masked-scatter-related tests.Validation
python -m pytest python/tests/test_ops.py -k masked_scatter -qpassed.python -m pytest python/tests/test_vmap.py -k vmap_masked_scatter -qpassed.python -m pytest python/tests/test_array.py -k setitem_with_boolean_mask -qpassed.build/tests/tests -tc="test masked_scatter,test masked_scatter autograd"passed.Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes