Skip to content

Add Expert Parallelism for MoE inference#3158

Draft
0xDaizz wants to merge 4 commits intoml-explore:mainfrom
0xDaizz:feat/ep-pr1
Draft

Add Expert Parallelism for MoE inference#3158
0xDaizz wants to merge 4 commits intoml-explore:mainfrom
0xDaizz:feat/ep-pr1

Conversation

@0xDaizz
Copy link

@0xDaizz 0xDaizz commented Feb 23, 2026

Umbrella PR — broken into smaller, focused PRs per maintainer feedback.
Each sub-PR is independently reviewable and testable.

Sub-PR Checklist

  • PR1Aall_to_all collective primitive (Add all_to_all collective primitive #3164)
  • PR1B — MoE dispatch/combine C++ primitives + blocking comm infra
  • PR1C — Metal GPU runtime for MoE primitives
  • PR1D — Python MixtureOfExperts layer + tests/docs
  • PR2A — Production infra: auto backend policy, warmup, metrics
  • PR2B — Performance: batched expert FFN + zero-copy combine
  • PR2C — Benchmarks (EP diagnostic + EP-vs-TP comparison)

Original description below for reference.

Summary

Add Expert Parallelism (EP) for MoE inference on Apple Silicon.

EP distributes experts across devices; each device holds a subset and tokens are routed via all-to-all exchange. This enables scaling MoE models (Mixtral, DBRX, Kimi K2.5) beyond single-device memory.

Key components

  1. all_to_all collective — equal-chunk exchange between all ranks
  2. Fused MoE dispatch/combine primitives — C++ primitives with CPU + Metal GPU backends
  3. Blocking comm infrablocking_exchange_v for variable-size packet exchange (RDMA/MPI)
  4. Auto backend policy — CPU for decode (N≤64), Metal for prefill (N≥320)
  5. Python MixtureOfExperts layer — drop-in replacement with EP support
  6. Production infra — warmup, metrics, GPU fallback

Performance (2-rank JACCL, E=384 D=7168 top_k=8)

Regime EP vs TP Interpretation
Decode (N≤64) 0.32x EP 3.1x faster
Prefill (N≥256) 1.07x TP 7% faster
Overall geomean 0.55x EP 1.8x faster

0xDaizz and others added 4 commits February 24, 2026 03:01
After gpu::synchronize() inside MoE fused primitives, the previous
command buffer is committed. Re-acquire it so subsequent Metal
dispatches use a valid buffer. This is a no-op for non-MoE paths
since get_command_buffer returns the existing buffer if still valid.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Introduce Expert Parallelism (EP) infrastructure for MoE inference:

- all_to_all collective primitive with CPU eval and Metal/CUDA stubs
- GroupImpl blocking_send/recv/sendrecv and variable exchange (exchange_v)
  for deadlock-free inter-rank communication with rank-parity ordering
- MoeDispatchExchange and MoeCombineExchange fused primitives that combine
  local expert routing with inter-rank token exchange
- 7 Metal GPU kernels for O(N*D) data movement (dispatch_local,
  dispatch_scatter_remote, combine_gather_remote, combine_weighted_sum,
  packet_gather, packet_scatter)
- CPU and Metal backends with ws>2 automatic CPU fallback
- Row-contiguous input guards for safe pointer arithmetic

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- MixtureOfExperts nn.Module with TopKRouter and SwiGLU expert FFN
- Python bindings for moe_dispatch_exchange and moe_combine_exchange
- ep_impl parameter ("python" for pure-Python, "cpp" for C++ fused path)
- Documentation for all_to_all, moe_dispatch_exchange, moe_combine_exchange
  in distributed.rst and MixtureOfExperts in layers.rst
- Docstrings with backend parameter docs and inference-only limitations

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- test_moe.py: 40+ tests covering TopKRouter, Expert, MixtureOfExperts,
  vectorized dispatch/combine, and C++ MoE exchange primitives
  (CPU/Metal roundtrip, overflow handling, dtype support, large batch)
- mlx_distributed_tests.py: all_to_all base distributed tests
- jaccl_test_distributed.py: 2-rank JACCL integration tests for
  dispatch/combine roundtrip correctness over RDMA

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@0xDaizz
Copy link
Author

0xDaizz commented Feb 23, 2026

Follow-up (planned): After this PR is merged, a second PR will add production infrastructure:

  • Auto backend policy (CPU/Metal selection based on workload size)
  • Batched expert FFN (1.2x decode speedup)
  • Zero-copy combine Metal kernel
  • RDMA warmup, metrics, and fallback hardening
  • EP vs TP comparison benchmarks

The follow-up branch (feat/ep-pr2) is ready on the fork and will be rebased onto main once this PR merges.

@angeloskath
Copy link
Member

Thanks for the great work @0xDaizz !

However, in order for this to be evaluated properly and merged I think you need to break it into multiple PRs. You can start for instance with the all_to_all implementation. My advice is either close this or turn it into a draft PR that gets constantly rebased on main and becomes smaller and smaller as the constituent PRs get merged.

@0xDaizz
Copy link
Author

0xDaizz commented Feb 24, 2026

@angeloskath Thanks for the guidance — makes sense.
I’ll convert this PR to draft and split it into smaller PRs, starting with all_to_all.
I’ll link each follow-up PR here as they are opened.

@0xDaizz 0xDaizz marked this pull request as draft February 24, 2026 10:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants