Add on-demand full-state checkpointing for OpenShift AI / KubeFlow preemption#686
Add on-demand full-state checkpointing for OpenShift AI / KubeFlow preemption#686
Conversation
…eemption Implements signal-driven checkpoint-and-exit for distributed training jobs running in OpenShift AI as KubeFlow training jobs or multi-node bare metal. When `on_demand_checkpointing=True` is set in TrainingArgs: - Parent process (run_training) installs handlers for SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, and SIGHUP — covering all signals Kubernetes/OpenShift sends before the hard SIGKILL. - On signal receipt, a trigger file is atomically written to /dev/shm (tmpfs, shared within the pod, zero disk I/O). - Worker processes check for the trigger file after each optimizer step via an all_reduce(MAX) collective, ensuring global consensus across all ranks on all nodes. - When any rank detects the trigger, all ranks collectively save a full-state distributed checkpoint (model + optimizer + LR scheduler) then exit gracefully. - Parent waits up to 300s for workers to complete the checkpoint before proceeding with normal shutdown. https://claude.ai/code/session_01HSxsk7SnMULJxy7uafe7t3
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (4)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds an opt‑in on‑demand, signal‑triggered full‑state checkpointing mode: new TrainingArgs flag and CLI option, parent-side signal handler to create a shared trigger, worker-side consensus check and save-on-demand flow, and integration points in minibatch processing to interrupt and persist training state. Changes
Sequence DiagramsequenceDiagram
participant Parent as Parent Process
participant Signal as ParentSignalHandler
participant Worker as Worker Process(es)
participant Trigger as Trigger File (/dev/shm)
participant Dist as Distributed Backend
participant Checkpoint as Checkpoint Storage
Note over Parent,Worker: On‑demand checkpoint flow
Parent->>Signal: install()
Worker->>Worker: training loop -> process_batch(interrupt_check)
Parent->>Parent: receives termination signal
Parent->>Signal: handler invoked
Signal->>Trigger: write_trigger_file(job_id)
Worker->>Trigger: trigger_file_exists()
Worker->>Dist: all_reduce(MAX, local_flag)
Dist-->>Worker: consensus_flag
alt consensus_flag == true
Worker->>Checkpoint: save_on_demand_checkpoint(full_state=True)
Checkpoint-->>Worker: saved
Worker->>Trigger: remove_trigger_file()
Worker->>Worker: exit early
end
Parent->>Parent: wait for workers (timeout)
Parent->>Signal: uninstall()
Parent->>Parent: exit
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
src/instructlab/training/on_demand_checkpoint.py (1)
225-229: Consider rank-gating the global-consensus log.When a checkpoint is requested, every rank logs the same message. Logging only on rank 0 would reduce shutdown-time log bursts on large jobs.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/instructlab/training/on_demand_checkpoint.py` around lines 225 - 229, The log message is emitted by every rank when a checkpoint is requested; gate it to only run on the main/rank-0 process to avoid log storms. Wrap the existing logger.info block (the code that runs when requested is truthy) with a check for the main process—e.g., if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: or, if the project exposes a helper like is_main_process(), use that—then call logger.info only inside that conditional while leaving the checkpoint request flow unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/instructlab/training/main_ds.py`:
- Around line 882-900: The code computes failure using process.poll() before
sending terminate()/kill(), so if the subprocess exits after forced shutdown the
failure status can be stale; update the logic inside the shutdown path in
main_ds.py to recompute process_code and failure after you perform
terminate()/kill() and any subsequent wait() calls (use process.wait with a
timeout then process.poll()), and then decide whether to log success or raise
based on the new failure value; apply the same fix for the second occurrence
referenced (the block around the later terminate/kill sequence) and reference
process.wait, process.poll, terminate(), kill(), and the logger.error messages
when updating the flow.
- Around line 821-833: The ParentSignalHandler is being instantiated without a
job identifier causing shared trigger files; update the instantiation to pass a
stable job id (e.g., use train_args.job_id or another unique training identifier
available in scope) so ParentSignalHandler(job_id=...) is used and the
handler.install() uses a namespaced trigger path; ensure the same job_id is
passed to any worker-side reader logic so trigger files live under a per-job
namespace instead of the global default.
---
Nitpick comments:
In `@src/instructlab/training/on_demand_checkpoint.py`:
- Around line 225-229: The log message is emitted by every rank when a
checkpoint is requested; gate it to only run on the main/rank-0 process to avoid
log storms. Wrap the existing logger.info block (the code that runs when
requested is truthy) with a check for the main process—e.g., if
torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: or, if
the project exposes a helper like is_main_process(), use that—then call
logger.info only inside that conditional while leaving the checkpoint request
flow unchanged.
ℹ️ Review info
Configuration used: defaults
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/instructlab/training/config.pysrc/instructlab/training/main_ds.pysrc/instructlab/training/on_demand_checkpoint.py
- Fix mypy error: properly type _original_handlers dict with _SignalHandler type alias instead of bare object - Fix ruff/isort: remove duplicate comment, fix import ordering - Namespace trigger file with rdzv_id as job_id so concurrent jobs sharing /dev/shm don't interfere with each other - Recompute subprocess failure status after forced termination to avoid stale exit code - Gate consensus log message to rank 0 to reduce log noise on large jobs
Move the checkpoint request check from after the full optimizer step to after each minibatch's backward pass inside BatchLossManager.process_batch. This ensures the system responds within one fwd+bwd cycle (~1-2s) even when gradient accumulation spans many minibatches, giving more time to save before Kubernetes sends SIGKILL after the grace period. The check is passed as an optional interrupt_check callback to keep checkpoint-specific logic out of BatchLossManager. When triggered, the batch loop breaks early and the training loop saves the checkpoint immediately, skipping the optimizer step to preserve the pre-step model state for exact resumption.
When the training subprocess fails after an on-demand checkpoint signal was received, the error message now includes guidance to increase terminationGracePeriodSeconds or reduce fwd/bwd pass time so the checkpoint check fires before SIGKILL arrives.
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/instructlab/training/main_ds.py`:
- Around line 1160-1170: The help text for the "--on_demand_checkpointing"
argparse option is inaccurate: it says workers check "after each training step"
but the implementation triggers checks after each minibatch backward pass (see
BatchLossManager.process_batch). Update the parser.add_argument help string for
"--on_demand_checkpointing" to explicitly say the check happens after each
minibatch/backward pass (or "after each minibatch backward pass") and mention
that this is the granularity for checkpoint-trigger latency so the doc matches
the behavior in BatchLossManager.process_batch.
Update --on_demand_checkpointing help text and TrainingArgs description to accurately state that workers check for the trigger file after each minibatch backward pass, not after each training step.
Expand on-demand checkpointing to check for a trigger at five points: 1. Before each minibatch forward pass 2. Before each minibatch backward pass 3. After each minibatch backward pass (existing) 4. Before the optimizer step 5. After the optimizer step This minimizes the latency between a termination signal arriving and the checkpoint being saved, which is critical when the SIGKILL grace period is short (e.g. 30s on OpenShift/Kubernetes). Also cleans up the save-and-exit logic in train() by extracting a _save_and_exit() helper to eliminate three nearly identical blocks, and fixes _compute_average_loss to handle the case where the minibatch loop is interrupted before any forward pass completes. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Implements signal-driven checkpoint-and-exit for distributed training jobs
running in OpenShift AI as KubeFlow training jobs or multi-node bare metal.
When
on_demand_checkpointing=Trueis set in TrainingArgs:SIGUSR1, SIGUSR2, SIGXCPU, and SIGHUP — covering all signals
Kubernetes/OpenShift sends before the hard SIGKILL.
(tmpfs, shared within the pod, zero disk I/O).
via an all_reduce(MAX) collective, ensuring global consensus across
all ranks on all nodes.
full-state distributed checkpoint (model + optimizer + LR scheduler)
then exit gracefully.
proceeding with normal shutdown.
https://claude.ai/code/session_01HSxsk7SnMULJxy7uafe7t3
Summary by CodeRabbit