Fix checkpoint loading with rerun state machine #4448
Fix checkpoint loading with rerun state machine #4448YangFei1990 wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
|
This PR has been automatically converted to draft because all PRs must start as drafts. When you are ready for review, click Ready for Review to begin the review process. This will:
See the contribution guide for more details. |
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
/claude review |
| # In NOT_RUNNING_YET this is all zero/None defaults (a sentinel); | ||
| # in WILL_RERUN_FROM_CHECKPOINT it carries the real fault context. The | ||
| # ShardedObject key/shape/offset are identical in both cases. This is keep the | ||
| # checkpoint's sharded structure constant across saves (a |
There was a problem hiding this comment.
Typo: "This is keep" → "This keeps"
| # checkpoint's sharded structure constant across saves (a | |
| # ShardedObject key/shape/offset are identical in both cases. This keeps the |
There was a problem hiding this comment.
The suggested change looks to be on the wrong line, but correction is accurate, please change "This is keep" to "This keeps"
|
FWIW the motivation for returning a I just want to point out that this will no longer work with this change. But it's not clear it's still a requirement. |
@cyme The PR is back compatible, i.e. can still load checkpoint that is saved with an older Megatron version. But you are right it breaks the forward compatibility, which means if the user saves checkpoint with newer Megatron version includes this commit, but want to resume with a older Megatron version that does not have rerun state machine. I think it is a pretty uncommon use case. |
|
Understood @YangFei1990 .
This was deemed important enough to get the issue fixed (this commit). I think several users had run into the issue and complained. But this was 9 months ago, and is probably no longer needed. |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24868829952 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24872343810 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24873748604 |
|
🔄 Merge queue validation started! You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/24874613980 |
What does this PR do ?
Fix silent checkpoint corruption when
--ckpt-assume-constant-structureis combined with the rerun state machine.RerunStateMachine.state_dict()now always emits thererun_state_machine_stateShardedObject(with a sentinel payload in the steady state), so the cachedSavePlanstays valid across the transition from a normal save to a fault save.Issue tracking
Linked issue: Fixes #4378
Root cause
With
--ckpt-assume-constant-structure,TorchDistSaveShardedStrategycaches theSavePlanfrom the first save and reuses it for subsequent saves. In the steady state,RerunStateMachine.state_dict()returnedNone, so the cached plan had no entry forrerun_state_machine_state. When a fault later causedstate_dict()to emit aShardedObject, the stale cached plan silently dropped it:common.ptcorrectly reflected that the rerun state machine was active (non-sentinelstate, etc.)..metadata.On load, the
ShardedObjecttemplate could not be resolved against the checkpoint, failing with:Fix (Option 2 from the issue)
RerunStateMachine.state_dict()now always returns a dict containing aShardedObjectwhenever the machine is enabled andckpt_format == "torch_dist". TheShardedObject'skey/global_shape/global_offsetare identical across every save; only the payload differs (sentinel in steady state, real fault context when a fault is pending). This makes the sharded structure constant, which is exactly what--ckpt-assume-constant-structureassumes.Secondary change:
_sanitize_data_iteratorsis now skipped in the steady state, so callers are not required to wrap their training iterator inRerunDataIteratorfor every normal save — that requirement only applies while a fault is mid-flight.Backward compatibility:
RerunMode.DISABLEDstill returnsNoneon the save path (no checkpoint bloat).torch_distformats still returnNoneon the save path.force=True(load path) continues to produce a template regardless of mode/format, matching the existing load-side contract inmegatron/training/checkpointing.py.validate_state_dict()still ignores checkpoints whosestate == NOT_RUNNING_YET, so a sentinel entry is a no-op on load.Note
Ideally we should implement option 1 mentioned in the issue to recreate the cache when there is any change in the checkpoint content, it is a more permanent fix. However it requires much more changes as the cache system is complicated and involves more components, checking the checkpoint signature change can be a non-trivial work and might require refactor/redesign. Thus for this specific bug purpose we implemented this immediate fix.
Contribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.