Skip to content

Conversation

@NripeshN
Copy link
Contributor

@NripeshN NripeshN commented Jan 27, 2026

Implementation for metal #3064

Copilot AI review requested due to automatic review settings January 27, 2026 02:33
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR attempts to add a radix select implementation for efficient partition operations in the CUDA backend to address performance issues with topk/partition operations reported in issue #3064.

Changes:

  • Added new header file mlx/backend/cuda/device/radix_select.cuh with radix traits and utility functions for radix-based selection
  • Implemented radix select kernels in mlx/backend/cuda/sort.cu including histogram, bin-finding, filtering, and collection kernels
  • Modified partition and argpartition GPU evaluation to use the new radix select algorithm

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 9 comments.

File Description
mlx/backend/cuda/device/radix_select.cuh Adds radix traits for type conversion, utility functions for radix operations, and helper functions for NaN handling and histogram operations
mlx/backend/cuda/sort.cu Implements radix select kernels and dispatching logic; modifies ArgPartition and Partition eval_gpu methods to call new gpu_radix_partition function

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@NripeshN NripeshN changed the title Add radix select implementation for efficient partition operations [Experiment] Add radix select implementation for efficient partition operations Jan 27, 2026
@NripeshN NripeshN marked this pull request as draft January 27, 2026 03:04
@NripeshN
Copy link
Contributor Author

@awni

On metal I am getting these speedup

Results for the benchmark case (b=2048, v=8192, k=32):
bfloat16: 0.923ms vs 4.466ms (sort) = 4.84x speedup
float16: 0.822ms vs 4.182ms (sort) = 5.08x speedup
float32: 1.266ms vs 5.280ms (sort) = 4.17x speedup

@NripeshN
Copy link
Contributor Author

@awni

image

@NripeshN
Copy link
Contributor Author

The code in issue #3064 gives the following output with metal and mps backed torch:

MLX ms=967.725
PyTorch ms=7261.384

@NripeshN NripeshN marked this pull request as ready for review January 27, 2026 05:50
@NripeshN NripeshN changed the title [Experiment] Add radix select implementation for efficient partition operations Add radix select implementation for efficient partition operations Jan 27, 2026
NripeshN and others added 12 commits January 27, 2026 11:16
This commit introduces an optimized radix-based selection algorithm for
ArgPartition and Partition operations on CUDA, replacing the previous
approach of doing a full sort.

Key changes:
- Add mlx/backend/cuda/device/radix_select.cuh with:
  - RadixTraits for IEEE 754 bit manipulation (preserves sort order)
  - Support for all numeric types (float, double, half, bfloat16, integers)
  - Hierarchical atomics utilities for histogram building
  - NaN handling that places NaNs at the end

- Add radix select kernels in sort.cu:
  - radix_histogram_kernel: Build per-row histograms in shared memory
  - radix_find_bin_kernel: Find target bin containing kth element
  - radix_filter_kernel: Filter candidates with flush-efficient write buffer
  - radix_collect_topk_kernel: Final collection of partitioned elements
  - radix_select_small_kernel: Optimized single-pass kernel for small arrays

- Update ArgPartition::eval_gpu and Partition::eval_gpu to use radix select

Algorithm complexity:
- Previous: O(n log n) merge sort
- New: O(n) expected for radix select

For bfloat16/float16 with n=8192, k=32:
- Only 2 passes maximum needed (16 bits / 8 bits per pass)
- Expected ~6-10x speedup over full sort

Based on RadiK paper (Li et al., ICS'24) optimizations.
This commit adds an optimized radix-based selection algorithm for
ArgPartition and Partition operations on Metal (Apple Silicon).

Key changes:
- Add mlx/backend/metal/kernels/radix_select.h with:
  - RadixTraits for IEEE 754 bit manipulation (float, half, bfloat16)
  - Support for all integer types (signed/unsigned, 8-64 bit)
  - Threadgroup-level histogram building with atomic operations
  - RadixSelectSmall kernel for arrays up to 2048 elements

- Add mlx/backend/metal/kernels/radix_select.metal:
  - Kernel instantiations for all supported types
  - Both contiguous and non-contiguous variants

- Update mlx/backend/metal/sort.cpp:
  - Add gpu_radix_partition() dispatch function
  - Update ArgPartition::eval_gpu and Partition::eval_gpu

- Update JIT compilation support:
  - Add get_radix_select_kernel() in jit_kernels.cpp
  - Register radix_select in includes.h and CMakeLists.txt

Algorithm:
- Iterates through digits from MSB to LSB (8 bits at a time)
- Builds histogram in threadgroup memory
- Finds target bin via prefix sum
- Outputs partitioned array in three phases:
  1. Elements less than pivot
  2. Elements equal to pivot
  3. Elements greater than pivot

For bfloat16/float16 with n=2048, k=32:
- Only 2 passes needed (16 bits / 8 bits per pass)
- Expected significant speedup over full merge sort

Based on RadiK paper (Li et al., ICS'24) optimizations.
- Fix as_type cast for half and bfloat16 by using intermediate variable
- Remove unused RadixHistogram struct that used CUDA-specific gridDim
- Add get_radix_select_kernel to nojit_kernels.cpp
- Add radix_select to non-JIT kernel build in CMakeLists.txt
- Add radix_histogram_kernel for building global histograms
- Add radix_find_bin_kernel for finding target bin
- Add radix_filter_kernel for filtering candidates
- Add radix_collect_kernel for final output collection
- Add benchmark_radix_select.py for testing performance

Note: Multi-pass dispatch not yet implemented in host code.
Currently falls back to merge sort for arrays > 2048 elements.
This commit adds a complete multi-pass radix select implementation that
provides 4-5x speedup over full merge sort for large arrays.

Key changes:

Metal Kernels (radix_select.h):
- Add radix_histogram_kernel: builds histogram with prefix filtering
- Add radix_find_bin_kernel: finds target bin from histogram
- Add radix_partition_output_kernel: outputs elements < pivot
- Add radix_partition_equal_kernel: outputs elements == pivot
- Add radix_partition_greater_kernel: outputs elements > pivot
- Refactored RadixTraits for cleaner code
- Support for prefix_mask and target_prefix in histogram building

Host-side dispatch (sort.cpp):
- Add gpu_radix_partition_small(): single-pass for arrays <= 2048
- Add gpu_radix_partition_large(): multi-pass for larger arrays
- Add get_radix_bits() helper for dtype bit width
- Proper temporary buffer allocation for histograms and counters
- Multi-pass loop iterating from MSB to LSB

Performance results (b=2048, v=8192, k=32):
- bfloat16: 0.92ms vs 4.47ms (sort) = 4.84x speedup
- float16:  0.82ms vs 4.18ms (sort) = 5.08x speedup
- float32:  1.27ms vs 5.28ms (sort) = 4.17x speedup

Algorithm complexity:
- Radix select: O(n) expected with 2-4 passes for 16-32 bit types
- Merge sort: O(n log n)

For the benchmark case with n=8192, this is log2(8192)=13 vs 2-4 passes,
explaining the ~4-5x speedup.
…g kernel

This commit introduces significant improvements to the radix select algorithm for partition operations in Metal. Key changes include:

- Updated kernel implementations in `radix_select.h` and `radix_select.metal` to support a new streaming approach for large arrays, allowing all radix passes to be processed in a single dispatch.
- Enhanced performance through SIMD-optimized histogram building and coalesced memory access patterns.
- Refactored the `gpu_radix_partition_large` function in `sort.cpp` to utilize the new streaming kernel, improving efficiency for large datasets.
- Added comprehensive documentation for the new kernel functionalities and optimizations.

These changes aim to provide better performance and scalability for partition operations on large arrays, aligning with the latest advancements in GPU computing.
This commit introduces a new function, `gpu_radix_partition_large_nc`, to handle non-contiguous arrays in the radix select algorithm. Key changes include:

- Implementation of a non-contiguous streaming kernel in `radix_select.h` and `radix_select.metal`, allowing for efficient partitioning of large arrays with proper multi-dimensional indexing.
- Refactoring of the `gpu_radix_partition` function in `sort.cpp` to utilize the new non-contiguous kernel when necessary, enhancing flexibility for different array layouts.
- Added kernel instantiations for various data types to support the new functionality.

These enhancements aim to improve performance and usability for partition operations on non-contiguous datasets in Metal.
@awni awni force-pushed the cuda-radix-select branch from 0cb7965 to 5144eb6 Compare January 27, 2026 19:16
@awni
Copy link
Member

awni commented Jan 27, 2026

A few shape/type combinations are a lot slower (esp a long vector with batch size 1):

Pre:
(mlx.core.bfloat16, b=1, v=128000): ms=99.249
(mlx.core.bfloat16, b=1, v=512): ms=21.032
(mlx.core.bfloat16, b=2048, v=4096): ms=1399.125
(mlx.core.bfloat16, b=2048, v=8192): ms=3214.426
(mlx.core.float32, b=1, v=128000): ms=100.609
(mlx.core.float32, b=1, v=512): ms=20.622
(mlx.core.float32, b=2048, v=4096): ms=1631.850
(mlx.core.float32, b=2048, v=8192): ms=3833.621

Post:

(mlx.core.bfloat16, b=1, v=128000): ms=317.180
(mlx.core.bfloat16, b=1, v=512): ms=21.306
(mlx.core.bfloat16, b=2048, v=4096): ms=386.231
(mlx.core.bfloat16, b=2048, v=8192): ms=721.148
(mlx.core.float32, b=1, v=128000): ms=479.955
(mlx.core.float32, b=1, v=512): ms=35.616
(mlx.core.float32, b=2048, v=4096): ms=761.181
(mlx.core.float32, b=2048, v=8192): ms=1352.038

We could fallback to sort for those cases or try and make the radix kernel faster.

@NripeshN
Copy link
Contributor Author

A few shape/type combinations are a lot slower (esp a long vector with batch size 1):

Yea this kinda makes sense, Radix select uses a single threadgroup per row, so if number of rows are less and number of elements are higher then it might just be inefficient. Let me see if there is any way I could optimize the kernel but my best bet might just be to fallback to merge sort in cases like this

@NripeshN
Copy link
Contributor Author

@awni
Does metal have global synchronization between threadgroups?

@awni
Copy link
Member

awni commented Jan 27, 2026

The way to do that would be with a second kernel

@NripeshN
Copy link
Contributor Author

The overhead of multiple kernel dispatches might not be worth it tbh. merge sort might just be efficient in this cases.

@awni
Copy link
Member

awni commented Jan 27, 2026

For very long vectors (especially half precision) it might be worth it. Sorting is quite a bit more expensive than the radix select.

But in the interest of keeping this PR manageable I think a fall back is good for now and we can follow up if it makes sense

@NripeshN
Copy link
Contributor Author

Out of curiosity have you already started working on cuda on this branch or shall we create a different PR for cuda implementation? I might have to spin up an EC2 instance to work on the cuda implementation, but we could quite definitely get some performance boost with cuda

@awni
Copy link
Member

awni commented Jan 27, 2026

I am not working on CUDA. A separate PR would indeed be better.

@awni
Copy link
Member

awni commented Jan 27, 2026

Let me know if you plan to work on it, otherwise I will look into it.

@NripeshN
Copy link
Contributor Author

I have honestly gotten way too invested in this😅 I will create a draft PR in a bit and we could work on it together

@NripeshN
Copy link
Contributor Author

@awni
I am not too sure what heuristics I should set for the fallback tbh. This is what I have done for now

@NripeshN NripeshN changed the title Add radix select implementation for efficient partition operations [Metal] [Performance] Add radix select implementation for efficient partition operations Jan 27, 2026
@awni
Copy link
Member

awni commented Jan 27, 2026

I removed the benchmark as well. I will update the one in the issue which is what we should use. It's more representative of the workload we want to optimize.

@NripeshN
Copy link
Contributor Author

NripeshN commented Jan 27, 2026

@awni
I made a push for cuda in this branch itself, if you do have a cuda machine could you maybe quickly run the test on the issue so we know if this already works? If we need a complete implementation I will just create a new branch

NVM it would not do great, we might have to implement everything ourselves as topk from cuda does not have batch support

…performance for small arrays. Removed fallback to merge sort for small sizes, optimizing the handling of contiguous data in the radix select algorithm. Enhanced histogram building for contiguous data to improve memory throughput.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants