Skip to content

Add all_to_all collective primitive#3164

Open
0xDaizz wants to merge 1 commit intoml-explore:mainfrom
0xDaizz:feat/ep-pr1a
Open

Add all_to_all collective primitive#3164
0xDaizz wants to merge 1 commit intoml-explore:mainfrom
0xDaizz:feat/ep-pr1a

Conversation

@0xDaizz
Copy link

@0xDaizz 0xDaizz commented Feb 24, 2026

Summary

Add all_to_all collective primitive for exchanging equal-sized chunks between all ranks. First in a series of PRs adding Expert Parallelism support.

  • mx.distributed.all_to_all(x) — splits input along axis 0, sends chunk i to rank i, concatenates received chunks
  • CPU eval via backend-specific GroupImpl::all_to_all
  • MPI and JACCL backends implemented
  • VJP support (all_to_all is its own transpose)

Follow-ups:

Test plan

  • pytest python/tests/mlx_distributed_tests.py -k all_to_all — 4 passed, 1 skipped (shape validation requires ws>1)
  • 2-rank JACCL integration verified
  • Linux CPU-only build (CI)
  • MPI multi-rank (CI)

🤖 Generated with Claude Code

Add `mx.distributed.all_to_all(x)` — splits input along axis 0,
sends chunk i to rank i, and concatenates received chunks.

- CPU eval via backend-specific GroupImpl::all_to_all
- MPI backend (MPI_Alltoall) and JACCL backend (RDMA pipelined)
- VJP support (all_to_all is its own transpose)
- GPU/CUDA stubs (not-implemented exceptions)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@angeloskath
Copy link
Member

Great I 'll take a look soon

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants