Skip to content

feat: Support for dtensor ppo#2837

Open
fujial-code wants to merge 17 commits into
mainfrom
ppo-dtensor-path-a
Open

feat: Support for dtensor ppo#2837
fujial-code wants to merge 17 commits into
mainfrom
ppo-dtensor-path-a

Conversation

@fujial-code

@fujial-code fujial-code commented Jun 16, 2026

Copy link
Copy Markdown

What does this PR do ?

Adds DTensor/FSDP2 support for the PPO value model path and aligns the DTensor PPO 1.5B GSM8K recipe with the existing Megatron-Core PPO baseline.
The main correctness fix is a value temporal alignment issue: DTensor value inference/training was not using the same V(s_t) convention as the Megatron PPO path. After right-shifting token values to match Megatron's temporal semantics, the DTensor reward curve aligns with the Megatron baseline on the Qwen2.5-1.5B GSM8K 1-node recipe.

Issues

close [#2047]

Summary of Changes

DTensor PPO Value Model

  • Adds DTensorValueWorkerV2 for PPO value model training and inference on the DTensor/FSDP2 backend.
  • Wires DTensor value worker selection through LMValue.
  • Supports PPO value inference, value training, and checkpoint save/load paths for the DTensor backend.
  • Adds lifecycle compatibility with the PPO loop, including inference/training finish hooks where needed.

Value Temporal Alignment

  • Fixes the DTensor value convention to match the Megatron PPO path.
  • Applies right-shifted values so token-level value predictions correspond to V(s_t).
  • This was the key fix for the DTensor vs Megatron reward mismatch observed during GSM8K validation.

Recipe and Config Alignment

  • Adds the DTensor 1-node 8-GPU Qwen2.5-1.5B GSM8K PPO recipe:
    • examples/configs/recipes/llm/ppo-qwen2.5-1.5b-gsm8k-1n8g-dtensor.yaml
  • Aligns the DTensor recipe with the Megatron recipe for A/B comparison:
    • ppo_epochs: 1
    • math_verify_impl: hf_math_verify
    • matching value loss config structure
  • Removes the unused GSM8K verifier-specific path and uses the existing HF math verifier for the GSM8K recipe.

Runtime Compatibility Fixes

  • Updates DTensor worker call sites to match current Automodel API signatures.
  • Removes stale kwargs that no longer exist in the current forward/checkpoint APIs.
  • Passes checkpointing config when saving the PPO value model checkpoint.
  • Adds missing imports and no-op lifecycle methods needed by the PPO loop.

Tests

  • Unit tests: Added DTensor PPO value temporal alignment tests for _right_shift_values and _RightShiftLossWrapper.
  • Functional test: Added a 2-GPU DTensor PPO smoke test using Qwen2.5-0.5B on GSM8K for 2 steps with metric sanity checks.
  • Nightly test: Added a 1-node 8-GPU Qwen2.5-1.5B GSM8K DTensor PPO test, checking train/reward > 0.75 and validation/accuracy > 0.65 at step 40.
  • Local checks: Ran Python compile, shell syntax checks, nightly dry-run, and git diff --check.

Validation

Validated the DTensor PPO path against the Megatron-Core PPO baseline on the Qwen2.5-1.5B GSM8K 1-node 8-GPU setup.

Ran 100-step PPO training for both backends with matched recipe settings:

  • Model: Qwen2.5-1.5B-Instruct
  • Dataset: GSM8K
  • Cluster: 1 node, 8 GPUs
  • Backends:
    • DTensor/FSDP2 Automodel backend
    • Megatron-Core backend
  • Matched settings:
    • ppo_epochs: 1
    • math_verify_impl: hf_math_verify
    • matched value loss config
    • DTensor train/logprob micro batch size aligned to 4 for the final run

Results:

  • train/reward: DTensor and Megatron follow the same learning curve and both converge to around 0.88-0.90 by step 100.
  • validation/accuracy: DTensor and Megatron track closely across training and reach around 0.78-0.80 by step 100.
  • The reward and validation accuracy curves are aligned after fixing the value temporal shift and aligning the recipe settings.
    These results indicate that the DTensor PPO value path now matches the Megatron-Core PPO baseline behavior on the 1.5B GSM8K recipe.

@copy-pr-bot

copy-pr-bot Bot commented Jun 16, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@fujial-code fujial-code added the CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) label Jun 16, 2026
@fujial-code

