Add Expert Parallelism for MoE inference#3158
Add Expert Parallelism for MoE inference#31580xDaizz wants to merge 4 commits intoml-explore:mainfrom
Conversation
After gpu::synchronize() inside MoE fused primitives, the previous command buffer is committed. Re-acquire it so subsequent Metal dispatches use a valid buffer. This is a no-op for non-MoE paths since get_command_buffer returns the existing buffer if still valid. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce Expert Parallelism (EP) infrastructure for MoE inference: - all_to_all collective primitive with CPU eval and Metal/CUDA stubs - GroupImpl blocking_send/recv/sendrecv and variable exchange (exchange_v) for deadlock-free inter-rank communication with rank-parity ordering - MoeDispatchExchange and MoeCombineExchange fused primitives that combine local expert routing with inter-rank token exchange - 7 Metal GPU kernels for O(N*D) data movement (dispatch_local, dispatch_scatter_remote, combine_gather_remote, combine_weighted_sum, packet_gather, packet_scatter) - CPU and Metal backends with ws>2 automatic CPU fallback - Row-contiguous input guards for safe pointer arithmetic Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- MixtureOfExperts nn.Module with TopKRouter and SwiGLU expert FFN
- Python bindings for moe_dispatch_exchange and moe_combine_exchange
- ep_impl parameter ("python" for pure-Python, "cpp" for C++ fused path)
- Documentation for all_to_all, moe_dispatch_exchange, moe_combine_exchange
in distributed.rst and MixtureOfExperts in layers.rst
- Docstrings with backend parameter docs and inference-only limitations
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- test_moe.py: 40+ tests covering TopKRouter, Expert, MixtureOfExperts, vectorized dispatch/combine, and C++ MoE exchange primitives (CPU/Metal roundtrip, overflow handling, dtype support, large batch) - mlx_distributed_tests.py: all_to_all base distributed tests - jaccl_test_distributed.py: 2-rank JACCL integration tests for dispatch/combine roundtrip correctness over RDMA Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Follow-up (planned): After this PR is merged, a second PR will add production infrastructure:
The follow-up branch ( |
|
Thanks for the great work @0xDaizz ! However, in order for this to be evaluated properly and merged I think you need to break it into multiple PRs. You can start for instance with the |
|
@angeloskath Thanks for the guidance — makes sense. |
Sub-PR Checklist
all_to_allcollective primitive (Add all_to_all collective primitive #3164)MixtureOfExpertslayer + tests/docsOriginal description below for reference.
Summary
Add Expert Parallelism (EP) for MoE inference on Apple Silicon.
EP distributes experts across devices; each device holds a subset and tokens are routed via all-to-all exchange. This enables scaling MoE models (Mixtral, DBRX, Kimi K2.5) beyond single-device memory.
Key components
all_to_allcollective — equal-chunk exchange between all ranksblocking_exchange_vfor variable-size packet exchange (RDMA/MPI)MixtureOfExpertslayer — drop-in replacement with EP supportPerformance (2-rank JACCL, E=384 D=7168 top_k=8)