Skip to content

[CUDA][Performance] Add radix select implementation for efficient partition operations#3117

Open
Lyxot wants to merge 16 commits intoml-explore:mainfrom
Lyxot:cuda-radix-select
Open

[CUDA][Performance] Add radix select implementation for efficient partition operations#3117
Lyxot wants to merge 16 commits intoml-explore:mainfrom
Lyxot:cuda-radix-select

Conversation

@Lyxot
Copy link

@Lyxot Lyxot commented Feb 9, 2026

Proposed changes

This adds a CUDA radix-select based path for argpartition partition and introduces multi-block-per-row and multi-row-per-block for shapes where normal radix select underperforms. #3064

What changed

  • Added CUDA radix-select kernels in mlx/backend/cuda/device/radix_select.cuh:
    • Small shared-memory path for smaller sorted axes
    • Large streaming path
    • Tiled large-array path with launch planning (blocks_per_row, rows_per_block)
    • Deterministic scatter for equal keys (stable partition ordering behavior)
  • Integrated new radix partition dispatch in mlx/backend/cuda/sort.cu:
    • ArgPartition::eval_gpu / Partition::eval_gpu now call radix partition path
    • Size-based dispatch between small and large kernels
    • Tiled launch used for large contiguous workloads when beneficial
  • Added benchmark and verification script benchmarks/python/radix_select_bench.py: Correctness checks, determinism checks, and performance sweep utilities

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Lyxot added 11 commits February 9, 2026 02:02
Fix two correctness issues in CUDA radix partition/argpartition:

- In the large contiguous radix path, stop deriving row bases from
  `row * min(non-axis stride)` and compute row offsets with `elem_to_loc(...)`
  using non-axis shape/strides (matching merge-sort indexing behavior).
- Keep stride arguments 64-bit end-to-end in radix-select kernels and launches
  (remove narrowing to `int` and related `INT32_MAX` guard).

This fixes incorrect row addressing for valid contiguous non-linear layouts
(e.g. column-major with axis=0) and avoids silent misindexing on large strides.
Eliminate MAX_NDIM-based rank limits in CUDA radix partition/argpartition by
switching radix kernels from fixed-size __grid_constant__ shape/stride params
to dynamic device pointers for non-axis metadata.

Changes:
- Update radix kernels to take dynamic NC metadata pointers:
  - radix_select_small_nc_kernel
  - radix_select_large_streaming_kernel
  - radix_select_large_streaming_nc_kernel
- In gpu_radix_partition_small/gpu_radix_partition_large:
  - allocate device buffers for nc_shape/in_nc_strides/out_nc_strides
  - copy host metadata with cudaMemcpyAsync
  - pass pointers into kernel launches
- Remove MAX_NDIM-dependent routing so high-rank tensors can still use radix
  partition path.
- Keep stride handling 64-bit end-to-end in radix launches/kernels.

Also slightly widens fallback-model threshold range (without changing max_rows).
Copilot AI review requested due to automatic review settings February 9, 2026 19:23
@Lyxot
Copy link
Author

Lyxot commented Feb 9, 2026

I got the following benchmark results on the 4070 Super

Dtype: bfloat16
Config                      ArgPartition      ArgSort    Speedup
--------------------------------------------------------------------------------
b=2048, v=8192, k=32             0.315ms      1.564ms   4.97x
b=2048, v=4096, k=32             0.173ms      0.517ms   2.98x
b=1024, v=4096, k=16             0.097ms      0.256ms   2.64x
b=512, v=2048, k=64              0.051ms      0.078ms   1.53x
b=256, v=1024, k=32              0.029ms      0.036ms   1.22x
b=128, v=512, k=16               0.027ms      0.028ms   1.02x
b=1, v=128000, k=64              0.058ms      0.076ms   1.30x
b=1, v=512, k=32                 0.026ms      0.027ms   1.04x
b=16, v=8192, k=32               0.049ms      0.046ms   0.94x
b=32, v=8192, k=32               0.054ms      0.061ms   1.13x
b=64, v=8192, k=32               0.072ms      0.075ms   1.04x

