Skip to content

feat: pull-based NIXL weight broadcast with lazy-tensor plan discovery#2779

Draft
S1ro1 wants to merge 8 commits into
mainfrom
nixl-transfer
Draft

feat: pull-based NIXL weight broadcast with lazy-tensor plan discovery#2779
S1ro1 wants to merge 8 commits into
mainfrom
nixl-transfer

Conversation

@S1ro1

@S1ro1 S1ro1 commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Summary

Sharded, zero-copy weight transfer from trainer to vLLM over pure NIXL RDMA, with Model Express as the metadata store. Combines the lazy-tensor slice-discovery of vllm-project/vllm#43375 (mimicked with a custom worker extension — no Ray, no RDT) with the sharded pure-NIXL transport of #2598 — but driven entirely by the lazy op-chains, so there are no per-model conversion specs and no all-gather.

Design

Trainer = passive, sharded weight store (trainer/rl/broadcast/nixl.py)

  • Registers each rank's local DTensor shards (to_local()) — never gathers, never converts naming/format. Serves the native prime-format state dict as-is.
  • Publishes one TrainerTable to Model Express: per state-dict tensor, which dim-0 row range lives on which rank's NIXL buffer. Everything is sharded on dim 0 (FSDP shards output/vocab, expert-parallel shards experts); expert weights use vLLM's FusedMoE global-expert numbering, everything else takes its range straight from the DTensor placement.
  • Serving dtype is bf16 (vLLM's): zero per-sync work when params are already bf16 (live shard storage is registered once and read in place); when they're fp32 master weights (optimization_dtype default), each rank casts only its local shard into a persistent bf16 buffer (a sharded cast, still no gather).
  • Never learns who its consumers are — no manifests, no per-worker plan — so inference can scale out / restart / fail with zero trainer-side coordination.

vLLM worker = sharded pull client (inference/vllm/worker/nixl.py)

  • Bake (once): drive the model's own load_weights with zero-storage LazyWeight placeholders. vLLM's loaders (fused QKV, merged gate/up, FusedMoE EP routing) slice them and copy_ into views of live params; each copy_ records (trainer tensor, op chain, destination view). Disallowed ops raise — loud failure over wrong bytes.
  • Route: resolve each chain to a region of the full logical trainer tensor, decompose into runs, and map onto the trainer shards that own those dim-0 rows (weight_transfer/sharding.py). Each worker pulls only its slices from whichever ranks hold them — e.g. an EP worker pulls its experts from the matching trainer EP ranks; a tp=1 dense param is pulled row-block-by-row-block from every FSDP rank.
  • Pull (per sync): one batched NIXL READ per trainer rank straight into live parameter storage, then process_weights_after_loading so any kernel-format repacking happens in place — exactly as a normal vLLM weight reload. The receive is zero-copy: NIXL writes directly into the destination param (no intermediate buffer — the "zero-copy receive" the vLLM PR aspires to but couldn't get from Ray RDT).

The only naming bridge is a generic, metadata-only adapter (weight_transfer/adapter.py): prime's stacked experts (experts.w1/w2/w3) explode into per-expert HF lazy views, and mlp.router.gatemlp.gate. Zero per-model code; dense models need nothing.

No artificial limitations. Kernel/quantized formats are handled by running process_weights_after_loading after the pull (the whole point of the lazy-tensor path) — there is no ban on quantized or non-identity MoE backends. The only hard requirements are an RDMA-registrable allocator (no expandable_segments) and no EPLB (runtime expert rearrangement would invalidate the baked plan).

Coordination rides the existing STABLE/NCCL_READY markers; additions are a step-scoped NIXL_DONE (workers block on it before pulling) and per-rank NIXL_PULLED acks (the trainer waits for all acks before returning, so no shard buffer is overwritten — or the trainer process exits — under an in-flight pull). Model Express carries only the one-time table.

Deployment

  • [weight_broadcast] type = "nixl" (shared config propagates to trainer/orchestrator/inference).
  • The multi-node SLURM template launches the Model Express server + redis on the trainer head node, loads the CUDA-enabled UCX from third_party/ucx, and pins UCX_NET_DEVICES to active RDMA ports. Build the MX server binaries once via scripts/install_modelexpress.sh; build CUDA UCX + NIXL via scripts/install_nixl_from_source.sh.
  • Validation config: configs/qwen3_30b_math_nixl_2node/ (NIXL + an NCCL-baseline twin).

Validation

2-node (1 train + 1 infer) Qwen3-30B-A3B, EP=8 (tp=1 dp=8 enable_expert_parallel), Hendrycks math, 10 steps, against an identical NCCL-broadcast baseline. Acceptance: Mismatch KL stable, not growing, and ≈ the NCCL baseline.

(see Validation result below)

Tests

tests/unit/weight_transfer/test_pull_plan.py — CPU end-to-end of the sharded mechanism: lazy bake against vLLM-style loaders (fused QKV, EP expert routing with skips, TP source narrows, padded vocab), chain resolution, region→shard routing across multiple trainer ranks, run matching, byte-level pull simulation verified bit-identical; plus loud-failure paths (unsupported ops, copy-materializing chains, length mismatches). Full unit suite passes.

🤖 Generated with Claude Code


Validation result ✅

2-node (1 train + 1 infer, 8 GPUs each), Qwen3-30B-A3B, EP=8 (tp=1 dp=8 enable_expert_parallel), Hendrycks math, 10 steps. NIXL run vs an identical NCCL-broadcast baseline.

  • Trainer published a sharded table: 579 tensors across 8 ranks (61 GB), no gather.
  • Each EP worker baked its plan and pulled only its slices (~10.3 GB) from all 8 trainer ranks in ~0.24 s/sync, straight into live params.
  • All 10 steps completed cleanly (RL trainer finished!), no NIXL errors.

Mismatch KL — flat, not growing, ≈ baseline:

Step 0 1 2 3 4 5 6 7 8 9 mean
NIXL 0.0014 0.0016 0.0013 0.0018 0.0012 0.0012 0.0018 0.0015 0.0019 0.0013 0.00150
NCCL 0.0015 0.0015 0.0014 0.0017 0.0011 0.0013 0.0021 0.0016 0.0017 0.0013 0.00152

Online FP8 (block-wise) inference ✅

Re-validated with the worker switched to the upstream-RDT mechanism (bake through vLLM's layerwise reload + process_weights_after_loading), pure NIXL+MX. Same 2-node EP=8 Qwen3-30B-A3B math run, inference under online block-FP8 (quantization="fp8_per_block", DeepGEMM block-FP8 MoE backend); trainer still streams bf16 shards and the inference side re-quantizes bf16→fp8-blockwise on every reload.

  • Bake covered 387/387 leaf modules; sharded pull moved 10.33 GB bf16 into staging across 8 ranks; per-sync pull + re-quantize ≈ 0.95 s (first sync 17 s was one-time DeepGEMM warmup). No NIXL errors; RL trainer finished!.
  • Mismatch KL over 10 steps: 0.0107, 0.0118, 0.0100, 0.0132, 0.0098, 0.0081, 0.0137, 0.0102, 0.0117, 0.0095 — stable, not growing, oscillating around a ~0.011 floor (the inherent gap between the bf16-precision trainer logprobs and the fp8-quantized inference weights; ~7× the bf16 path's 0.0015, as expected for fp8).
  • Unquantized path re-validated on the same (rewritten) worker: KL flat 0.0011–0.0020 ≈ NCCL baseline, confirming no regression.

The same code path covers both: process_weights_after_loading re-quantizes for fp8 and is a no-op for unquantized. Config: configs/qwen3_30b_math_nixl_2node/rl_fp8.toml.

S1ro1 and others added 8 commits June 11, 2026 22:30
Combines the approaches of vLLM PR #43375 (lazy-tensor weight-slice
discovery) and prime-rl PR #2598 (pure NIXL + Model Express) into a
converter-free, pull-based weight transport:

- trainer serves its native state dict as a passive NIXL pull source:
  per-rank persistent registered store, refreshed layer-by-layer per
  sync, one TrainerTable published to Model Express; it never tracks
  consumers, so inference can scale/restart freely
- each vLLM worker bakes its pull plan once by driving the model's own
  load_weights with zero-storage LazyWeight placeholders (op-chain
  recording), resolves chains to strided trainer regions, and issues
  batched NIXL READs straight into live parameter memory
- generic metadata-only adapter bridges prime-rl's stacked-expert naming
  (experts.w1/w2/w3, router.gate) to HF per-expert names; no per-model
  conversion code anywhere in the path
- per-sync coordination rides the existing STABLE/NCCL_READY markers
  plus a step-scoped NIXL_DONE marker; direct-write safety (unquantized,
  identity-format MoE backends, no EPLB) enforced at worker init

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…ommand

pkill exits 1 when nothing matches; with set -e at the top level the
whole job died before launch. Group them with the other pkills instead.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…XL broadcast

The nixl wheel's bundled UCX has no CUDA support; both trainer (pull
source) and inference (pull client) nodes must load the CUDA-enabled UCX
from third_party/ucx and pin UCX_NET_DEVICES to working RDMA ports.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…ning

Without the ack, the trainer process exits immediately after its final
broadcast, tearing down its NIXL agents while workers are still pulling
the last update (NIXL_ERR_REMOTE_DISCONNECT). The ack wait also
guarantees the store is never refreshed under an in-flight pull.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…nel-format limitation

Redesign per review: the transfer is now fully sharded and the false
direct-write restriction is gone.

Trainer (broadcast/nixl.py): registers each rank's LOCAL DTensor shards
(to_local), never gathers. Publishes a per-tensor dim-0 shard table
(which rank owns which rows). Expert weights use vLLM's FusedMoE global
numbering; everything else takes its dim-0 range from the DTensor
placement. Serves bf16 live storage when params are already bf16 (zero
per-sync work); casts local shards into a persistent bf16 buffer when
they are fp32 master weights (sharded cast, still no gather).

Worker (inference/vllm/worker/nixl.py): bakes lazy op-chains against the
live params, resolves each to a region of the full logical tensor, and
routes it onto the owning trainer shards — one batched NIXL READ per
trainer rank straight into live param storage. Then runs
process_weights_after_loading so any kernel-format repacking happens in
place (mirrors vLLM's reload). Removed _check_direct_write_safe: the
lazy-tensor + processing path is exactly what makes kernel formats work,
so quantized / non-identity MoE backends are no longer rejected.

New: weight_transfer/sharding.py (region->shard routing), chains.region_elem_runs.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
named_parameters(recurse=False) yields local names; the bake records
full dotted names, so the lookup never matched and
process_weights_after_loading was skipped for every module. Harmless for
identity (unquantized) backends but would skip required repacking for
kernel-format ones.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…supports online fp8

Switch the worker to the upstream-RDT mechanism (vLLM #43375) with pure
NIXL+MX: bake through vLLM's layerwise reload so each copy's destination is
recorded against the load-time (pre-process) layout, which for an online-fp8
model is bf16. Per sync, the sharded NIXL pull fills persistent bf16 staging
buffers, then each baked layer is driven through process_weights_after_loading
— re-quantizing bf16->fp8-blockwise into the persistent kernel storage, exactly
as a vLLM weight reload. For unquantized models the step is a no-op.

This replaces the direct-into-live-params shortcut (which skipped processing
and so couldn't support kernel formats). Staging buffers are registered once,
so the per-rank READ plan stays static.

Adds configs/qwen3_30b_math_nixl_2node/rl_fp8.toml (quantization=fp8_per_block).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant