[CUDA][Performance] Add radix select implementation for efficient partition operations#3117
[CUDA][Performance] Add radix select implementation for efficient partition operations#3117Lyxot wants to merge 16 commits intoml-explore:mainfrom
Conversation
Fix two correctness issues in CUDA radix partition/argpartition: - In the large contiguous radix path, stop deriving row bases from `row * min(non-axis stride)` and compute row offsets with `elem_to_loc(...)` using non-axis shape/strides (matching merge-sort indexing behavior). - Keep stride arguments 64-bit end-to-end in radix-select kernels and launches (remove narrowing to `int` and related `INT32_MAX` guard). This fixes incorrect row addressing for valid contiguous non-linear layouts (e.g. column-major with axis=0) and avoids silent misindexing on large strides.
Eliminate MAX_NDIM-based rank limits in CUDA radix partition/argpartition by switching radix kernels from fixed-size __grid_constant__ shape/stride params to dynamic device pointers for non-axis metadata. Changes: - Update radix kernels to take dynamic NC metadata pointers: - radix_select_small_nc_kernel - radix_select_large_streaming_kernel - radix_select_large_streaming_nc_kernel - In gpu_radix_partition_small/gpu_radix_partition_large: - allocate device buffers for nc_shape/in_nc_strides/out_nc_strides - copy host metadata with cudaMemcpyAsync - pass pointers into kernel launches - Remove MAX_NDIM-dependent routing so high-rank tensors can still use radix partition path. - Keep stride handling 64-bit end-to-end in radix launches/kernels. Also slightly widens fallback-model threshold range (without changing max_rows).
remove fallback strategy
|
I got the following benchmark results on the 4070 Super |
|
Most performance is basically OK, but there are still some dtypes that need further optimization (float32) |
|
Benchmark results may vary with hardware. Further test is required. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mlx/backend/cuda/sort.cu
Outdated
| if (!contiguous && nc_dim > 0) { | ||
| array nc_shape_dev({nc_dim}, int32, nullptr, {}); | ||
| array in_nc_strides_dev({nc_dim}, int64, nullptr, {}); | ||
| array out_nc_strides_dev({nc_dim}, int64, nullptr, {}); | ||
| nc_shape_dev.set_data(cu::malloc_async(nc_shape_dev.nbytes(), encoder)); | ||
| in_nc_strides_dev.set_data( | ||
| cu::malloc_async(in_nc_strides_dev.nbytes(), encoder)); | ||
| out_nc_strides_dev.set_data( | ||
| cu::malloc_async(out_nc_strides_dev.nbytes(), encoder)); | ||
|
|
||
| CHECK_CUDA_ERROR(cudaMemcpyAsync( | ||
| gpu_ptr<int32_t>(nc_shape_dev), | ||
| nc_shape.data(), | ||
| nc_shape_dev.nbytes(), | ||
| cudaMemcpyHostToDevice, | ||
| encoder.stream())); | ||
| CHECK_CUDA_ERROR(cudaMemcpyAsync( | ||
| gpu_ptr<int64_t>(in_nc_strides_dev), | ||
| in_nc_str.data(), | ||
| in_nc_strides_dev.nbytes(), | ||
| cudaMemcpyHostToDevice, | ||
| encoder.stream())); | ||
| CHECK_CUDA_ERROR(cudaMemcpyAsync( | ||
| gpu_ptr<int64_t>(out_nc_strides_dev), | ||
| out_nc_str.data(), | ||
| out_nc_strides_dev.nbytes(), | ||
| cudaMemcpyHostToDevice, | ||
| encoder.stream())); |
There was a problem hiding this comment.
In the non-contiguous path, this allocates temporary device buffers for nc_shape and stride arrays and copies them from host every call. Elsewhere in this file (e.g., sort kernels) similar shape/stride metadata is passed via const_param(...) into fixed-size cu::Shape/cu::Strides, avoiding async mallocs and H2D copies on the hot path. Consider switching the radix kernels to accept __grid_constant__ Shape/Strides (plus nc_dim) and pass them with const_param to reduce overhead.
There was a problem hiding this comment.
This is limited by MAX_NDIM. I’ve updated behavior so when nc_dim > MAX_NDIM, argpartition/partition now fall back to merge-sort. 47ca908
|
@zcbenz Could you please review this PR? |
|
@Lyxot Thanks for your contributions, I'm not familiar with gpu radix sort and I have to do some homework before I can review, and I'm currently stuck solving a hard problem so it will take some time before I can look into this. Maybe other maintainers can take a look before I do. |
|
Hi @Lyxot thanks for the PR! I think part of my comment on #3069 also applies here. In short I think a PR that tries to address a smaller use case but is consistently better than the fallback would be much better. It would be shorter, the code would be simpler and more importantly we wouldn't have to either accept regressions or make some complicated heuristic for routing to the fallback. My suggestion to begin with is to only tackle the use case of small axes that fit in shared memory. That would cover for instance MoE expert selection since the number of tokens can vary from 1 to 10s of thousands but the axis is fairly small 8 to a few hundreds. This is also a use case where the particular implementation is slow. (as is also the case in #3069). |
based on estimated shared-memory usage and device limits
|
@angeloskath I tuned the small-kernel which is fit in shared memory. While doing this, I found shared-memory capacity differs a lot across GPUs, so I switched the strategy to query device shared memory and route to the small kernel when it fits. On my 4070 Super, the opt-in shared memory is ~99KB, and with the current small-kernel configuration it can run If you prefer a simpler scope for this PR, I can remove the large-kernel path and keep only the small-kernel with fallback to merge sort for the rest. |
|
current performance of small kernel is: |
Proposed changes
This adds a CUDA radix-select based path for
argpartitionpartitionand introduces multi-block-per-row and multi-row-per-block for shapes where normal radix select underperforms. #3064What changed
mlx/backend/cuda/device/radix_select.cuh:blocks_per_row,rows_per_block)mlx/backend/cuda/sort.cu:ArgPartition::eval_gpu/Partition::eval_gpunow call radix partition pathbenchmarks/python/radix_select_bench.py: Correctness checks, determinism checks, and performance sweep utilitiesChecklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes