Skip to content

[CUDA] Implement MaskedScatter #3151

Open
Lyxot wants to merge 4 commits intoml-explore:mainfrom
Lyxot:cuda/masked_scatter
Open

[CUDA] Implement MaskedScatter #3151
Lyxot wants to merge 4 commits intoml-explore:mainfrom
Lyxot:cuda/masked_scatter

Conversation

@Lyxot
Copy link

@Lyxot Lyxot commented Feb 20, 2026

Proposed changes

This PR adds CUDA support for MaskedScatter.

Changed files

  • mlx/backend/cuda/indexing.cpp: implemented CUDA MaskedScatter::eval_gpu using the CUDA JIT module path.
  • mlx/backend/cuda/device/scatter.cuh: added the JIT device kernel masked_scatter_assign<...> used by CUDA masked scatter.
  • mlx/backend/cuda/scan.cu: refactored scan execution into reusable scan_gpu_inplace(...) and updated Scan::eval_gpu to delegate to it.
  • mlx/backend/cuda/scan.h: added declaration for scan_gpu_inplace(...).
  • python/tests/cuda_skip.py, tests/ops_tests.cpp, tests/autograd_tests.cpp: removed CUDA skip entries for masked-scatter-related tests.

Validation

  • Python:
    • python -m pytest python/tests/test_ops.py -k masked_scatter -q passed.
    • python -m pytest python/tests/test_vmap.py -k vmap_masked_scatter -q passed.
    • python -m pytest python/tests/test_array.py -k setitem_with_boolean_mask -q passed.
  • C++:
    • build/tests/tests -tc="test masked_scatter,test masked_scatter autograd" passed.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Copilot AI review requested due to automatic review settings February 20, 2026 19:41
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements CUDA support for the MaskedScatter operation, which scatters values from a source array into a destination array at positions specified by a boolean mask. The implementation follows the existing Metal backend pattern and properly integrates with the CUDA backend infrastructure.

Changes:

  • Converted indexing.cpp to indexing.cu and added full CUDA MaskedScatter::eval_gpu implementation with masked_assign kernel
  • Refactored scan launch logic into reusable scan_gpu_inplace function with new header file
  • Removed CUDA skip entries for masked-scatter-related tests

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated no comments.

Show a summary per file
File Description
mlx/backend/cuda/indexing.cu Implemented masked_assign CUDA kernel and MaskedScatter::eval_gpu method; converted from .cpp to .cu
mlx/backend/cuda/scan.cu Refactored scan logic into scan_gpu_inplace function for reuse in MaskedScatter
mlx/backend/cuda/scan.h Added header declaring scan_gpu_inplace function
mlx/backend/cuda/primitives.cpp Removed NO_GPU(MaskedScatter) macro to enable CUDA support
mlx/backend/cuda/CMakeLists.txt Updated build to compile indexing.cu instead of indexing.cpp
tests/ops_tests.cpp Removed CUDA skip guard from masked_scatter tests
tests/autograd_tests.cpp Removed CUDA skip guard from masked_scatter autograd tests
python/tests/cuda_skip.py Removed three masked-scatter-related test entries from skip list
Comments suppressed due to low confidence (1)

mlx/backend/cuda/indexing.cu:80

  • The masked_assign kernel uses a signed 32-bit IdxT together with stride = static_cast<IdxT>(blockDim.x) * gridDim.x * gridDim.y * gridDim.z, which can overflow when mask_flat.size() approaches INT32_MAX, causing stride to wrap negative while total remains positive. In that case, the loop for (IdxT idx = thread_id; idx < total; idx += stride) can revisit with negative idx values and read/write mask[idx], scatter_offsets[idx], and out[idx] out of bounds, leading to GPU memory corruption and potential data exposure or code execution in contexts that rely on untrusted shapes. To address this, ensure the index type used in this kernel cannot overflow for the chosen grid/block configuration (e.g., use an unsigned or 64-bit index consistently for IdxT when computing block_id, stride, and indexing, or otherwise constrain gridDim/blockDim so their product fits safely in the index type).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, would like another review before merging.

@Lyxot Lyxot force-pushed the cuda/masked_scatter branch from 39962fe to f5693f7 Compare February 27, 2026 10:50
@Lyxot Lyxot requested a review from zcbenz February 27, 2026 10:50
@nastya236 nastya236 self-requested a review February 27, 2026 16:05
nastya236
nastya236 previously approved these changes Feb 27, 2026
Copy link
Collaborator

@nastya236 nastya236 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me as well!

@nastya236 nastya236 self-requested a review February 27, 2026 20:19
@nastya236 nastya236 dismissed their stale review February 28, 2026 00:25

Re-review

Copy link
Collaborator

@nastya236 nastya236 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said looks great, thanks for your contribution.
Could you please provide bandwidth numbers for masked scatter kernel for a range of shapes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants