feat: pull-based NIXL weight broadcast with lazy-tensor plan discovery#2779
Draft
S1ro1 wants to merge 8 commits into
Draft
feat: pull-based NIXL weight broadcast with lazy-tensor plan discovery#2779S1ro1 wants to merge 8 commits into
S1ro1 wants to merge 8 commits into
Conversation
Combines the approaches of vLLM PR #43375 (lazy-tensor weight-slice discovery) and prime-rl PR #2598 (pure NIXL + Model Express) into a converter-free, pull-based weight transport: - trainer serves its native state dict as a passive NIXL pull source: per-rank persistent registered store, refreshed layer-by-layer per sync, one TrainerTable published to Model Express; it never tracks consumers, so inference can scale/restart freely - each vLLM worker bakes its pull plan once by driving the model's own load_weights with zero-storage LazyWeight placeholders (op-chain recording), resolves chains to strided trainer regions, and issues batched NIXL READs straight into live parameter memory - generic metadata-only adapter bridges prime-rl's stacked-expert naming (experts.w1/w2/w3, router.gate) to HF per-expert names; no per-model conversion code anywhere in the path - per-sync coordination rides the existing STABLE/NCCL_READY markers plus a step-scoped NIXL_DONE marker; direct-write safety (unquantized, identity-format MoE backends, no EPLB) enforced at worker init Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…ommand pkill exits 1 when nothing matches; with set -e at the top level the whole job died before launch. Group them with the other pkills instead. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…XL broadcast The nixl wheel's bundled UCX has no CUDA support; both trainer (pull source) and inference (pull client) nodes must load the CUDA-enabled UCX from third_party/ucx and pin UCX_NET_DEVICES to working RDMA ports. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…ning Without the ack, the trainer process exits immediately after its final broadcast, tearing down its NIXL agents while workers are still pulling the last update (NIXL_ERR_REMOTE_DISCONNECT). The ack wait also guarantees the store is never refreshed under an in-flight pull. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…nel-format limitation Redesign per review: the transfer is now fully sharded and the false direct-write restriction is gone. Trainer (broadcast/nixl.py): registers each rank's LOCAL DTensor shards (to_local), never gathers. Publishes a per-tensor dim-0 shard table (which rank owns which rows). Expert weights use vLLM's FusedMoE global numbering; everything else takes its dim-0 range from the DTensor placement. Serves bf16 live storage when params are already bf16 (zero per-sync work); casts local shards into a persistent bf16 buffer when they are fp32 master weights (sharded cast, still no gather). Worker (inference/vllm/worker/nixl.py): bakes lazy op-chains against the live params, resolves each to a region of the full logical tensor, and routes it onto the owning trainer shards — one batched NIXL READ per trainer rank straight into live param storage. Then runs process_weights_after_loading so any kernel-format repacking happens in place (mirrors vLLM's reload). Removed _check_direct_write_safe: the lazy-tensor + processing path is exactly what makes kernel formats work, so quantized / non-identity MoE backends are no longer rejected. New: weight_transfer/sharding.py (region->shard routing), chains.region_elem_runs. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
named_parameters(recurse=False) yields local names; the bake records full dotted names, so the lookup never matched and process_weights_after_loading was skipped for every module. Harmless for identity (unquantized) backends but would skip required repacking for kernel-format ones. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
…supports online fp8 Switch the worker to the upstream-RDT mechanism (vLLM #43375) with pure NIXL+MX: bake through vLLM's layerwise reload so each copy's destination is recorded against the load-time (pre-process) layout, which for an online-fp8 model is bf16. Per sync, the sharded NIXL pull fills persistent bf16 staging buffers, then each baked layer is driven through process_weights_after_loading — re-quantizing bf16->fp8-blockwise into the persistent kernel storage, exactly as a vLLM weight reload. For unquantized models the step is a no-op. This replaces the direct-into-live-params shortcut (which skipped processing and so couldn't support kernel formats). Staging buffers are registered once, so the per-rank READ plan stays static. Adds configs/qwen3_30b_math_nixl_2node/rl_fp8.toml (quantization=fp8_per_block). Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Sharded, zero-copy weight transfer from trainer to vLLM over pure NIXL RDMA, with Model Express as the metadata store. Combines the lazy-tensor slice-discovery of vllm-project/vllm#43375 (mimicked with a custom worker extension — no Ray, no RDT) with the sharded pure-NIXL transport of #2598 — but driven entirely by the lazy op-chains, so there are no per-model conversion specs and no all-gather.
Design
Trainer = passive, sharded weight store (
trainer/rl/broadcast/nixl.py)to_local()) — never gathers, never converts naming/format. Serves the native prime-format state dict as-is.TrainerTableto Model Express: per state-dict tensor, which dim-0 row range lives on which rank's NIXL buffer. Everything is sharded on dim 0 (FSDP shards output/vocab, expert-parallel shards experts); expert weights use vLLM's FusedMoE global-expert numbering, everything else takes its range straight from the DTensor placement.optimization_dtypedefault), each rank casts only its local shard into a persistent bf16 buffer (a sharded cast, still no gather).vLLM worker = sharded pull client (
inference/vllm/worker/nixl.py)load_weightswith zero-storageLazyWeightplaceholders. vLLM's loaders (fused QKV, merged gate/up, FusedMoE EP routing) slice them andcopy_into views of live params; eachcopy_records(trainer tensor, op chain, destination view). Disallowed ops raise — loud failure over wrong bytes.weight_transfer/sharding.py). Each worker pulls only its slices from whichever ranks hold them — e.g. an EP worker pulls its experts from the matching trainer EP ranks; a tp=1 dense param is pulled row-block-by-row-block from every FSDP rank.process_weights_after_loadingso any kernel-format repacking happens in place — exactly as a normal vLLM weight reload. The receive is zero-copy: NIXL writes directly into the destination param (no intermediate buffer — the "zero-copy receive" the vLLM PR aspires to but couldn't get from Ray RDT).The only naming bridge is a generic, metadata-only adapter (
weight_transfer/adapter.py): prime's stacked experts (experts.w1/w2/w3) explode into per-expert HF lazy views, andmlp.router.gate→mlp.gate. Zero per-model code; dense models need nothing.No artificial limitations. Kernel/quantized formats are handled by running
process_weights_after_loadingafter the pull (the whole point of the lazy-tensor path) — there is no ban on quantized or non-identity MoE backends. The only hard requirements are an RDMA-registrable allocator (noexpandable_segments) and no EPLB (runtime expert rearrangement would invalidate the baked plan).Coordination rides the existing
STABLE/NCCL_READYmarkers; additions are a step-scopedNIXL_DONE(workers block on it before pulling) and per-rankNIXL_PULLEDacks (the trainer waits for all acks before returning, so no shard buffer is overwritten — or the trainer process exits — under an in-flight pull). Model Express carries only the one-time table.Deployment
[weight_broadcast] type = "nixl"(shared config propagates to trainer/orchestrator/inference).third_party/ucx, and pinsUCX_NET_DEVICESto active RDMA ports. Build the MX server binaries once viascripts/install_modelexpress.sh; build CUDA UCX + NIXL viascripts/install_nixl_from_source.sh.configs/qwen3_30b_math_nixl_2node/(NIXL + an NCCL-baseline twin).Validation
2-node (1 train + 1 infer) Qwen3-30B-A3B, EP=8 (
tp=1 dp=8 enable_expert_parallel), Hendrycks math, 10 steps, against an identical NCCL-broadcast baseline. Acceptance:Mismatch KLstable, not growing, and ≈ the NCCL baseline.(see Validation result below)
Tests
tests/unit/weight_transfer/test_pull_plan.py— CPU end-to-end of the sharded mechanism: lazy bake against vLLM-style loaders (fused QKV, EP expert routing with skips, TP source narrows, padded vocab), chain resolution, region→shard routing across multiple trainer ranks, run matching, byte-level pull simulation verified bit-identical; plus loud-failure paths (unsupported ops, copy-materializing chains, length mismatches). Full unit suite passes.🤖 Generated with Claude Code
Validation result ✅
2-node (1 train + 1 infer, 8 GPUs each), Qwen3-30B-A3B, EP=8 (
tp=1 dp=8 enable_expert_parallel), Hendrycks math, 10 steps. NIXL run vs an identical NCCL-broadcast baseline.RL trainer finished!), no NIXL errors.Mismatch KL — flat, not growing, ≈ baseline:
Online FP8 (block-wise) inference ✅
Re-validated with the worker switched to the upstream-RDT mechanism (bake through vLLM's layerwise reload +
process_weights_after_loading), pure NIXL+MX. Same 2-node EP=8 Qwen3-30B-A3B math run, inference under online block-FP8 (quantization="fp8_per_block", DeepGEMM block-FP8 MoE backend); trainer still streams bf16 shards and the inference side re-quantizes bf16→fp8-blockwise on every reload.RL trainer finished!.Mismatch KLover 10 steps: 0.0107, 0.0118, 0.0100, 0.0132, 0.0098, 0.0081, 0.0137, 0.0102, 0.0117, 0.0095 — stable, not growing, oscillating around a ~0.011 floor (the inherent gap between the bf16-precision trainer logprobs and the fp8-quantized inference weights; ~7× the bf16 path's 0.0015, as expected for fp8).The same code path covers both:
process_weights_after_loadingre-quantizes for fp8 and is a no-op for unquantized. Config:configs/qwen3_30b_math_nixl_2node/rl_fp8.toml.