Skip to content

Fix checkpoint loading with rerun state machine #4448

Open
YangFei1990 wants to merge 7 commits intoNVIDIA:mainfrom
YangFei1990:fix_state_machine
Open

Fix checkpoint loading with rerun state machine #4448
YangFei1990 wants to merge 7 commits intoNVIDIA:mainfrom
YangFei1990:fix_state_machine

Conversation

@YangFei1990
Copy link
Copy Markdown

What does this PR do ?

Fix silent checkpoint corruption when --ckpt-assume-constant-structure is combined with the rerun state machine. RerunStateMachine.state_dict() now always emits the rerun_state_machine_state ShardedObject (with a sentinel payload in the steady state), so the cached SavePlan stays valid across the transition from a normal save to a fault save.

Issue tracking

Linked issue: Fixes #4378

Root cause

With --ckpt-assume-constant-structure, TorchDistSaveShardedStrategy caches the SavePlan from the first save and reuses it for subsequent saves. In the steady state, RerunStateMachine.state_dict() returned None, so the cached plan had no entry for rerun_state_machine_state. When a fault later caused state_dict() to emit a ShardedObject, the stale cached plan silently dropped it:

  • common.pt correctly reflected that the rerun state machine was active (non-sentinel state, etc.).
  • The corresponding sharded payload was not written to disk and not recorded in .metadata.
    On load, the ShardedObject template could not be resolved against the checkpoint, failing with:
RuntimeError: Missing key in checkpoint state_dict: rerun_state_machine_state/shard__<world_size>

Fix (Option 2 from the issue)

RerunStateMachine.state_dict() now always returns a dict containing a ShardedObject whenever the machine is enabled and ckpt_format == "torch_dist". The ShardedObject's key / global_shape / global_offset are identical across every save; only the payload differs (sentinel in steady state, real fault context when a fault is pending). This makes the sharded structure constant, which is exactly what --ckpt-assume-constant-structure assumes.
Secondary change: _sanitize_data_iterators is now skipped in the steady state, so callers are not required to wrap their training iterator in RerunDataIterator for every normal save — that requirement only applies while a fault is mid-flight.
Backward compatibility:

  • RerunMode.DISABLED still returns None on the save path (no checkpoint bloat).
  • Non-torch_dist formats still return None on the save path.
  • force=True (load path) continues to produce a template regardless of mode/format, matching the existing load-side contract in megatron/training/checkpointing.py.
  • Old checkpoints load unchanged: validate_state_dict() still ignores checkpoints whose state == NOT_RUNNING_YET, so a sentinel entry is a no-op on load.

Note

Ideally we should implement option 1 mentioned in the issue to recreate the cache when there is any change in the checkpoint content, it is a more permanent fix. However it requires much more changes as the cache system is complicated and involves more components, checking the checkpoint signature change can be a non-trivial work and might require refactor/redesign. Thus for this specific bug purpose we implemented this immediate fix.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

@YangFei1990 YangFei1990 requested review from a team as code owners April 23, 2026 16:27
@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft April 23, 2026 16:27
@github-actions
Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 23, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Apr 23, 2026
@YangFei1990 YangFei1990 marked this pull request as ready for review April 23, 2026 16:28
@svcnvidia-nemo-ci svcnvidia-nemo-ci requested a review from a team April 23, 2026 16:28
@svcnvidia-nemo-ci svcnvidia-nemo-ci added Final Review PR is in the "final review" stage complexity: low labels Apr 23, 2026
@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Apr 23, 2026

/claude review

# In NOT_RUNNING_YET this is all zero/None defaults (a sentinel);
# in WILL_RERUN_FROM_CHECKPOINT it carries the real fault context. The
# ShardedObject key/shape/offset are identical in both cases. This is keep the
# checkpoint's sharded structure constant across saves (a
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: "This is keep" → "This keeps"

Suggested change
# checkpoint's sharded structure constant across saves (a
# ShardedObject key/shape/offset are identical in both cases. This keeps the

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suggested change looks to be on the wrong line, but correction is accurate, please change "This is keep" to "This keeps"

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Copy Markdown
Contributor

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix looks correct — the ShardedObject sentinel in steady state ensures the cached SavePlan stays structurally valid across the normal→fault transition. Tests cover the key invariants well. Left one minor typo nit on a comment.

@maanug-nv maanug-nv requested a review from sbak5 April 23, 2026 20:46
Copy link
Copy Markdown
Contributor

@maanug-nv maanug-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix and tests lgtm

@cyme
Copy link
Copy Markdown
Contributor

cyme commented Apr 23, 2026

FWIW the motivation for returning a None rerun state dict for "regular" checkpoint (state == NOT_RUNNING_YET) is to ensure backward compatibility, i.e. the ability to resume a regular checkpoint from a codebase commit prior to the introduction of the rerun state machine. This was a fix we did in this MR..

I just want to point out that this will no longer work with this change. But it's not clear it's still a requirement.

@YangFei1990
Copy link
Copy Markdown
Author

FWIW the motivation for returning a None rerun state dict for "regular" checkpoint (state == NOT_RUNNING_YET) is to ensure backward compatibility, i.e. the ability to resume a regular checkpoint from a codebase commit prior to the introduction of the rerun state machine. This was a fix we did in this MR..

I just want to point out that this will no longer work with this change. But it's not clear it's still a requirement.

@cyme The PR is back compatible, i.e. can still load checkpoint that is saved with an older Megatron version. But you are right it breaks the forward compatibility, which means if the user saves checkpoint with newer Megatron version includes this commit, but want to resume with a older Megatron version that does not have rerun state machine. I think it is a pretty uncommon use case.

@cyme
Copy link
Copy Markdown
Contributor

cyme commented Apr 23, 2026

Understood @YangFei1990 .

I think it is a pretty uncommon use case.

This was deemed important enough to get the issue fixed (this commit). I think several users had run into the issue and complained. But this was 9 months ago, and is probably no longer needed.

@svcnvidia-nemo-ci svcnvidia-nemo-ci added Approved All necessary approvals have been made and removed Final Review PR is in the "final review" stage labels Apr 24, 2026
@YangFei1990 YangFei1990 added this pull request to the merge queue Apr 24, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24868829952

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Apr 24, 2026
@YangFei1990 YangFei1990 added this pull request to the merge queue Apr 24, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24872343810

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Apr 24, 2026
@YangFei1990 YangFei1990 added this pull request to the merge queue Apr 24, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24873748604

@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24874613980

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Apr 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made complexity: low

Projects

None yet

Development

Successfully merging this pull request may close these issues.

--ckpt-assume-constant-structure causes missing shard files when RerunStateMachine saves a fault checkpoint

8 participants