Copy link
Copy Markdown
Author

/ok to test 980ef04

@fujial-code

Copy link
Copy Markdown
Author

/ok to test 980ef04

@copy-pr-bot

copy-pr-bot Bot commented Jun 16, 2026

Copy link
Copy Markdown

/ok to test 980ef04

@fujial-code, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@fujial-code

Copy link
Copy Markdown
Author

/ok to test 980ef04

@copy-pr-bot

copy-pr-bot Bot commented Jun 16, 2026

Copy link
Copy Markdown

/ok to test 980ef04

@fujial-code, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@fujial-code

Copy link
Copy Markdown
Author

/ok to test 8681202

@fujial-code

Copy link
Copy Markdown
Author

/ok to test 138ae84

yuki-97 and others added 8 commits June 16, 2026 20:25
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Co-authored-by: bg51717 <biguo@nvidia.com>
Co-authored-by: Lukasz Pierscieniewski <l.pierscieniewski@gmail.com>
Co-authored-by: Gerald Shen <geshen@nvidia.com>
Signed-off-by: Jia Liufu <fujial@nvidia.com>
Three small infra fixes required to get the DTensor V2 PPO surface working
end-to-end on top of bg51717/ppo:

- automodel/setup.py: optional import shim for NeMoAutoModelForTokenClassification
  (older nemo_automodel submodule on this worktree doesn't export the symbol,
  but the eager top-level import broke EVERY backend including Megatron).

- algorithms/loss/loss_functions.py: MseValueLossFn declares
  input_type = LossInputType.LOGIT and renames the param `values` -> `logits`
  to match the LossFunction protocol contract that automodel_forward_backward
  dispatches on.

- models/policy/workers/dtensor_policy_worker_v2.py: add no-op
  finish_training / finish_inference overrides so the PPO algorithm loop's
  per-step lifecycle calls don't AttributeError on the DTensor backend.

None of these change Megatron behavior.

Signed-off-by: Jia Liufu <fujial@nvidia.com>
Re-architect dtensor_value_worker_v2.py to mirror megatron_value_worker.py's
critic dynamics, so DTensor PPO matches Megatron's reward curve on the GSM8K
1n8g Qwen2.5-1.5B recipe.

Why: the original DTensor value worker used the HuggingFace `score` head as
the value head, trained under the same FSDP2 MixedPrecisionPolicy
(param_dtype=bf16) as the backbone. PPO's critic signal is narrow
(grad ~1e-3, update ~1e-5) and the bf16 Adam precision floor (~4e-5 ULP/2
near typical weight magnitudes) silently truncates per-step updates to zero
-> value/loss stuck ~12x Megatron -> miscalibrated advantage -> reward gap
~16-29% vs Megatron baseline.

What:
- Introduce a standalone fp32 ValueHead module (nn.Linear(hidden, 1) with
  fp32 autocast forward), separate from the HF `score` head.
- Freeze HF score (requires_grad_(False)) so the FSDP2 backbone optimizer
  doesn't update it; bypass it at inference.
- Dedicated torch.optim.AdamW for the ValueHead, so its m, v moments live in
  fp32 -> per-step update accumulates correctly even at lr=1e-5.
- Right-shift values by one position (V[t] = V(before t)) to align with
  Megatron's time convention.
- Manual DP all-reduce on ValueHead grads since it isn't inside FSDP.
- Side-car value_head.pt checkpoint for save/load.
- _backbone_forward uses register_forward_hook on the inner Qwen2Model to
  capture last_hidden_state while still routing through the wrapped
  top-level FSDP forward, so DTensor input-cast (Replicate on input_ids
  before sharded-embedding matmul) still happens. Bypassing the top wrapper
  would crash with `aten.embedding.default got mixed torch.Tensor and DTensor`.
- lm_value.py: minor surface adjustments to keep the value interface contract
  consistent with the new worker.

Result on 100-step GSM8K 1n8g (Qwen2.5-1.5B) vs Megatron A/B twin:
  reward       end-of-train: dtensor 0.873-0.882  vs megatron 0.876  (<1% diff)
  value/loss   end-of-train: dtensor ~0.030       vs megatron ~0.008 (~4x;
               residual bf16 backbone precision -- does not affect reward
               because advantage normalization removes the constant bias
               before policy update)
  policy/loss  stable, no blow-up; magnitudes differ from megatron because
               the steeper fp32 advantage signal produces a larger but
               equally-directional policy gradient.

Megatron value worker is untouched; this only changes
dtensor_value_worker_v2.py + the small lm_value.py surface tweak.

Signed-off-by: Jia Liufu <fujial@nvidia.com>
- New recipe: examples/configs/recipes/llm/ppo-qwen2.5-1.5b-gsm8k-1n8g-dtensor.yaml
  DTensor V2 twin of the existing Megatron 1n8g GSM8K recipe. Pins
  policy/value optimizer LR to 1e-6 / 1e-5 to match the Megatron twin
  (the value worker rewrite alone isn't enough -- without LR parity the
  DTensor policy LR was 2x slower).

- Megatron recipe: wandb naming aligned to the same project
  (nemo-rl-ppo-port-v1) with distinct -dtensor / -megatron run suffixes,
  so the two backend runs render side-by-side as A/B in wandb. No training
  semantics changed.

- New guide: docs/guides/ppo-dtensor.md
  Quick-start, recipe pins, value-worker architecture note, verification
  table vs Megatron, and known limitations (vLLM seed nondeterminism +
  residual value/loss gap explanation).

- docs/index.md: add the new guide to the Guides toctree under ppo.md.

Signed-off-by: Jia Liufu <fujial@nvidia.com>
…d-bearing

Path A's standalone fp32 ValueHead + dedicated optimizer + FSDP-bypass hook
+ side-car checkpoint manager all worked, but ablation showed only the
right-shift was actually closing the reward gap. This commit replaces the
Path A scaffolding with the pre-Path-A DTensor value worker + a 3-line
right-shift fix, matching megatron_value_worker.py's V(s_t) temporal
semantics exactly.

Causality evidence from parallel ablations:

  12235997 (Path A + ABLATION_HEAD_BF16=1 = bf16 head + right-shift ON)
    -> reward 0.37 -> ~0.89 over 100 steps (matches megatron A/B twin)
  12237461 (Path A + ABLATION_NO_RIGHT_SHIFT=1 = fp32 head, right-shift OFF)
    -> reward 0.37 -> ~0.79 over 100 steps (loses ~10 pts -- the original gap)

So right-shift is the load-bearing fix; the Path A precision / optimizer /
FSDP machinery is not needed for parity.

Tree changes (this commit's content == cherry-pick of true-minimal 27b3b7e):

* nemo_rl/models/value/workers/dtensor_value_worker_v2.py
  -> replaced Path A worker (~700 lines, standalone fp32 ValueHead +
     dedicated optimizer + FSDP-bypass + side-car checkpoint) with the
     pre-Path-A baseline worker (HF score head inside FSDP, backbone
     optimizer, standard checkpoint), plus the right-shift fix wired into
     train() via _RightShiftLossWrapper and into get_values() via
     _right_shift_values.

* nemo_rl/models/value/lm_value.py
  -> Path A's special-case dispatch is no longer needed; restored the
     simple "DTensor enabled -> DTensorValueWorkerV2" path.

* docs/guides/ppo-dtensor.md (DELETED) and docs/index.md (REVERTED)
  -> Path A docs no longer apply.

* examples/configs/recipes/llm/ppo-qwen2.5-1.5b-gsm8k-1n8g-megatron.yaml
  -> reverted Path A's LR pinning; the dtensor twin recipe (kept) is
     where the matched LRs live.

* examples/configs/recipes/llm/ppo-qwen2.5-1.5b-gsm8k-1n8g-dtensor.yaml
  (KEPT, identical content) -- pins policy/value LR to 1e-6 / 1e-5 to
  match the megatron A/B twin.

Verified end-to-end by Job C v5 (12248222): exit 0 after 100 steps,
final reward ~0.89, matching megatron baseline (12236465) and the Path A
ablation (12235997).

The 4 follow-up commits in this series fix latent API drift between this
fork's pre-Path-A worker and current automodel APIs (those bugs were
masked by lm_value.py's NotImplementedError guard before).

Signed-off-by: Jia Liufu <fujial@nvidia.com>
…t init

Job 12246618 failed at worker startup with:

  TypeError: AutomodelCheckpointManager.__init__() got an unexpected
  keyword argument 'model_state_dict_keys'

Pre-Path-A worker (blob 509625e, inherited from upstream bg51717/ppo)
called AutomodelCheckpointManager(dp_mesh, tp_mesh, model_state_dict_keys,
moe_mesh). This fork's signature in nemo_rl/models/automodel/checkpoint.py
is just (dp_mesh, tp_mesh, moe_mesh) -- the model_state_dict_keys param
was either dropped or never landed on this branch.

Path A's worker (current ablation HEAD) already omits the same kwarg,
so this is a known divergence with a known fix. Mirroring that here.

All other AutomodelCheckpointManager method signatures (save_checkpoint,
load_checkpoint, init_checkpointer) match the pre-Path-A call sites
verbatim, so this is the only signature-drift fix needed.

Signed-off-by: Jia Liufu <fujial@nvidia.com>
Job 12246898 failed at value worker __init__ with:
  TypeError: setup_model_and_optimizer() got an unexpected keyword
  argument 'distributed_manager'. Did you mean 'distributed_context'?

Path A's worker (current ablation HEAD) documents 3 signature-drift
fixes its own __init__ needed against this fork's APIs. Apply the same
3 fixes to the pre-Path-A worker (mechanical mirror of Path A's notes,
no semantic divergence):

1. setup_model_and_optimizer(distributed_manager=...)
   -> setup_model_and_optimizer(distributed_context=...)
   (this fork renamed the kwarg in PR #2027)

2. ModelAndOptimizerState unpacking: 11 slots -> 10
   Drop self.model_state_dict_keys -- this fork's ModelAndOptimizerState
   NamedTuple (nemo_rl/models/automodel/config.py) has 10 fields and
   does not include model_state_dict_keys. Without this:
     ValueError: not enough values to unpack (expected 11, got 10)

3. RuntimeConfig unpacking: 12 slots -> 13
   Add _runtime_sampling_params slot before _runtime_is_reward_model.
   This fork's RuntimeConfig has 13 fields with sampling_params right
   before is_reward_model. bg51717/ppo stripped sampling_params from
   train()/get_values() call sites too (which is what we want and what
   the pre-Path-A worker already has), so we discard the value into
   _runtime_sampling_params rather than assigning self.sampling_params.
   Without this:
     ValueError: too many values to unpack (expected 12)

All 3 fixes are mechanical and documented verbatim in Path A's worker
NOTE comments (commit fb38a59 lines 297-339).

Signed-off-by: Jia Liufu <fujial@nvidia.com>
… calls

Job 12247164 reached value-worker init (3 prior signature fixes worked),
got into the rollout/get_values cycle, then crashed with:

  TypeError: forward_with_post_processing_fn() got an unexpected
  keyword argument 'cfg'

Pre-Path-A worker passed `cfg=self.cfg` to two functions that this fork
no longer accepts it on (PR #2027 stripped the param):

- automodel_forward_backward (train() microbatch loop, line 437)
- forward_with_post_processing_fn (get_values() forward, line 561)

LossPostProcessor.__init__ and ScorePostProcessor.__init__ DO still
accept cfg, so those call sites are left alone.

This is the same pattern as the previous setup_model_and_optimizer +
ModelAndOptimizerState/RuntimeConfig fixes -- pre-Path-A worker was
inherited from upstream bg51717/ppo and never re-tested against this
fork's APIs. Path A worked around it by writing a custom microbatch
loop that didn't go through these functions at all.

Signed-off-by: Jia Liufu <fujial@nvidia.com>
Signed-off-by: Jia Liufu <fujial@nvidia.com>
@fujial-code

Copy link
Copy Markdown
Author

/ok to test 125a14e

Signed-off-by: Jia Liufu <fujial@nvidia.com>
@fujial-code

Copy link
Copy Markdown
Author

/ok to test a9529c4

Signed-off-by: Jia Liufu <fujial@nvidia.com>
@fujial-code

Copy link
Copy Markdown
Author

/ok to test 66ccdb3

Signed-off-by: fujial-code <fujial@nvidia.com>
@fujial-code fujial-code marked this pull request as ready for review June 18, 2026 09:59
@fujial-code fujial-code requested review from a team as code owners June 18, 2026 09:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[dtensor] PPO

3 participants