-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[Metal] [Performance] Add radix select implementation for efficient partition operations #3069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR attempts to add a radix select implementation for efficient partition operations in the CUDA backend to address performance issues with topk/partition operations reported in issue #3064.
Changes:
- Added new header file
mlx/backend/cuda/device/radix_select.cuhwith radix traits and utility functions for radix-based selection - Implemented radix select kernels in
mlx/backend/cuda/sort.cuincluding histogram, bin-finding, filtering, and collection kernels - Modified partition and argpartition GPU evaluation to use the new radix select algorithm
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
| mlx/backend/cuda/device/radix_select.cuh | Adds radix traits for type conversion, utility functions for radix operations, and helper functions for NaN handling and histogram operations |
| mlx/backend/cuda/sort.cu | Implements radix select kernels and dispatching logic; modifies ArgPartition and Partition eval_gpu methods to call new gpu_radix_partition function |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
On metal I am getting these speedup |
|
The code in issue #3064 gives the following output with metal and mps backed torch: MLX ms=967.725 |
This commit introduces an optimized radix-based selection algorithm for ArgPartition and Partition operations on CUDA, replacing the previous approach of doing a full sort. Key changes: - Add mlx/backend/cuda/device/radix_select.cuh with: - RadixTraits for IEEE 754 bit manipulation (preserves sort order) - Support for all numeric types (float, double, half, bfloat16, integers) - Hierarchical atomics utilities for histogram building - NaN handling that places NaNs at the end - Add radix select kernels in sort.cu: - radix_histogram_kernel: Build per-row histograms in shared memory - radix_find_bin_kernel: Find target bin containing kth element - radix_filter_kernel: Filter candidates with flush-efficient write buffer - radix_collect_topk_kernel: Final collection of partitioned elements - radix_select_small_kernel: Optimized single-pass kernel for small arrays - Update ArgPartition::eval_gpu and Partition::eval_gpu to use radix select Algorithm complexity: - Previous: O(n log n) merge sort - New: O(n) expected for radix select For bfloat16/float16 with n=8192, k=32: - Only 2 passes maximum needed (16 bits / 8 bits per pass) - Expected ~6-10x speedup over full sort Based on RadiK paper (Li et al., ICS'24) optimizations.
This commit adds an optimized radix-based selection algorithm for ArgPartition and Partition operations on Metal (Apple Silicon). Key changes: - Add mlx/backend/metal/kernels/radix_select.h with: - RadixTraits for IEEE 754 bit manipulation (float, half, bfloat16) - Support for all integer types (signed/unsigned, 8-64 bit) - Threadgroup-level histogram building with atomic operations - RadixSelectSmall kernel for arrays up to 2048 elements - Add mlx/backend/metal/kernels/radix_select.metal: - Kernel instantiations for all supported types - Both contiguous and non-contiguous variants - Update mlx/backend/metal/sort.cpp: - Add gpu_radix_partition() dispatch function - Update ArgPartition::eval_gpu and Partition::eval_gpu - Update JIT compilation support: - Add get_radix_select_kernel() in jit_kernels.cpp - Register radix_select in includes.h and CMakeLists.txt Algorithm: - Iterates through digits from MSB to LSB (8 bits at a time) - Builds histogram in threadgroup memory - Finds target bin via prefix sum - Outputs partitioned array in three phases: 1. Elements less than pivot 2. Elements equal to pivot 3. Elements greater than pivot For bfloat16/float16 with n=2048, k=32: - Only 2 passes needed (16 bits / 8 bits per pass) - Expected significant speedup over full merge sort Based on RadiK paper (Li et al., ICS'24) optimizations.
- Fix as_type cast for half and bfloat16 by using intermediate variable - Remove unused RadixHistogram struct that used CUDA-specific gridDim - Add get_radix_select_kernel to nojit_kernels.cpp - Add radix_select to non-JIT kernel build in CMakeLists.txt
- Add radix_histogram_kernel for building global histograms - Add radix_find_bin_kernel for finding target bin - Add radix_filter_kernel for filtering candidates - Add radix_collect_kernel for final output collection - Add benchmark_radix_select.py for testing performance Note: Multi-pass dispatch not yet implemented in host code. Currently falls back to merge sort for arrays > 2048 elements.
This commit adds a complete multi-pass radix select implementation that provides 4-5x speedup over full merge sort for large arrays. Key changes: Metal Kernels (radix_select.h): - Add radix_histogram_kernel: builds histogram with prefix filtering - Add radix_find_bin_kernel: finds target bin from histogram - Add radix_partition_output_kernel: outputs elements < pivot - Add radix_partition_equal_kernel: outputs elements == pivot - Add radix_partition_greater_kernel: outputs elements > pivot - Refactored RadixTraits for cleaner code - Support for prefix_mask and target_prefix in histogram building Host-side dispatch (sort.cpp): - Add gpu_radix_partition_small(): single-pass for arrays <= 2048 - Add gpu_radix_partition_large(): multi-pass for larger arrays - Add get_radix_bits() helper for dtype bit width - Proper temporary buffer allocation for histograms and counters - Multi-pass loop iterating from MSB to LSB Performance results (b=2048, v=8192, k=32): - bfloat16: 0.92ms vs 4.47ms (sort) = 4.84x speedup - float16: 0.82ms vs 4.18ms (sort) = 5.08x speedup - float32: 1.27ms vs 5.28ms (sort) = 4.17x speedup Algorithm complexity: - Radix select: O(n) expected with 2-4 passes for 16-32 bit types - Merge sort: O(n log n) For the benchmark case with n=8192, this is log2(8192)=13 vs 2-4 passes, explaining the ~4-5x speedup.
…g kernel This commit introduces significant improvements to the radix select algorithm for partition operations in Metal. Key changes include: - Updated kernel implementations in `radix_select.h` and `radix_select.metal` to support a new streaming approach for large arrays, allowing all radix passes to be processed in a single dispatch. - Enhanced performance through SIMD-optimized histogram building and coalesced memory access patterns. - Refactored the `gpu_radix_partition_large` function in `sort.cpp` to utilize the new streaming kernel, improving efficiency for large datasets. - Added comprehensive documentation for the new kernel functionalities and optimizations. These changes aim to provide better performance and scalability for partition operations on large arrays, aligning with the latest advancements in GPU computing.
This commit introduces a new function, `gpu_radix_partition_large_nc`, to handle non-contiguous arrays in the radix select algorithm. Key changes include: - Implementation of a non-contiguous streaming kernel in `radix_select.h` and `radix_select.metal`, allowing for efficient partitioning of large arrays with proper multi-dimensional indexing. - Refactoring of the `gpu_radix_partition` function in `sort.cpp` to utilize the new non-contiguous kernel when necessary, enhancing flexibility for different array layouts. - Added kernel instantiations for various data types to support the new functionality. These enhancements aim to improve performance and usability for partition operations on non-contiguous datasets in Metal.
0cb7965 to
5144eb6
Compare
|
A few shape/type combinations are a lot slower (esp a long vector with batch size 1): Pre: Post: (mlx.core.bfloat16, b=1, v=128000): ms=317.180 We could fallback to sort for those cases or try and make the radix kernel faster. |
Yea this kinda makes sense, Radix select uses a single threadgroup per row, so if number of rows are less and number of elements are higher then it might just be inefficient. Let me see if there is any way I could optimize the kernel but my best bet might just be to fallback to merge sort in cases like this |
|
@awni |
|
The way to do that would be with a second kernel |
|
The overhead of multiple kernel dispatches might not be worth it tbh. merge sort might just be efficient in this cases. |
|
For very long vectors (especially half precision) it might be worth it. Sorting is quite a bit more expensive than the radix select. But in the interest of keeping this PR manageable I think a fall back is good for now and we can follow up if it makes sense |
|
Out of curiosity have you already started working on cuda on this branch or shall we create a different PR for cuda implementation? I might have to spin up an EC2 instance to work on the cuda implementation, but we could quite definitely get some performance boost with cuda |
|
I am not working on CUDA. A separate PR would indeed be better. |
|
Let me know if you plan to work on it, otherwise I will look into it. |
|
I have honestly gotten way too invested in this😅 I will create a draft PR in a bit and we could work on it together |
|
I removed the benchmark as well. I will update the one in the issue which is what we should use. It's more representative of the workload we want to optimize. |
|
@awni NVM it would not do great, we might have to implement everything ourselves as topk from cuda does not have batch support |
…performance for small arrays. Removed fallback to merge sort for small sizes, optimizing the handling of contiguous data in the radix select algorithm. Enhanced histogram building for contiguous data to improve memory throughput.

Implementation for metal #3064