Dtype: float32
Config                      ArgPartition      ArgSort    Speedup
--------------------------------------------------------------------------------
b=2048, v=8192, k=32             0.376ms      1.851ms   4.93x
b=2048, v=4096, k=32             0.211ms      0.628ms   2.98x
b=1024, v=4096, k=16             0.119ms      0.325ms   2.73x
b=512, v=2048, k=64              0.060ms      0.095ms   1.57x
b=256, v=1024, k=32              0.034ms      0.039ms   1.14x
b=128, v=512, k=16               0.034ms      0.029ms   0.86x
b=1, v=128000, k=64              0.083ms      0.084ms   1.01x
b=1, v=512, k=32                 0.028ms      0.027ms   0.96x
b=16, v=8192, k=32               0.076ms      0.051ms   0.67x
b=32, v=8192, k=32               0.078ms      0.067ms   0.86x
b=64, v=8192, k=32               0.114ms      0.088ms   0.78x

@Lyxot
Copy link
Author

Lyxot commented Feb 9, 2026

Most performance is basically OK, but there are still some dtypes that need further optimization (float32)

Dtype=int16  k=vocab*0.004

            v=512       v=1024      v=2048      v=4096      v=8192     v=16384     v=32768     v=65536     v=131072 
b=1          1.19x       1.14x       1.43x       0.96x       0.99x       1.10x       1.23x       1.30x       1.26x  
b=2          1.10x       1.18x       1.24x       0.97x       1.03x       1.13x       1.20x       1.35x       1.76x  
b=4          1.10x       1.11x       1.24x       0.98x       1.00x       1.08x       1.21x       1.60x       2.27x  
b=8          1.24x       1.52x       1.51x       0.89x       1.03x       1.19x       1.67x       2.14x       2.90x  
b=16         1.39x       1.60x       1.67x       0.89x       1.05x       1.23x       2.11x       2.97x       4.45x  
b=32         1.08x       1.13x       1.17x       0.95x       1.36x       1.81x       1.96x       3.65x       5.43x  
b=48         1.42x       1.28x       1.14x       0.99x       1.12x       2.05x       2.80x       3.74x       8.90x  
b=64         1.13x       1.04x       1.29x       1.18x       1.39x       2.13x       3.60x       4.96x      10.46x  
b=96         1.16x       1.24x       0.98x       1.38x       1.72x       2.22x       3.67x       7.91x       9.29x  
b=128        1.96x       1.39x       1.10x       1.30x       1.93x       2.98x       4.37x       8.74x       9.06x  
b=256        1.14x       1.36x       1.39x       1.97x       3.23x       4.67x       9.05x       8.68x       6.98x  
b=512        1.20x       1.56x       1.54x       2.85x       4.21x       7.15x       7.93x       8.58x       7.72x  
b=1024       1.33x       1.94x       1.79x       3.55x       6.09x       7.26x       8.90x       8.42x       7.25x  
b=2048       1.53x       1.90x       1.94x       4.14x       5.41x       7.96x       9.28x       8.76x       7.58x  
Dtype=float32  k=vocab*0.004

            v=512       v=1024      v=2048      v=4096      v=8192     v=16384     v=32768     v=65536     v=131072 
b=1          0.91x       1.15x       1.07x       0.84x       0.74x       0.77x       0.91x       0.93x       1.09x  
b=2          0.88x       0.96x       1.01x       0.81x       0.69x       0.69x       0.78x       0.90x       1.24x  
b=4          0.82x       0.86x       1.06x       0.81x       0.78x       0.78x       0.90x       1.29x       2.15x  
b=8          0.84x       0.92x       0.97x       0.89x       0.71x       0.84x       1.02x       1.60x       2.21x  
b=16         0.87x       0.93x       0.97x       0.71x       0.68x       1.00x       1.38x       1.99x       3.15x  
b=32         0.87x       0.79x       0.96x       0.84x       0.88x       1.21x       1.80x       2.82x       6.07x  
b=48         0.93x       0.90x       0.75x       0.74x       0.62x       1.37x       2.24x       3.20x       7.27x  
b=64         0.84x       1.00x       1.09x       1.13x       0.78x       1.55x       2.53x       5.14x       7.46x  
b=96         0.68x       0.95x       1.08x       0.93x       1.23x       1.80x       3.16x       6.69x       6.16x  
b=128        0.91x       0.99x       1.20x       1.41x       1.30x       2.20x       4.37x       6.35x       3.91x  
b=256        0.90x       1.12x       1.30x       1.55x       2.80x       4.82x       6.60x       3.58x       4.22x  
b=512        0.88x       1.29x       1.51x       2.19x       3.82x       5.29x       5.89x       4.02x       4.10x  
b=1024       0.84x       1.25x       1.68x       2.79x       4.25x       5.89x       5.97x       3.65x       4.15x  
b=2048       0.93x       1.41x       1.84x       2.81x       4.94x       7.04x       6.35x       3.64x       4.24x  

@Lyxot
Copy link
Author

Lyxot commented Feb 9, 2026

Benchmark results may vary with hardware. Further test is required.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@Lyxot Lyxot requested a review from Copilot February 9, 2026 19:43
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 1091 to 1118
if (!contiguous && nc_dim > 0) {
array nc_shape_dev({nc_dim}, int32, nullptr, {});
array in_nc_strides_dev({nc_dim}, int64, nullptr, {});
array out_nc_strides_dev({nc_dim}, int64, nullptr, {});
nc_shape_dev.set_data(cu::malloc_async(nc_shape_dev.nbytes(), encoder));
in_nc_strides_dev.set_data(
cu::malloc_async(in_nc_strides_dev.nbytes(), encoder));
out_nc_strides_dev.set_data(
cu::malloc_async(out_nc_strides_dev.nbytes(), encoder));

CHECK_CUDA_ERROR(cudaMemcpyAsync(
gpu_ptr<int32_t>(nc_shape_dev),
nc_shape.data(),
nc_shape_dev.nbytes(),
cudaMemcpyHostToDevice,
encoder.stream()));
CHECK_CUDA_ERROR(cudaMemcpyAsync(
gpu_ptr<int64_t>(in_nc_strides_dev),
in_nc_str.data(),
in_nc_strides_dev.nbytes(),
cudaMemcpyHostToDevice,
encoder.stream()));
CHECK_CUDA_ERROR(cudaMemcpyAsync(
gpu_ptr<int64_t>(out_nc_strides_dev),
out_nc_str.data(),
out_nc_strides_dev.nbytes(),
cudaMemcpyHostToDevice,
encoder.stream()));
Copy link

Copilot AI Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the non-contiguous path, this allocates temporary device buffers for nc_shape and stride arrays and copies them from host every call. Elsewhere in this file (e.g., sort kernels) similar shape/stride metadata is passed via const_param(...) into fixed-size cu::Shape/cu::Strides, avoiding async mallocs and H2D copies on the hot path. Consider switching the radix kernels to accept __grid_constant__ Shape/Strides (plus nc_dim) and pass them with const_param to reduce overhead.

Copilot uses AI. Check for mistakes.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is limited by MAX_NDIM. I’ve updated behavior so when nc_dim > MAX_NDIM, argpartition/partition now fall back to merge-sort. 47ca908

@Lyxot
Copy link
Author

Lyxot commented Feb 10, 2026

@zcbenz Could you please review this PR?

@zcbenz
Copy link
Collaborator

zcbenz commented Feb 12, 2026

@Lyxot Thanks for your contributions, I'm not familiar with gpu radix sort and I have to do some homework before I can review, and I'm currently stuck solving a hard problem so it will take some time before I can look into this. Maybe other maintainers can take a look before I do.

@angeloskath
Copy link
Member

Hi @Lyxot thanks for the PR! I think part of my comment on #3069 also applies here.

In short I think a PR that tries to address a smaller use case but is consistently better than the fallback would be much better. It would be shorter, the code would be simpler and more importantly we wouldn't have to either accept regressions or make some complicated heuristic for routing to the fallback.

My suggestion to begin with is to only tackle the use case of small axes that fit in shared memory. That would cover for instance MoE expert selection since the number of tokens can vary from 1 to 10s of thousands but the axis is fairly small 8 to a few hundreds. This is also a use case where the particular implementation is slow. (as is also the case in #3069).

@Lyxot
Copy link
Author

Lyxot commented Feb 16, 2026

@angeloskath I tuned the small-kernel which is fit in shared memory.

While doing this, I found shared-memory capacity differs a lot across GPUs, so I switched the strategy to query device shared memory and route to the small kernel when it fits.

On my 4070 Super, the opt-in shared memory is ~99KB, and with the current small-kernel configuration it can run float32 up to v=8192 on the small path.

If you prefer a simpler scope for this PR, I can remove the large-kernel path and keep only the small-kernel with fallback to merge sort for the rest.

@Lyxot
Copy link
Author

Lyxot commented Feb 16, 2026

current performance of small kernel is:

Dtype=bfloat16  k=vocab*0.004  small-kernel-limit≈16384  smem=99.0KB

             v=32        v=64        v=96       v=160       v=256       v=384       v=512       v=1024      v=2048      v=4096      v=8192     v=16384  
b=1          1.20x       1.14x       0.91x       1.06x       1.11x       1.03x       1.40x       1.14x       1.23x       1.05x       1.42x       1.17x  
b=4          1.12x       1.17x       1.83x       1.75x       1.27x       1.09x       0.94x       0.90x       0.99x       1.02x       1.41x       1.05x  
b=8          0.94x       1.20x       1.52x       1.63x       1.23x       1.44x       1.06x       1.11x       1.15x       1.02x       1.35x       0.99x  
b=16         1.04x       1.14x       1.37x       1.04x       1.02x       1.09x       0.93x       0.98x       1.21x       1.42x       1.24x       1.32x  
b=32         1.73x       1.72x       1.26x       1.04x       0.91x       0.98x       1.36x       1.11x       1.25x       1.06x       1.82x       1.86x  
b=64         1.13x       0.96x       1.08x       1.09x       1.22x       1.26x       1.35x       1.15x       1.04x       1.67x       1.88x       1.67x  
b=128        1.16x       0.95x       1.07x       1.52x       1.09x       1.18x       1.09x       1.50x       1.39x       1.98x       1.75x       1.98x  
b=256        1.16x       1.19x       1.12x       1.08x       1.35x       1.17x       1.11x       1.16x       1.68x       1.98x       2.33x       2.17x  
b=512        1.20x       0.91x       1.13x       1.11x       1.20x       1.23x       0.97x       1.76x       1.99x       2.41x       2.80x       2.89x  
b=1024       1.62x       1.40x       1.27x       1.37x       1.13x       1.23x       1.56x       1.81x       2.34x       2.53x       3.47x       2.87x  
b=2048       1.60x       1.24x       1.49x       1.09x       1.23x       1.48x       1.61x       1.97x       2.66x       2.78x       3.44x       2.99x  
b=4096       2.25x       1.95x       1.64x       1.54x       1.03x       2.05x       1.59x       2.42x       2.64x       2.71x       3.56x       3.02x  
b=8192       2.80x       2.31x       2.22x       1.73x       1.40x       2.49x       2.15x       2.56x       2.81x       2.73x       3.60x       3.07x  

Dtype=float32  k=vocab*0.004  small-kernel-limit≈12288  smem=99.0KB

             v=32        v=64        v=96       v=160       v=256       v=384       v=512       v=1024      v=2048      v=4096      v=8192  
b=1          1.11x       1.38x       1.04x       0.96x       1.22x       1.05x       1.00x       1.21x       1.16x       1.42x       1.22x  
b=4          0.91x       1.10x       1.20x       1.23x       0.96x       1.02x       0.97x       0.93x       1.36x       1.00x       1.39x  
b=8          1.04x       1.15x       1.05x       0.94x       0.95x       0.90x       0.98x       1.05x       1.41x       1.33x       1.24x  
b=16         1.10x       0.96x       1.19x       1.27x       0.94x       1.07x       0.92x       0.99x       1.18x       1.26x       1.56x  
b=32         1.02x       1.08x       1.30x       0.98x       0.94x       1.18x       1.01x       0.98x       1.31x       1.41x       1.65x  
b=64         0.95x       0.98x       1.07x       0.97x       0.94x       1.05x       0.90x       1.08x       1.45x       1.46x       1.31x  
b=128        0.90x       1.09x       1.23x       0.99x       1.10x       1.09x       0.93x       1.53x       1.36x       1.97x       1.68x  
b=256        1.08x       0.95x       0.94x       1.29x       0.99x       1.11x       1.11x       1.26x       2.00x       2.33x       1.88x  
b=512        0.96x       1.72x       1.98x       0.93x       1.02x       1.15x       1.07x       1.91x       2.24x       2.54x       2.02x  
b=1024       1.36x       1.55x       1.18x       1.37x       1.12x       1.29x       1.46x       1.89x       2.26x       2.72x       2.19x  
b=2048       1.36x       1.82x       1.51x       1.35x       1.26x       1.56x       1.44x       2.05x       2.63x       2.80x       2.29x  
b=4096       2.10x       2.03x       1.78x       1.41x       1.40x       2.19x       1.72x       2.53x       2.71x       2.93x       2.32x  
b=8192       2.45x       2.33x       2.18x       1.80x       1.37x       2.52x       2.04x       2.57x       2.98x       3.19x       2.32x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants