Skip to content

Asyncrl/sc sync weights#2831

Open
mehraakash wants to merge 29 commits into
NVIDIA-NeMo:mainfrom
mehraakash:asyncrl/sc_sync_weights
Open

Asyncrl/sc sync weights#2831
mehraakash wants to merge 29 commits into
NVIDIA-NeMo:mainfrom
mehraakash:asyncrl/sc_sync_weights

Conversation

@mehraakash

Copy link
Copy Markdown

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

mehraakash and others added 29 commits June 8, 2026 15:29
…orker

Adds begin_train_step / train_microbatch / finish_train_step / abort_train_step
on MegatronPolicyWorkerImpl, mirroring the DTensor v1/v2 implementations but
adapted for mcore's contiguous grad bucket + pipeline-schedule reduce path.

Mechanism:
- begin_train_step: zero_grad_buffer + optimizer.zero_grad, store loss_fn /
  gbs / mbs / local_valid_seqs/toks accumulators on _train_step_state, and
  null model.config.grad_sync_func (saved for restore) so the PP scheduler's
  direct reduce dispatch cannot bypass no_sync.
- train_microbatch(data): wrap one ``megatron_forward_backward`` invocation
  in ``with self.model.no_sync():`` so mcore DDP hooks accumulate
  ``param.main_grad`` locally without dispatching the cross-DP reduce.
  Pass ``global_valid_seqs/toks=tensor(1.0)`` so the loss returns
  un-normalized sums; backward deposits raw d(sum)/dθ. Accumulate local
  mask sums + per-mb metrics + the total pipeline-microbatch count
  (for finish-time MoE aux-loss scaling).
- finish_train_step: all_reduce mask sums to get true N (toks for
  TOKEN_LEVEL loss, seqs for SEQUENCE_LEVEL), call
  self.model.scale_gradients(1/N), then the one true cross-DP reduce via
  start_grad_sync + finish_grad_sync, optimizer.step (clips internally),
  restore grad_sync_func, scheduler.step(increment=gbs). Rescale per-mb
  metrics by 1/N (linear-in-1/N math), aggregate, surface global counts.
- abort_train_step: restore grad_sync_func, zero_grad_buffer + zero_grad,
  drop state. ``trainer_version`` unchanged.

Sync ``train()`` is left untouched.

Includes CPU unit tests at tests/unit/models/policy/test_megatron_split_state.py
covering the lifecycle and call-order invariants (no_sync wrap,
grad_sync_func save/restore, mask-sum accumulation, N selection by
loss_type, abort idempotence, MoE scaling). Marked pytest.mark.mcore so
they run only in mcore-enabled CI containers.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Pre-existing zero-error file from NVIDIA-NeMo#2078 (Eagle3) that was never added
to the project-includes whitelist. Carrying the fix forward in this
PR to unblock the lint job.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
The file is introduced by NVIDIA-NeMo#2692 (DTensor PR), not by this branch.
Whitelisting it here causes pyrefly to fail with 'No Python files
matched pattern' since the file does not exist on mcore.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
cloudpickle traverses globals/closures of each method when serializing
the Ray actor class. With torch 2.11, 'config' in __code__.co_names
matches torch.distributed.config (a non-pickleable ConfigModuleInstance),
breaking actor creation with:
  TypeError: cannot pickle 'ConfigModuleInstance' object
  Could not serialize the actor class ...MegatronPolicyWorker

Same workaround as the existing sync train(): read 'config' via
getattr-by-string in begin/finish/abort_train_step.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
test_megatron_split_state.py eagerly imports megatron_policy_worker
which transitively imports megatron.bridge. In non-mcore shards (Models,
Vllm, Sglang, Automodel_Policy), megatron.bridge isn't installed so
collection of this file fails, killing every other test in the shard.

pytest.importorskip stops collection cleanly when megatron.bridge is
not available. The pytest.mark.mcore filter still ensures these tests
only run in mcore shards.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
`finish_train_step` was calling both `model.start_grad_sync()` and
`model.finish_grad_sync()` unconditionally. In megatron-core's BucketGroup,
when `overlap_grad_reduce=False` (the production default), `finish_grad_sync`
internally invokes `start_grad_sync(force_all_reduce=True)`, so calling
`start_grad_sync` ourselves first dispatches the synchronous collective
once, and `finish_grad_sync` dispatches it again — scaling the reduced
gradient by ~world_size. The optimizer then steps with grads that are
DP-world-size too large.

Parity tests pass because they assert call order / state shape, not
numerical equivalence at world_size > 1. The bug surfaces only at runtime.

Fix: only call `start_grad_sync` explicitly when `overlap_grad_reduce=True`
(the async path that needs an explicit dispatch since `model.no_sync()`
gates off the per-microbatch hooks). For `overlap_grad_reduce=False`,
`finish_grad_sync` handles everything.

Tests:

- Parametrize `test_start_then_finish_grad_sync_called_after_rescale` on
  `overlap_grad_reduce` and assert the correct call order in each case.
- Add `distributed_data_parallel_config` to the test cfg fixture so the
  worker's branch can read the flag.

Reported by terrykong in PR NVIDIA-NeMo#2683 review.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Two bugs around the split-API state machine fixed in one pass:

1. **no_sync_func leak (latent assertion on step 2 with PP=1).**
   `forward_backward_no_pipelining` (the PP=1 path, which is the common
   one) wraps inner microbatches in `model.config.no_sync_func` but runs
   the *last* microbatch OUTSIDE of it. Our outer `with self.model.no_sync():`
   in `train_microbatch` was therefore bypassed for the trailing MB.
   `register_grad_ready` leaked per-param counts past the explicit DP
   sync at `finish_train_step`, and the next `begin_train_step` then
   asserted on the stale counts (typically step 2).

   Fix: in `begin_train_step` also save and null `no_sync_func` (set to
   `contextlib.nullcontext`). Restore both hooks at finish/abort via the
   new `_restore_saved_grad_sync_func` helper.

   Spotted in yuki-97's PR NVIDIA-NeMo#2819 (`begin_train_step` also nulls
   `config.no_sync_func`); same fix landed here.

2. **No exception safety around the open step (terrykong, NVIDIA-NeMo#2683:640).**
   If `train_microbatch` or `finish_train_step` raised mid-body, both
   `grad_sync_func` and (now also) `no_sync_func` would stay nulled and
   future steps would run with the PP scheduler bypass disabled silently.

   Fix: extract `_train_microbatch_body` and `_finish_train_step_body`,
   wrap each entry method in try/except that calls
   `_restore_saved_grad_sync_func` before re-raising. Caller is still
   expected to invoke `abort_train_step` (idempotent on the saved
   values) to drop `_train_step_state`.

3. **Cleanup (terrykong, NVIDIA-NeMo#2683:582).** Drop dead `no_sync_active` field
   from `_split_step_state_init` — never read or written; `no_sync` is
   applied via the `with self.model.no_sync():` context manager.

Also: add module-level `log = logging.getLogger(__name__)` so the
try/except handlers can `log.exception(...)` if the restore itself fails.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
The split-API runs each microbatch with global_valid_*=1 (raw sums) and
applies the 1/N rescale once at finish, but the previous code used a
single inv_n for every non-min/max metric. ClippedPGLossFn normalizes
different metrics by different denominators — and some not at all — so
the single-scalar rescale was wrong for:

- Always-toks-normalized metrics under a SEQUENCE_LEVEL loss
  (token_mult_prob_error, gen_kl_error, policy_kl_error,
  js_divergence_error, approx_entropy, probs_ratio, probs_ratio_clamped).
  These came out off by ~average-sequence-length.
- Raw count metrics (num_valid_samples, num_unmasked_tokens) — these are
  ABSOLUTE counts the loss never normalized; the previous code divided
  them by N (e.g. 8.0 came out as 0.0039). Sync path passes them through
  via num_global_batches=1 → identity, then grpo.py's downstream reducer
  sums across microbatches for the global count.

Gradients are unaffected — they derive only from `loss`, whose
normalizer matches inv_n (loss-type-aware).

Fix: add module-level `_METRIC_NORMALIZATION_KIND` mapping each known
metric name to one of {toks, seqs, loss_type, none, skip}. In
`finish_train_step`, look up each metric's kind and apply the right
scale. Unknown metric names default to "loss_type" — preserves prior
behavior for any metric not in the table.

The catalog follows terrykong's review on NVIDIA-NeMo#2683 (`:789`/`:845`).
sampling_importance_ratio and is_oob_ratio have a flag-dependent
denominator (sequence_level_importance_ratios / sequence_level_loss_mask
toggles); for now they default to "loss_type" with a TODO noting the
seq-mask-tis + token-level mismatch.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Two small follow-ups to terrykong's NVIDIA-NeMo#2683 review:

- :832 — capture curr_lr / curr_wd BEFORE scheduler.step so the per-mb
  metrics carry the value of THIS step, not the next. Drops the
  duplicate capture from inside the rescale loop.
- :538 — drop the misleading "Mirrors the v1/v2 implementations" line
  from the split-API block docstring; v1/v2 don't define these methods.
  Replaced with a brief surface description (begin / train_microbatch /
  finish / abort) before the mcore-specific details.

No behavior change.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Adds the DTensor v1 + v2 backend implementations of the split-API methods
(begin_train_step / train_microbatch / finish_train_step / abort_train_step)
introduced in the mcore PR, plus the PolicyTrainerActor Ray wrapper that
exposes the API to SingleController via ``.remote(...)``.

- DTensor v1: state machine using the N=1 placeholder trick. Per microbatch
  the loss is called with ``global_valid_*=tensor(1.0)`` so backward
  deposits un-normalized gradients; ``finish_train_step`` all_reduces local
  mask sums to recover the true N (toks for TOKEN_LEVEL, seqs for
  SEQUENCE_LEVEL), rescales ``p.grad`` by 1/N, runs grad_norm/clip,
  optimizer.step, scheduler.step. Bin iteration with DP-rank dummy-bin
  padding handles seq-packing / dynamic-batching uneven splits without
  desyncing NCCL.

- DTensor v2: same shape, built on the v2 helpers (LossPostProcessor,
  get_microbatch_iterator, automodel_forward_backward,
  scale_grads_and_clip_grad_norm). The clip helper has MoE/EP awareness;
  the manual 1/N rescale runs before it.

- PolicyTrainerActor: ``@ray.remote(num_cpus=1, num_gpus=0)`` wrapper that
  owns a TQPolicy instance and exposes ``train_from_meta`` (sync proxy),
  ``prepare_logprobs_from_meta``, and the split API. ``trainer_version``
  advances on sync train_from_meta or on ``finish_train_step`` (never on
  abort).

- GPU parity tests at tests/unit/models/policy/test_split_train_step.py:
  numerical equivalence vs sync ``train()`` for token-level and seq-level
  losses, multi-bin-per-call under seq-packing, and split-state-machine
  lifecycle (double begin, abort idempotence). Parameterised over v1/v2.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
The file exists on this branch (added by the dtensor split-API work)
and pyrefly reports 0 errors for it, so it must be in the whitelist.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Update PolicyTrainerActor.prepare_logprobs_from_meta to accept
refresh_policy_logprobs / refresh_reference_logprobs kwargs and only
dispatch the requested calls. Mirrors the dispatcher signature added to
TQPolicy in NVIDIA-NeMo#2700; SingleController in NVIDIA-NeMo#2700 passes config-driven flags
based on advantage_policy_logprobs_field / advantage_reference_logprobs_field.

Body inlined here (calls TQPolicy.get_logprobs_from_meta and
get_reference_policy_logprobs_from_meta directly) so this PR stays
self-contained without depending on NVIDIA-NeMo#2700's TQPolicy helper. Once both
PRs merge, this can collapse to a one-line delegate.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Three issues caught running tests/unit/models/policy/test_split_train_step.py
locally on the v0.5.0 container:

- test_split_train_step_parity_seq_packing was missing several config
  keys the seq-packing path requires (make_sequence_length_divisible_by,
  algorithm, sequence_length_round, logprob_mb_tokens). It also ran at
  fp32 but the seq-packing path goes through FlashAttention which
  requires fp16/bf16; switched to bfloat16.
- All three parametrized cases passed tokenizer=None; replaced with
  get_tokenizer(config["tokenizer"]) so PolicyConfig wiring works.
- _drive_split_path bypassed Policy._shard_for_train and called
  train_microbatch on each worker directly with the unsharded data, so
  the worker iterators received batches without micro_batch_indices
  populated. Now shards via run_all_workers_sharded_data — each DP rank
  gets its corresponding shard.

All six parametrized cases (token_level / seq_packing / state_machine x
v1 / v2) pass locally on the v0.5.0 container.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Drives the begin/train_microbatch/finish split API in NVIDIA-NeMo#2683 and
group, per-group prepare_logprobs (when configured) -> advantage_pump
-> train_microbatch_from_meta (queued), one finish_train_step + one
clear_samples at end-of-step. Buffer capacity released per group, not
per step.

- StalenessSampler.select_one_group: picks one eligible prompt group;
  same predicate as select_indices, sort by (lag, indices[0]).
- SingleControllerConfig.target_prompt_groups_per_step: explicit per-
  step admission count; validated against min_prompt_groups_per_batch.
- _reap_in_flight_nonblocking: ray.wait(timeout=0) drain helper.
- DryRunTrainer: split-API stub with begin/microbatch/finish/abort
  invariants for dry-run tests.
- 7 streaming dry-run tests: arrival order, finish-time trainer_version
  tick, strict on-policy filter, long-tail overlap, abort idempotence,
  empty-step no-op, single clear_samples per step.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Required by 'Check if any files with zero errors not in whitelist'
guard in cicd-main.yml. Both files have zero pyrefly errors; without
the whitelist entry the lint job fails.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
- target_prompt_groups_per_step defaults to min_prompt_groups_per_batch
  when unset, preserving backwards compat with tests that exercised the
  old sync semantics (min=1 → target=1, not the new 8 default).
- _reap_in_flight_nonblocking switched from ray.wait(timeout=0) to
  asyncio.wait on ensure_future(ref) — ray.wait inside an async actor
  did not reliably reflect cross-process ref readiness, leaving in-flight
  refs misreported as pending.
- Inner _train_pump loop yields with asyncio.sleep(0) at the top of
  each iteration to keep the actor's event loop responsive under fast
  microbatch cycling.
- DryRunGenWorker: capture _call_count before sleep to avoid the
  asyncio race that interleaved two parallel rollout invocations into
  the same group_id (only group-0002 was committed).
- DryRunTrainer.train_microbatch_from_meta also appends to
  _train_start_times so streaming-path tests see when the trainer began.
- Ping threshold loosened from 1.0s to 3.0s — streaming dispatches more
  RPCs than the sync path; the test still asserts liveness, not strict
  latency.
- DryRunReapInFlightHelperActor mirrors the production helper's
  asyncio.wait pattern.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
pyrefly: '<' is not supported between Literal[0] and None — the field
is Optional[int] until __init__ coerces it to int. Add an assert at
the inner loop to narrow the type for the type-checker.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
TensorDict's source param is typed as dict[NestedKey, CompatibleType];
our dict[str, torch.Tensor] is operationally compatible. Adding pyrefly
ignore comments instead of reshaping the dict to a TensorDict-internal
type.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
- Thread loss_fn through SingleController + begin_train_step. Real bug
  against TQPolicy / MegatronPolicyWorker whose begin_train_step requires
  loss_fn with no default; previously masked by the dryrun stub.
- SC owns trainer_version: drop result["trainer_version"] read (key not
  emitted by real backends); bump immediately after finish_train_step
  succeeds, before clear_samples, so SC's counter matches worker state
  even if clear_samples fails.
- Add TQPolicy.prepare_logprobs_from_meta(meta, refresh_policy_logprobs=,
  refresh_reference_logprobs=) dispatcher over the existing
  get_*_from_meta methods. SC wires it from the per-group hook with
  config-driven flags. PolicyTrainerActor signature mirrored in NVIDIA-NeMo#2692.
- Wrap _train_pump per-step body in try/except: on mid-step worker
  failure, call abort_train_step on the worker and clear the
  consumed-ids ledger before re-raising. TODO: retry policy is a
  follow-up.
- Document max_concurrency=1 contract in _train_pump docstring; set it
  explicitly on MegatronPolicyWorker rather than relying on Ray's
  sync-actor default.
- Migrate SingleControllerConfig from @DataClass to
  pydantic.BaseModel(extra="allow") to match the repo's MasterConfig
  style. Drop the extra: dict workaround field and the runtime Literal
  assert (pydantic validates at construction).
- Move five cross-actor TQ schema field names to module-level constants
  (PROMPT_IDS_FIELD, REWARD_FIELD, TOKEN_MASK_FIELD, SAMPLE_MASK_FIELD,
  ADVANTAGE_OUTPUT_FIELD). They're a fixed protocol with the rollout
  actor, not user-tunable.
- Drop dead RAY_TEMP_DIR macOS workaround from the dryrun test header
  (Linux CI doesn't need it; mac devs can set RAY_TMPDIR in shell).
- Tighten reap timeout to 0.0 for a true non-blocking peek; was 0.05.
- print -> log.info for the strict_on_policy auto-set notice.
- Add dryrun test exercising prepare_logprobs_from_meta with both flags
  set, only policy set, and neither set.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Whitespace-only fix flagged by `uv run --group dev ruff format --check`
after the BaseModel migration in the previous commit. No behavior change.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Make the looping similar to sync path.

Changes:

- SingleControllerConfig: add `train_global_batch_size` (Optional[int],
  coerced to samples_per_step when None to preserve current behavior).
- _train_pump: wrap the per-step body in `for mb_idx in range(num_minibatches)`.
  Each iteration is one begin / microbatch×K / finish cycle = one opt.step,
  where K = train_global_batch_size // generations_per_prompt.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Mirror grpo_sync's logging surface so SC produces the same train metrics
and timing telemetry under W&B / TB. No-op if no logger is attached
(dryrun tests don't need one).

Phase 1 — log_metrics:
- Add `logger` (Optional) param to SingleControllerActor.__init__.
- Store as self._logger.
- Capture train_results returned by finish_train_step (was discarded).
- Per opt.step: logger.log_metrics(train_results, step=trainer_version,
  prefix="train") matching grpo_sync.py:1282.

Phase 2 — Timer:
- Import nemo_rl.utils.timer.Timer; instantiate self._timer when logger
  is attached, else None.
- _timed(label) helper returns timer.time(label) ctx-mgr or nullcontext().
- Wrap key phases in _train_pump: prepare_logprobs, advantage_pump,
  policy_training (the finish_train_step call), clear_samples,
  sync_weights.
- After each opt.step: logger.log_metrics(timer.get_timing_metrics(),
  step=trainer_version, prefix="timing/train", step_finished=True);
  timer.reset() so the next opt.step measures only its own phases.

Errors from logger calls are caught and logged (log.exception) so a
broken logger backend doesn't fail the training loop.

Deferred (Phase 3+): rollout metrics aggregation, validation hook,
performance/TFLOPS metrics, jsonl batched dumps, W&B generation samples,
memory_tracker snapshots. Each has design decisions (validation cadence
in an async flow, sampling vs every-step, etc.) that warrant separate
follow-ups.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
SC was calling `weight_synchronizer.sync_weights(trainer_version)` but
the production `WeightSynchronizer.sync_weights(*, timer=None, kv_scales=None)`
interface takes no version arg — transport-specific impls (IPC/HTTP/NCCL)
own the transfer. The version was a vestige of the dryrun stub.

Changes:
- _sync_weights calls `sync_weights()` with no args, matching the
  canonical WeightSynchronizer ABC.
- Explicit version propagation: after sync, `await gen.set_weight_version
  .remote(self._trainer_version)` stamps the rollout actor so newly-
  produced KVBatchMeta tags carry the right weight_version. Mirrors
  grpo.py:2903's `trajectory_collector.set_weight_version(...)` pattern
  and Yuki's NVIDIA-NeMo#2819 `rollout_manager.set_weight_version(new_v)`.
- getattr-guard + try/except around set_weight_version: best-effort,
  log + continue if the gen actor doesn't expose the setter (older
  interfaces or alternate stubs).

DryRunWeightSynchronizer:
- Signature updated to match: `sync_weights()` no longer takes
  trainer_version. gen_handle constructor arg retained (and deleted in
  __init__) so existing fixtures don't have to change, but the stub no
  longer touches gen — SC drives propagation.

All 26 dryrun tests pass in the v0.5.0 container.

This change targets the SC consumer (PR NVIDIA-NeMo#2700) — not the trainer side.
Built as a follow-up layer (branch asyncrl/sc_sync_weights) on top of
asyncrl/split_train_sc so NVIDIA-NeMo#2700 review can proceed independently.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash mehraakash requested a review from a team as a code owner June 16, 2026 04:01
@mehraakash mehraakash requested review from a team as code owners June 16, 2026 04:01
@copy-pr-bot

copy-pr-bot Bot commented Jun 16, 2026

Copy link
Copy Markdown

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants