From 4d82b3d98dee52b891189637d6c66c0f4b2f8d7c Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 26 Feb 2026 22:26:25 +0000 Subject: [PATCH 01/14] Add on-demand full-state checkpointing for OpenShift AI / KubeFlow preemption MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements signal-driven checkpoint-and-exit for distributed training jobs running in OpenShift AI as KubeFlow training jobs or multi-node bare metal. When `on_demand_checkpointing=True` is set in TrainingArgs: - Parent process (run_training) installs handlers for SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, and SIGHUP — covering all signals Kubernetes/OpenShift sends before the hard SIGKILL. - On signal receipt, a trigger file is atomically written to /dev/shm (tmpfs, shared within the pod, zero disk I/O). - Worker processes check for the trigger file after each optimizer step via an all_reduce(MAX) collective, ensuring global consensus across all ranks on all nodes. - When any rank detects the trigger, all ranks collectively save a full-state distributed checkpoint (model + optimizer + LR scheduler) then exit gracefully. - Parent waits up to 300s for workers to complete the checkpoint before proceeding with normal shutdown. https://claude.ai/code/session_01HSxsk7SnMULJxy7uafe7t3 --- src/instructlab/training/config.py | 13 + src/instructlab/training/main_ds.py | 102 ++++++- .../training/on_demand_checkpoint.py | 277 ++++++++++++++++++ 3 files changed, 385 insertions(+), 7 deletions(-) create mode 100644 src/instructlab/training/on_demand_checkpoint.py diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 911c3898..7c235582 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -351,6 +351,19 @@ class TrainingArgs(BaseModel): description="How often to evaluate validation loss (in training steps). Required when validation_split > 0.", ) + on_demand_checkpointing: bool = Field( + default=False, + description=( + "Enable on-demand full-state checkpointing triggered by Unix signals. " + "When enabled, the parent process intercepts termination signals " + "(SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, SIGHUP) and writes a " + "trigger file to /dev/shm. Worker processes check for this trigger " + "after each training step and collectively save a distributed " + "checkpoint before exiting gracefully. Designed for OpenShift AI / " + "KubeFlow training jobs where preemption signals must be handled." + ), + ) + @model_validator(mode="after") def validate_validation_config(self): if not 0.0 <= self.validation_split < 1.0: diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 072b27c6..8b47919d 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -173,6 +173,7 @@ def train( accelerator: Accelerator, val_data_loader=None, validation_frequency=None, + on_demand_checkpointing: bool = False, ): model.train() @@ -183,6 +184,15 @@ def train( metric_logger = logging.getLogger("instructlab.training.metrics") base_logger = logging.getLogger("instructlab.training") + # Import on-demand checkpointing utilities once if the feature is enabled + if on_demand_checkpointing: + from instructlab.training.on_demand_checkpoint import ( + check_checkpoint_requested, + save_on_demand_checkpoint, + ) + + base_logger.info("On-demand checkpointing is enabled in worker process.") + # Mini_trainer approach: batch_size will be determined dynamically by data loader # For save logic, use effective_batch_size since that's the target samples_seen = 0 @@ -308,6 +318,22 @@ def train( base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) dist.barrier() + # --- On-demand checkpointing: check if a signal triggered a save --- + if on_demand_checkpointing and check_checkpoint_requested(): + save_on_demand_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=model.tokenizer, + samples_seen=samples_seen, + epoch=epoch, + is_lora=bool(args.lora_r), + ) + base_logger.info( + "On-demand checkpoint saved. Exiting training gracefully." + ) + return + global_step += 1 if local_rank == 0: inner_pb.update(1) @@ -561,6 +587,7 @@ def main(args): accelerator=accelerator, val_data_loader=val_loader, validation_frequency=validation_frequency, + on_demand_checkpointing=getattr(args, "on_demand_checkpointing", False), ) dist.barrier() @@ -791,7 +818,24 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.keep_last_checkpoint_only: command.append("--keep_last_checkpoint_only") + if train_args.on_demand_checkpointing: + command.append("--on_demand_checkpointing") + logger.info("Running training command as subprocess: %s", " ".join(command)) + + # --- On-demand checkpointing: install signal handlers in the parent --- + signal_handler = None + if train_args.on_demand_checkpointing: + # First Party + from instructlab.training.on_demand_checkpoint import ParentSignalHandler + + signal_handler = ParentSignalHandler() + signal_handler.install() + logger.info( + "On-demand checkpointing is ENABLED. " + "Termination signals will trigger a full-state checkpoint before exit." + ) + process = None interrupt: KeyboardInterrupt | Exception | None = None failure = False @@ -811,19 +855,49 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: interrupt = e finally: if "process" not in locals() or process is None: + if signal_handler is not None: + signal_handler.uninstall() return + # If a signal was caught by the on-demand checkpoint handler, give + # the workers time to detect the trigger file and save a checkpoint + # before we start sending our own signals to the subprocess. + if signal_handler is not None and signal_handler.signal_received is not None: + logger.info( + "On-demand checkpoint: signal %s received. Waiting for workers to " + "save checkpoint before proceeding with shutdown...", + signal_handler.signal_received.name, + ) + # Give workers generous time to complete the checkpoint save. + # The workers will exit on their own after saving. + try: + process.wait(timeout=300) + except subprocess.TimeoutExpired: + logger.warning( + "On-demand checkpoint: workers did not finish within 300s. " + "Proceeding with shutdown." + ) + # wait for the process to exit so we can properly read the exit code - process.wait(timeout=60) + try: + process.wait(timeout=60) + except subprocess.TimeoutExpired: + pass process_code = process.poll() - failure = process_code != 0 + failure = process_code is not None and process_code != 0 - if not failure: - logger.info("Operation completed successfully! 🎉") + if process_code is not None and not failure: + logger.info("Operation completed successfully!") else: - logger.error( - f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {process_code}" - ) + if process_code is None: + logger.error( + "Training subprocess has not exited yet. Sending SIGTERM." + ) + else: + logger.error( + "Training subprocess exited with code %d. Sending SIGTERM.", + process_code, + ) process.terminate() try: @@ -835,6 +909,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ) process.kill() + if signal_handler is not None: + signal_handler.uninstall() + if interrupt: raise interrupt if failure: @@ -1045,6 +1122,17 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ), ) + parser.add_argument( + "--on_demand_checkpointing", + action="store_true", + default=False, + help=( + "Enable on-demand full-state checkpointing triggered by Unix signals. " + "When enabled, workers check for a trigger file in /dev/shm after each " + "training step and collectively save a distributed checkpoint before " + "exiting. Designed for OpenShift AI / KubeFlow preemption handling." + ), + ) parser.add_argument( "--use_liger", action="store_true", diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py new file mode 100644 index 00000000..8d4e7462 --- /dev/null +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -0,0 +1,277 @@ +# SPDX-License-Identifier: Apache-2.0 + +""" +On-demand checkpointing for distributed training. + +This module enables graceful checkpoint-and-exit when termination signals are +received. It is designed for environments like OpenShift AI / KubeFlow where +training jobs can be preempted at any time and the platform sends Unix signals +before killing the pod. + +Architecture +------------ +There are two sides to this feature: + +**Parent process** (``run_training`` in ``main_ds.py``): + Installs signal handlers that catch every signal OpenShift / Kubernetes can + send before a SIGKILL. When a signal arrives the handler writes a small + *trigger file* to ``/dev/shm`` (a tmpfs shared between containers in the + same pod). Because ``/dev/shm`` is node-local, every worker on the **same + node** can see the file instantly with zero network I/O. + +**Worker processes** (torchrun children): + After every optimizer step the training loop calls + ``check_checkpoint_requested()``. Each rank checks its local ``/dev/shm`` + for the trigger file, converts the boolean to a tensor, and does an + ``all_reduce(MAX)`` so that if *any* rank on *any* node detected the + trigger, *every* rank agrees to save a checkpoint. This works correctly in + multi-node training because all_reduce is a global collective. + +Signals handled +--------------- +We intercept every signal that Kubernetes / OpenShift can deliver before the +hard SIGKILL (which cannot be caught): + +* **SIGTERM** – the standard graceful-shutdown signal. Kubernetes sends this + first (configurable via ``terminationGracePeriodSeconds``). +* **SIGINT** – sent on Ctrl-C or by some job controllers. +* **SIGUSR1 / SIGUSR2** – commonly used by batch schedulers and custom + preemption controllers to signal upcoming eviction. +* **SIGXCPU** – sent when CPU time limits are exceeded (relevant for jobs + with resource quotas). +* **SIGHUP** – sent when the controlling terminal disconnects; some + container runtimes forward this on pod eviction. +""" + +# Standard +import logging +import os +import signal +import tempfile +from pathlib import Path +from typing import Optional + +# Third Party +import torch +import torch.distributed as dist + +logger = logging.getLogger("instructlab.training") + +# --------------------------------------------------------------------------- +# Trigger file helpers +# --------------------------------------------------------------------------- + +# The trigger file lives in /dev/shm which is a tmpfs (RAM-backed filesystem). +# It is: +# 1. Extremely fast (no disk I/O). +# 2. Shared between all containers in the same Kubernetes pod. +# 3. Automatically cleaned up when the pod is destroyed. +_TRIGGER_DIR = Path("/dev/shm") +_TRIGGER_FILENAME = "instructlab_checkpoint_requested" + + +def _get_trigger_path(job_id: Optional[str] = None) -> Path: + """Return the path to the checkpoint trigger file. + + An optional *job_id* can be supplied to avoid collisions if multiple + training jobs share the same ``/dev/shm`` (unlikely but possible). + """ + name = f"{_TRIGGER_FILENAME}_{job_id}" if job_id else _TRIGGER_FILENAME + return _TRIGGER_DIR / name + + +def write_trigger_file(job_id: Optional[str] = None) -> Path: + """Create the trigger file that tells workers to checkpoint. + + This is called from the *parent* process signal handler. + Returns the path that was written. + """ + path = _get_trigger_path(job_id) + # Use a atomic write via tempfile + rename to avoid partial reads. + fd, tmp = tempfile.mkstemp(dir=_TRIGGER_DIR, prefix=".ckpt_trigger_") + try: + os.write(fd, b"1") + finally: + os.close(fd) + os.rename(tmp, path) + logger.info( + "On-demand checkpoint trigger file written: %s", + path, + ) + return path + + +def trigger_file_exists(job_id: Optional[str] = None) -> bool: + """Check whether the trigger file exists (worker-side).""" + return _get_trigger_path(job_id).exists() + + +def remove_trigger_file(job_id: Optional[str] = None) -> None: + """Remove the trigger file after the checkpoint has been saved.""" + path = _get_trigger_path(job_id) + try: + path.unlink(missing_ok=True) + except OSError: + pass + + +# --------------------------------------------------------------------------- +# Parent-side signal handling +# --------------------------------------------------------------------------- + +# Signals that OpenShift / Kubernetes / batch schedulers may send before +# the hard SIGKILL. SIGKILL (9) and SIGSTOP (19) cannot be caught. +_CATCHABLE_SIGNALS = ( + signal.SIGTERM, # Kubernetes default graceful shutdown signal + signal.SIGINT, # Ctrl-C / some job controllers + signal.SIGUSR1, # Custom preemption controllers + signal.SIGUSR2, # Custom preemption controllers + signal.SIGXCPU, # CPU time limit exceeded (resource quotas) + signal.SIGHUP, # Terminal disconnect / some eviction paths +) + + +class ParentSignalHandler: + """Installs signal handlers in the parent (launcher) process. + + When any of the catchable signals fire, the handler: + 1. Writes the trigger file to ``/dev/shm``. + 2. Records that a signal was received (so the caller can decide to + wait for the child process to finish checkpointing). + + The handler is idempotent – multiple signals will not create multiple + trigger files. + + Parameters + ---------- + job_id : str, optional + Unique identifier for this training job. Used to namespace the + trigger file. + """ + + def __init__(self, job_id: Optional[str] = None): + self.job_id = job_id + self.signal_received: Optional[signal.Signals] = None + self._original_handlers: dict[signal.Signals, object] = {} + self._trigger_written = False + + def install(self) -> None: + """Register signal handlers for all catchable signals.""" + for sig in _CATCHABLE_SIGNALS: + try: + self._original_handlers[sig] = signal.getsignal(sig) + signal.signal(sig, self._handle) + except (OSError, ValueError): + # Some signals may not be available on all platforms + logger.debug("Could not install handler for %s", sig.name) + + logger.info( + "On-demand checkpoint signal handlers installed for: %s", + ", ".join(s.name for s in self._original_handlers), + ) + + def uninstall(self) -> None: + """Restore original signal handlers.""" + for sig, handler in self._original_handlers.items(): + try: + signal.signal(sig, handler) + except (OSError, ValueError): + pass + self._original_handlers.clear() + + def _handle(self, signum: int, _frame) -> None: + """Signal handler callback.""" + sig = signal.Signals(signum) + logger.info( + "On-demand checkpoint: received signal %s (%d). " + "Writing trigger file for workers to checkpoint before exit.", + sig.name, + signum, + ) + self.signal_received = sig + + if not self._trigger_written: + write_trigger_file(self.job_id) + self._trigger_written = True + + +# --------------------------------------------------------------------------- +# Worker-side synchronization +# --------------------------------------------------------------------------- + + +def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: + """Check across all ranks whether an on-demand checkpoint was requested. + + This function must be called by **all ranks** at the same point in the + training loop (it contains a collective all_reduce). + + Returns ``True`` if any rank detected the trigger file, meaning all + ranks should save a checkpoint. + """ + local_trigger = trigger_file_exists(job_id) + + # Convert to a tensor and all-reduce (MAX) so that if ANY rank on ANY + # node saw the trigger, every rank gets True. + trigger_tensor = torch.tensor( + [1 if local_trigger else 0], + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + dist.all_reduce(trigger_tensor, op=dist.ReduceOp.MAX) + + requested = trigger_tensor.item() > 0 + + if requested: + logger.info( + "On-demand checkpoint: global consensus reached – " + "all ranks will save a checkpoint." + ) + # Clean up the trigger file so that if the process somehow + # continues, we don't save again immediately. + remove_trigger_file(job_id) + + return requested + + +def save_on_demand_checkpoint( + args, + accelerator, + model, + tokenizer, + samples_seen: int, + epoch: int, + is_lora: bool, +) -> None: + """Save a full-state distributed checkpoint for on-demand resume. + + This is a thin wrapper that calls the existing ``save_checkpoint`` + utility with ``full_state=True`` so that optimizer + LR scheduler + state are also persisted, enabling exact training resumption. + """ + # First Party – imported here to avoid circular imports + from instructlab.training.utils import save_checkpoint + + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if local_rank == 0: + logger.info( + "On-demand checkpoint: saving full-state checkpoint at " + "epoch=%d, samples_seen=%d", + epoch, + samples_seen, + ) + + save_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=tokenizer, + samples_seen=samples_seen, + is_lora=is_lora, + full_state=True, + hf_format=True, + epoch=epoch, + ) + + if local_rank == 0: + logger.info("On-demand checkpoint: checkpoint saved successfully.") From 7cffa29ad018753e0683e7d443b8c8753303b6e2 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 2 Mar 2026 19:28:43 +0000 Subject: [PATCH 02/14] Address review feedback for on-demand checkpointing - Fix mypy error: properly type _original_handlers dict with _SignalHandler type alias instead of bare object - Fix ruff/isort: remove duplicate comment, fix import ordering - Namespace trigger file with rdzv_id as job_id so concurrent jobs sharing /dev/shm don't interfere with each other - Recompute subprocess failure status after forced termination to avoid stale exit code - Gate consensus log message to rank 0 to reduce log noise on large jobs --- src/instructlab/training/main_ds.py | 57 ++++++++++++------- .../training/on_demand_checkpoint.py | 37 +++++++----- 2 files changed, 57 insertions(+), 37 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 8b47919d..605dc2a7 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -185,12 +185,15 @@ def train( base_logger = logging.getLogger("instructlab.training") # Import on-demand checkpointing utilities once if the feature is enabled + checkpoint_job_id = None if on_demand_checkpointing: + # First Party from instructlab.training.on_demand_checkpoint import ( check_checkpoint_requested, save_on_demand_checkpoint, ) + checkpoint_job_id = os.environ.get("INSTRUCTLAB_ON_DEMAND_JOB_ID") base_logger.info("On-demand checkpointing is enabled in worker process.") # Mini_trainer approach: batch_size will be determined dynamically by data loader @@ -319,7 +322,9 @@ def train( dist.barrier() # --- On-demand checkpointing: check if a signal triggered a save --- - if on_demand_checkpointing and check_checkpoint_requested(): + if on_demand_checkpointing and check_checkpoint_requested( + checkpoint_job_id + ): save_on_demand_checkpoint( args=args, accelerator=accelerator, @@ -829,11 +834,16 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: # First Party from instructlab.training.on_demand_checkpoint import ParentSignalHandler - signal_handler = ParentSignalHandler() + # Use rdzv_id to namespace the trigger file so concurrent jobs + # sharing /dev/shm don't interfere with each other. + checkpoint_job_id = str(torch_args.rdzv_id) + os.environ["INSTRUCTLAB_ON_DEMAND_JOB_ID"] = checkpoint_job_id + signal_handler = ParentSignalHandler(job_id=checkpoint_job_id) signal_handler.install() logger.info( - "On-demand checkpointing is ENABLED. " - "Termination signals will trigger a full-state checkpoint before exit." + "On-demand checkpointing is ENABLED (job_id=%s). " + "Termination signals will trigger a full-state checkpoint before exit.", + checkpoint_job_id, ) process = None @@ -884,30 +894,33 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: except subprocess.TimeoutExpired: pass process_code = process.poll() - failure = process_code is not None and process_code != 0 - if process_code is not None and not failure: + if process_code is not None and process_code == 0: logger.info("Operation completed successfully!") - else: - if process_code is None: - logger.error( - "Training subprocess has not exited yet. Sending SIGTERM." - ) - else: + elif process_code is None: + logger.error("Training subprocess has not exited yet. Sending SIGTERM.") + process.terminate() + try: + logger.info("Waiting for process to exit, 60s...") + process.wait(timeout=60) + except subprocess.TimeoutExpired: logger.error( - "Training subprocess exited with code %d. Sending SIGTERM.", - process_code, + "Training subprocess did not terminate before timeout, sending SIGKILL." ) - - process.terminate() - try: - logger.info("Waiting for process to exit, 60s...") - process.wait(timeout=60) - except subprocess.TimeoutExpired: + process.kill() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + pass + else: logger.error( - "Training subprocess did not terminate before timeout, sending SIGKILL." + "Training subprocess exited with code %d.", + process_code, ) - process.kill() + + # Recompute final exit status after any forced shutdown + process_code = process.poll() + failure = process_code is None or process_code != 0 if signal_handler is not None: signal_handler.uninstall() diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py index 8d4e7462..b9643531 100644 --- a/src/instructlab/training/on_demand_checkpoint.py +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -44,17 +44,23 @@ """ # Standard +from pathlib import Path +from typing import Callable, Optional, Union import logging import os import signal import tempfile -from pathlib import Path -from typing import Optional +import types # Third Party import torch import torch.distributed as dist +# Type alias matching the return type of signal.getsignal(). +_SignalHandler = Union[ + Callable[[int, Optional[types.FrameType]], None], int, signal.Handlers, None +] + logger = logging.getLogger("instructlab.training") # --------------------------------------------------------------------------- @@ -122,12 +128,12 @@ def remove_trigger_file(job_id: Optional[str] = None) -> None: # Signals that OpenShift / Kubernetes / batch schedulers may send before # the hard SIGKILL. SIGKILL (9) and SIGSTOP (19) cannot be caught. _CATCHABLE_SIGNALS = ( - signal.SIGTERM, # Kubernetes default graceful shutdown signal - signal.SIGINT, # Ctrl-C / some job controllers - signal.SIGUSR1, # Custom preemption controllers - signal.SIGUSR2, # Custom preemption controllers - signal.SIGXCPU, # CPU time limit exceeded (resource quotas) - signal.SIGHUP, # Terminal disconnect / some eviction paths + signal.SIGTERM, # Kubernetes default graceful shutdown signal + signal.SIGINT, # Ctrl-C / some job controllers + signal.SIGUSR1, # Custom preemption controllers + signal.SIGUSR2, # Custom preemption controllers + signal.SIGXCPU, # CPU time limit exceeded (resource quotas) + signal.SIGHUP, # Terminal disconnect / some eviction paths ) @@ -152,7 +158,7 @@ class ParentSignalHandler: def __init__(self, job_id: Optional[str] = None): self.job_id = job_id self.signal_received: Optional[signal.Signals] = None - self._original_handlers: dict[signal.Signals, object] = {} + self._original_handlers: dict[signal.Signals, _SignalHandler] = {} self._trigger_written = False def install(self) -> None: @@ -174,7 +180,7 @@ def uninstall(self) -> None: """Restore original signal handlers.""" for sig, handler in self._original_handlers.items(): try: - signal.signal(sig, handler) + signal.signal(sig, handler) # type: ignore[arg-type] except (OSError, ValueError): pass self._original_handlers.clear() @@ -223,10 +229,11 @@ def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: requested = trigger_tensor.item() > 0 if requested: - logger.info( - "On-demand checkpoint: global consensus reached – " - "all ranks will save a checkpoint." - ) + if dist.is_initialized() and dist.get_rank() == 0: + logger.info( + "On-demand checkpoint: global consensus reached – " + "all ranks will save a checkpoint." + ) # Clean up the trigger file so that if the process somehow # continues, we don't save again immediately. remove_trigger_file(job_id) @@ -249,7 +256,7 @@ def save_on_demand_checkpoint( utility with ``full_state=True`` so that optimizer + LR scheduler state are also persisted, enabling exact training resumption. """ - # First Party – imported here to avoid circular imports + # First Party from instructlab.training.utils import save_checkpoint local_rank = int(os.environ.get("LOCAL_RANK", "0")) From 799417bf6154542499f5927cf781f1811e3e5336 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:27:05 +0000 Subject: [PATCH 03/14] Check for on-demand checkpoint after each minibatch backward Move the checkpoint request check from after the full optimizer step to after each minibatch's backward pass inside BatchLossManager.process_batch. This ensures the system responds within one fwd+bwd cycle (~1-2s) even when gradient accumulation spans many minibatches, giving more time to save before Kubernetes sends SIGKILL after the grace period. The check is passed as an optional interrupt_check callback to keep checkpoint-specific logic out of BatchLossManager. When triggered, the batch loop breaks early and the training loop saves the checkpoint immediately, skipping the optimizer step to preserve the pre-step model state for exact resumption. --- .../training/batch_loss_manager.py | 22 ++++++++- src/instructlab/training/main_ds.py | 49 +++++++++++-------- 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index cc6da021..b199ac17 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -7,7 +7,8 @@ """ # Standard -from dataclasses import dataclass +from collections.abc import Callable +from dataclasses import dataclass, field import logging # Third Party @@ -33,6 +34,7 @@ class BatchMetrics: accumulated_aux_loss: torch.Tensor | None grad_accum_steps: int num_minibatches: int + interrupted: bool = field(default=False) class BatchLossManager: @@ -62,12 +64,21 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int): self.local_rank: int = local_rank self.torch_device = torch.device("cuda", local_rank) - def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]: + def process_batch( + self, + batch: list[CollatedItem], + interrupt_check: Callable[[], bool] | None = None, + ) -> tuple[BatchMetrics, float]: """ Process a batch of minibatches, computing losses and accumulating gradients. Args: batch: List of minibatches to process + interrupt_check: Optional callback invoked after each minibatch's + backward pass. If it returns ``True``, gradient accumulation + stops early and ``BatchMetrics.interrupted`` is set. Used by + on-demand checkpointing to react within one fwd+bwd cycle + instead of waiting for the full optimizer step. Returns: tuple: (BatchMetrics, average_loss_across_ranks) @@ -82,6 +93,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] accumulated_loss = 0.0 accumulated_aux_loss = 0.0 grad_accum_steps = 0 + interrupted = False # process each minibatch for mb in batch: @@ -108,6 +120,11 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] if raw_losses.aux_loss is not None: accumulated_aux_loss += raw_losses.aux_loss + # check for early exit (e.g. on-demand checkpoint requested) + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + # reduce metrics across ranks batch_total_samples, batch_total_length = self._reduce_metrics( batch_total_samples, batch_total_length @@ -127,6 +144,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] accumulated_aux_loss=accumulated_aux_loss, grad_accum_steps=grad_accum_steps, num_minibatches=num_minibatches, + interrupted=interrupted, ) return metrics, avg_loss_across_ranks diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 605dc2a7..88a48f4d 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -233,11 +233,38 @@ def train( continue start = time.time() - # Process the batch using the BatchLossManager + # Process the batch using the BatchLossManager. + # When on-demand checkpointing is enabled, pass a callback so + # the check runs after every minibatch backward rather than + # waiting for the full optimizer step. + _interrupt_check = ( + (lambda: check_checkpoint_requested(checkpoint_job_id)) + if on_demand_checkpointing + else None + ) batch_metrics, avg_loss_across_ranks = batch_loss_manager.process_batch( - batch + batch, interrupt_check=_interrupt_check ) + # If the batch was interrupted by an on-demand checkpoint + # request, save immediately and exit — skip the optimizer step + # since we want to preserve the pre-step model state for + # exact resumption. + if batch_metrics.interrupted: + save_on_demand_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=model.tokenizer, + samples_seen=samples_seen, + epoch=epoch, + is_lora=bool(args.lora_r), + ) + base_logger.info( + "On-demand checkpoint saved. Exiting training gracefully." + ) + return + # Update samples seen samples_seen += batch_metrics.total_samples @@ -321,24 +348,6 @@ def train( base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) dist.barrier() - # --- On-demand checkpointing: check if a signal triggered a save --- - if on_demand_checkpointing and check_checkpoint_requested( - checkpoint_job_id - ): - save_on_demand_checkpoint( - args=args, - accelerator=accelerator, - model=model, - tokenizer=model.tokenizer, - samples_seen=samples_seen, - epoch=epoch, - is_lora=bool(args.lora_r), - ) - base_logger.info( - "On-demand checkpoint saved. Exiting training gracefully." - ) - return - global_step += 1 if local_rank == 0: inner_pb.update(1) From d089910ce320e70657cd21041a3739e33891cf46 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:29:01 +0000 Subject: [PATCH 04/14] Add diagnostic note when on-demand checkpoint fails to save in time When the training subprocess fails after an on-demand checkpoint signal was received, the error message now includes guidance to increase terminationGracePeriodSeconds or reduce fwd/bwd pass time so the checkpoint check fires before SIGKILL arrives. --- src/instructlab/training/main_ds.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 88a48f4d..05f3a851 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -937,9 +937,22 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if interrupt: raise interrupt if failure: - raise RuntimeError( - "Suffered a failure during distributed training. Please see the training logs for more context." - ) + msg = "Suffered a failure during distributed training. Please see the training logs for more context." + if ( + signal_handler is not None + and signal_handler.signal_received is not None + ): + msg += ( + f"\n\nNote: signal {signal_handler.signal_received.name} was" + " received and on-demand checkpointing was enabled, but the" + " training subprocess did not exit cleanly. This usually" + " means the process was killed (SIGKILL) before the" + " checkpoint could be saved. To fix this, increase" + " terminationGracePeriodSeconds in your pod spec to give" + " workers more time, or reduce the model's forward/backward" + " pass time so the checkpoint check fires sooner." + ) + raise RuntimeError(msg) if __name__ == "__main__": From 68fb576049130659bcc4de391467766858996e9c Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 2 Mar 2026 20:38:56 +0000 Subject: [PATCH 05/14] Fix help text: checkpoint check happens after each minibatch backward Update --on_demand_checkpointing help text and TrainingArgs description to accurately state that workers check for the trigger file after each minibatch backward pass, not after each training step. --- src/instructlab/training/config.py | 2 +- src/instructlab/training/main_ds.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 7c235582..11182a0e 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -358,7 +358,7 @@ class TrainingArgs(BaseModel): "When enabled, the parent process intercepts termination signals " "(SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, SIGHUP) and writes a " "trigger file to /dev/shm. Worker processes check for this trigger " - "after each training step and collectively save a distributed " + "after each minibatch backward pass and collectively save a distributed " "checkpoint before exiting gracefully. Designed for OpenShift AI / " "KubeFlow training jobs where preemption signals must be handled." ), diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 05f3a851..91f73008 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -1164,7 +1164,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: help=( "Enable on-demand full-state checkpointing triggered by Unix signals. " "When enabled, workers check for a trigger file in /dev/shm after each " - "training step and collectively save a distributed checkpoint before " + "minibatch backward pass and collectively save a distributed checkpoint before " "exiting. Designed for OpenShift AI / KubeFlow preemption handling." ), ) From d21bf12e912c7f306f4313610865bbf3eeaef467 Mon Sep 17 00:00:00 2001 From: Oleg S <97077423+RobotSail@users.noreply.github.com> Date: Fri, 20 Mar 2026 10:27:58 -0400 Subject: [PATCH 06/14] Add checkpoint checks at 5 synchronization points per training step Expand on-demand checkpointing to check for a trigger at five points: 1. Before each minibatch forward pass 2. Before each minibatch backward pass 3. After each minibatch backward pass (existing) 4. Before the optimizer step 5. After the optimizer step This minimizes the latency between a termination signal arriving and the checkpoint being saved, which is critical when the SIGKILL grace period is short (e.g. 30s on OpenShift/Kubernetes). Also cleans up the save-and-exit logic in train() by extracting a _save_and_exit() helper to eliminate three nearly identical blocks, and fixes _compute_average_loss to handle the case where the minibatch loop is interrupted before any forward pass completes. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../training/batch_loss_manager.py | 37 +++++++++---- src/instructlab/training/main_ds.py | 53 +++++++++++++------ .../training/on_demand_checkpoint.py | 28 +++++++--- 3 files changed, 85 insertions(+), 33 deletions(-) diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index b199ac17..46e2af30 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -74,10 +74,11 @@ def process_batch( Args: batch: List of minibatches to process - interrupt_check: Optional callback invoked after each minibatch's - backward pass. If it returns ``True``, gradient accumulation + interrupt_check: Optional callback invoked at three points per + minibatch: before forward, before backward, and after + backward. If it returns ``True`` at any point, processing stops early and ``BatchMetrics.interrupted`` is set. Used by - on-demand checkpointing to react within one fwd+bwd cycle + on-demand checkpointing to react as quickly as possible instead of waiting for the full optimizer step. Returns: @@ -97,6 +98,11 @@ def process_batch( # process each minibatch for mb in batch: + # Check for on-demand checkpoint before forward + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + # extract minibatch-specific info micro_batch_size = mb["num_samples"] total_length = mb["total_length"] @@ -108,10 +114,16 @@ def process_batch( # prepare model inputs model_inputs = self._prepare_model_inputs(mb) - # compute loss and backward pass + # compute loss (forward pass) scaled_loss, raw_losses = self.model.compute_loss( model_inputs, self.world_size, batch_num_loss_counted_tokens ) + + # Check for on-demand checkpoint before backward + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + self.accelerator.backward(scaled_loss) # accumulate losses @@ -120,7 +132,7 @@ def process_batch( if raw_losses.aux_loss is not None: accumulated_aux_loss += raw_losses.aux_loss - # check for early exit (e.g. on-demand checkpoint requested) + # Check for on-demand checkpoint after backward if interrupt_check is not None and interrupt_check(): interrupted = True break @@ -183,8 +195,8 @@ def _reduce_metrics( def _compute_average_loss( self, - accumulated_loss: torch.Tensor, - accumulated_aux_loss: torch.Tensor | None, + accumulated_loss: torch.Tensor | float, + accumulated_aux_loss: torch.Tensor | float | None, batch_num_loss_counted_tokens: int, ) -> float: """Compute average loss across all ranks for metrics logging.""" @@ -195,11 +207,16 @@ def _compute_average_loss( if accumulated_aux_loss is not None: total_batch_loss += accumulated_aux_loss + # Extract scalar value — accumulated_loss may be a plain float if the + # minibatch loop was interrupted before any forward pass completed. + if isinstance(total_batch_loss, torch.Tensor): + loss_value = total_batch_loss.detach().item() + else: + loss_value = float(total_batch_loss) + # reduce across ranks avg_loss_across_ranks = self.accelerator.reduce( - torch.tensor( - total_batch_loss.detach().item(), device=self.accelerator.device - ), + torch.tensor(loss_value, device=self.accelerator.device), reduction="mean", ).item() diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 91f73008..a0cc802f 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -196,6 +196,22 @@ def train( checkpoint_job_id = os.environ.get("INSTRUCTLAB_ON_DEMAND_JOB_ID") base_logger.info("On-demand checkpointing is enabled in worker process.") + def _save_and_exit(checkpoint_location: str) -> None: + """Save an on-demand checkpoint and exit the training loop.""" + save_on_demand_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=model.tokenizer, + samples_seen=samples_seen, + epoch=epoch, + is_lora=bool(args.lora_r), + ) + base_logger.info( + "On-demand checkpoint saved (%s). Exiting training.", + checkpoint_location, + ) + # Mini_trainer approach: batch_size will be determined dynamically by data loader # For save logic, use effective_batch_size since that's the target samples_seen = 0 @@ -251,22 +267,14 @@ def train( # since we want to preserve the pre-step model state for # exact resumption. if batch_metrics.interrupted: - save_on_demand_checkpoint( - args=args, - accelerator=accelerator, - model=model, - tokenizer=model.tokenizer, - samples_seen=samples_seen, - epoch=epoch, - is_lora=bool(args.lora_r), - ) - base_logger.info( - "On-demand checkpoint saved. Exiting training gracefully." - ) + _save_and_exit("during minibatch processing") return - # Update samples seen - samples_seen += batch_metrics.total_samples + if on_demand_checkpointing and check_checkpoint_requested( + checkpoint_job_id + ): + _save_and_exit("before optimizer step") + return base_logger.info( f"Epoch: {epoch}, Step: {global_step}, Rank: {dist.get_rank()}, loss = {avg_loss_across_ranks:.6f}, grad_accum_steps = {batch_metrics.grad_accum_steps}" @@ -275,6 +283,15 @@ def train( # Take optimizer step after all minibatches accelerator.take_optimizer_step() + # Update samples seen after the optimizer step has been applied + samples_seen += batch_metrics.total_samples + + if on_demand_checkpointing and check_checkpoint_requested( + checkpoint_job_id + ): + _save_and_exit("after optimizer step") + return + if local_rank == 0: elapsed_time = time.time() - start overall_throughput = batch_metrics.total_samples / elapsed_time @@ -1163,9 +1180,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: default=False, help=( "Enable on-demand full-state checkpointing triggered by Unix signals. " - "When enabled, workers check for a trigger file in /dev/shm after each " - "minibatch backward pass and collectively save a distributed checkpoint before " - "exiting. Designed for OpenShift AI / KubeFlow preemption handling." + "When enabled, workers check for a trigger file in /dev/shm at five " + "synchronization points per step (before/after each minibatch forward " + "and backward pass, and before/after the optimizer step) and collectively " + "save a distributed checkpoint before exiting. Designed for OpenShift AI / " + "KubeFlow preemption handling." ), ) parser.add_argument( diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py index b9643531..20238ee5 100644 --- a/src/instructlab/training/on_demand_checkpoint.py +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -20,12 +20,28 @@ node** can see the file instantly with zero network I/O. **Worker processes** (torchrun children): - After every optimizer step the training loop calls - ``check_checkpoint_requested()``. Each rank checks its local ``/dev/shm`` - for the trigger file, converts the boolean to a tensor, and does an - ``all_reduce(MAX)`` so that if *any* rank on *any* node detected the - trigger, *every* rank agrees to save a checkpoint. This works correctly in - multi-node training because all_reduce is a global collective. + The training loop calls ``check_checkpoint_requested()`` at five + synchronization points per training step, allowing the system to + react as quickly as possible to termination signals: + + 1. **Before each minibatch forward pass** — no partial computation; + the current state is saved as-is. + 2. **Before each minibatch backward pass** — the forward result is + discarded; the pre-step state is saved. + 3. **After each minibatch backward pass** — gradients are computed but + not yet applied; the pre-step state is saved (gradients will be + recomputed on resume). + 4. **Before the optimizer step** — all minibatches are done and + gradients are ready, but the step is skipped; the pre-step state + is saved. + 5. **After the optimizer step** — the step has been applied; + ``samples_seen`` is updated and the post-step state is saved. + + Each rank checks its local ``/dev/shm`` for the trigger file, converts + the boolean to a tensor, and does an ``all_reduce(MAX)`` so that if + *any* rank on *any* node detected the trigger, *every* rank agrees to + save a checkpoint. This works correctly in multi-node training because + all_reduce is a global collective. Signals handled --------------- From 3794d5e4a5f06738db4bb29dbc8e8d84d789ebfa Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:01:31 +0000 Subject: [PATCH 07/14] Add stale trigger cleanup and exact mid-epoch resume Two fixes to the on-demand checkpointing feature: 1. Stale trigger file cleanup: ParentSignalHandler.install() now checks for and removes any existing trigger file before installing signal handlers. If the file exists before handlers are installed, it's from a previous run that was killed before workers could clean it up. Prevents a new training job from immediately checkpointing and exiting. 2. Exact mid-epoch resume: save_on_demand_checkpoint() now persists global_step in the checkpoint metadata alongside current_epoch and samples_seen. On resume, load_latest_full_state() detects the global_step field and sets last_step accordingly, so the training loop fast-forwards to the exact step within the epoch. Without this, mid-epoch checkpoints would skip to the next epoch on resume, losing remaining steps. Tested with Qwen2-1.5B-Instruct on 2 GPUs: interrupted at step 19/25, checkpoint saved with global_step=19, resumed and completed steps 20-25. --- src/instructlab/training/main_ds.py | 1 + .../training/on_demand_checkpoint.py | 28 ++++++++++++++++++- src/instructlab/training/utils.py | 26 +++++++++++++---- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index a0cc802f..5c3ac28d 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -205,6 +205,7 @@ def _save_and_exit(checkpoint_location: str) -> None: tokenizer=model.tokenizer, samples_seen=samples_seen, epoch=epoch, + global_step=global_step, is_lora=bool(args.lora_r), ) base_logger.info( diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py index 20238ee5..2d21a6d9 100644 --- a/src/instructlab/training/on_demand_checkpoint.py +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -179,6 +179,25 @@ def __init__(self, job_id: Optional[str] = None): def install(self) -> None: """Register signal handlers for all catchable signals.""" + # Clear any stale trigger file from a previous run. If the file + # exists before we've even installed signal handlers, it cannot + # be from this job — it's left over from a prior run that was + # killed before the workers could clean it up. + if trigger_file_exists(self.job_id): + logger.info( + "On-demand checkpoint: clearing stale trigger file from " + "a previous run (job_id=%s).", + self.job_id, + ) + try: + remove_trigger_file(self.job_id) + except Exception: + logger.warning( + "On-demand checkpoint: failed to remove stale trigger file, " + "but continuing anyway.", + exc_info=True, + ) + for sig in _CATCHABLE_SIGNALS: try: self._original_handlers[sig] = signal.getsignal(sig) @@ -264,6 +283,7 @@ def save_on_demand_checkpoint( tokenizer, samples_seen: int, epoch: int, + global_step: int, is_lora: bool, ) -> None: """Save a full-state distributed checkpoint for on-demand resume. @@ -271,6 +291,10 @@ def save_on_demand_checkpoint( This is a thin wrapper that calls the existing ``save_checkpoint`` utility with ``full_state=True`` so that optimizer + LR scheduler state are also persisted, enabling exact training resumption. + + The ``global_step`` is saved to the checkpoint metadata so that + on resume the training loop can fast-forward to the exact step + within the epoch where training was interrupted. """ # First Party from instructlab.training.utils import save_checkpoint @@ -279,8 +303,9 @@ def save_on_demand_checkpoint( if local_rank == 0: logger.info( "On-demand checkpoint: saving full-state checkpoint at " - "epoch=%d, samples_seen=%d", + "epoch=%d, global_step=%d, samples_seen=%d", epoch, + global_step, samples_seen, ) @@ -294,6 +319,7 @@ def save_on_demand_checkpoint( full_state=True, hf_format=True, epoch=epoch, + global_step=global_step, ) if local_rank == 0: diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index fc31858e..332adfb6 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -794,6 +794,7 @@ def save_checkpoint( epoch: int = None, hf_format: bool = True, full_state: bool = False, + global_step: int | None = None, ) -> None: if hf_format: save_hf_format_accelerate( @@ -812,10 +813,11 @@ def save_checkpoint( is_lora=is_lora, epoch=epoch, samples_seen=samples_seen, + global_step=global_step, ) -def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int): +def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int, global_step: int | None = None): """ Saves model, optimizer, and lr_scheduler state. TODO: save model config - decided not to do this. @@ -848,9 +850,11 @@ def _get_state_dict_patched(model, unwrap=False): # save metadata file for current training status if accelerator.is_main_process: - # TODO: should we set the global_step here rather than calculating global_step - # based on samples_seen? metadata = {"current_epoch": epoch, "samples_seen": samples_seen} + # Save global_step when provided (on-demand mid-epoch checkpoints) + # so that resume can fast-forward to the exact training step. + if global_step is not None: + metadata["global_step"] = global_step torch.save(metadata, output_dir / "training_metadata.json") log_rank_0(f"\033[93mSaving training state: {metadata}\033[0m", to_print=True) @@ -895,10 +899,22 @@ def load_latest_full_state(args, accelerator) -> None: f"\033[93mTraining metadata loaded: {training_metadata}\033[0m", to_print=True ) - # previous epoch is basis for current epoch. - args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1 args.__dict__["samples_seen"] = training_metadata["samples_seen"] + if "global_step" in training_metadata: + # On-demand mid-epoch checkpoint: resume at the same epoch and + # fast-forward to the exact step via last_step. + args.__dict__["current_epoch"] = training_metadata["current_epoch"] + args.__dict__["last_step"] = training_metadata["global_step"] + log_rank_0( + f"\033[93mResuming mid-epoch: epoch={args.current_epoch}, " + f"last_step={args.last_step}\033[0m", + to_print=True, + ) + else: + # Epoch-boundary checkpoint: start at the next epoch. + args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1 + def freeze_router_params(model: Model): """ From 7e75b62e19520559ceed4fdab7cb697cac04d361 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:06:41 +0000 Subject: [PATCH 08/14] Fix ruff formatting in utils.py --- src/instructlab/training/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 332adfb6..5a4bd8e0 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -817,7 +817,14 @@ def save_checkpoint( ) -def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int, global_step: int | None = None): +def save_full_state( + args, + accelerator, + is_lora: bool, + epoch: int, + samples_seen: int, + global_step: int | None = None, +): """ Saves model, optimizer, and lr_scheduler state. TODO: save model config - decided not to do this. From fb1c3434acab0db9f334a6a7bc919025979f03d4 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:07:24 +0000 Subject: [PATCH 09/14] Add documentation for on-demand checkpointing feature --- docs/on_demand_checkpointing.md | 139 ++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) create mode 100644 docs/on_demand_checkpointing.md diff --git a/docs/on_demand_checkpointing.md b/docs/on_demand_checkpointing.md new file mode 100644 index 00000000..b3cb9f52 --- /dev/null +++ b/docs/on_demand_checkpointing.md @@ -0,0 +1,139 @@ +# On-Demand Checkpointing + +On-demand checkpointing enables graceful checkpoint-and-exit when termination +signals are received during training. It is designed for environments like +OpenShift AI and KubeFlow where training jobs can be preempted at any time. + +## How It Works + +When enabled, the system installs signal handlers in the parent (launcher) +process that catch termination signals before the hard SIGKILL. When a signal +arrives: + +1. The parent writes a trigger file to `/dev/shm` (a fast, node-local tmpfs). +2. Worker processes check for the trigger file at multiple synchronization + points during each training step. +3. Workers coordinate via `all_reduce` so that if any rank on any node + detects the trigger, all ranks agree to save. +4. A full-state checkpoint (model + optimizer + LR scheduler) is saved. +5. Workers exit cleanly. + +On resume, the training loop detects the mid-epoch checkpoint, restores the +full training state, and fast-forwards to the exact step where training was +interrupted. + +## Signals Handled + +The following signals are intercepted (SIGKILL cannot be caught): + +| Signal | Source | +|--------|--------| +| SIGTERM | Kubernetes graceful shutdown (default) | +| SIGINT | Ctrl-C / some job controllers | +| SIGUSR1 | Custom preemption controllers | +| SIGUSR2 | Custom preemption controllers | +| SIGXCPU | CPU time limit exceeded (resource quotas) | +| SIGHUP | Terminal disconnect / some eviction paths | + +## Usage + +### Python API + +```python +from instructlab.training.config import TorchrunArgs, TrainingArgs +from instructlab.training import run_training + +torch_args = TorchrunArgs( + nproc_per_node=8, + nnodes=1, + node_rank=0, + rdzv_id=12345, + rdzv_endpoint="127.0.0.1:29500", +) + +train_args = TrainingArgs( + model_path="Qwen/Qwen2-1.5B-Instruct", + data_path="./data.jsonl", + data_output_dir="./processed", + ckpt_output_dir="./checkpoints", + num_epochs=3, + on_demand_checkpointing=True, # Enable the feature + # ... other training args +) + +run_training(torch_args, train_args) +``` + +### CLI + +```bash +torchrun --nproc-per-node=8 \ + instructlab/training/main_ds.py \ + --model_name_or_path Qwen/Qwen2-1.5B-Instruct \ + --data_path ./data.jsonl \ + --output_dir ./checkpoints \ + --on_demand_checkpointing \ + ... +``` + +## Resume Behavior + +When a checkpoint saved by on-demand checkpointing is found in the output +directory, the training loop automatically: + +1. Loads the full optimizer and LR scheduler state from the checkpoint. +2. Reads `global_step` from the checkpoint metadata to determine where + training was interrupted. +3. Resumes at the **same epoch** and fast-forwards to the exact step, + skipping already-completed batches. + +This differs from epoch-boundary checkpoints, which resume at the start of +the next epoch. + +### Checkpoint Metadata + +On-demand checkpoints store additional metadata compared to epoch-boundary +checkpoints: + +```json +{ + "current_epoch": 0, + "samples_seen": 144, + "global_step": 19 +} +``` + +The `global_step` field is what distinguishes an on-demand checkpoint from an +epoch-boundary one. When present, the resume logic keeps `current_epoch` +unchanged and sets `last_step = global_step` to enable fast-forwarding. + +## Multi-Node Training + +The trigger file mechanism works correctly across multiple nodes: + +- The trigger file lives on `/dev/shm`, which is node-local. Each node's + parent process writes its own trigger file when it receives a signal. +- Workers use `all_reduce(MAX)` to synchronize: if any rank on any node + detects a trigger, all ranks on all nodes agree to save. +- The checkpoint itself is saved to the shared filesystem (the configured + `ckpt_output_dir`), accessible by all nodes on resume. + +## Stale Trigger Files + +If a previous training run was killed before workers could clean up the +trigger file, the new run's `ParentSignalHandler` detects and removes it +during initialization. This prevents a new job from immediately +checkpointing and exiting due to a leftover trigger from a prior run. + +## Kubernetes / OpenShift Configuration + +To give workers enough time to save a checkpoint before the hard SIGKILL, +increase `terminationGracePeriodSeconds` in your pod spec: + +```yaml +spec: + terminationGracePeriodSeconds: 300 # 5 minutes +``` + +The default of 30 seconds may not be enough for large models. The checkpoint +save time depends on model size, number of GPUs, and filesystem speed. From 04298c485ebb19c08b6324f3e4ad6f43235bb65a Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:11:10 +0000 Subject: [PATCH 10/14] Fix sync point count in help text and docstring --- src/instructlab/training/main_ds.py | 9 ++++----- src/instructlab/training/on_demand_checkpoint.py | 7 ++++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 5c3ac28d..35064eb6 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -1181,11 +1181,10 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: default=False, help=( "Enable on-demand full-state checkpointing triggered by Unix signals. " - "When enabled, workers check for a trigger file in /dev/shm at five " - "synchronization points per step (before/after each minibatch forward " - "and backward pass, and before/after the optimizer step) and collectively " - "save a distributed checkpoint before exiting. Designed for OpenShift AI / " - "KubeFlow preemption handling." + "When enabled, workers check for a trigger file in /dev/shm at multiple " + "synchronization points (three times per minibatch and twice around the " + "optimizer step) and collectively save a distributed checkpoint before " + "exiting. Designed for OpenShift AI / KubeFlow preemption handling." ), ) parser.add_argument( diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py index 2d21a6d9..fa99b48d 100644 --- a/src/instructlab/training/on_demand_checkpoint.py +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -20,9 +20,10 @@ node** can see the file instantly with zero network I/O. **Worker processes** (torchrun children): - The training loop calls ``check_checkpoint_requested()`` at five - synchronization points per training step, allowing the system to - react as quickly as possible to termination signals: + The training loop calls ``check_checkpoint_requested()`` at multiple + synchronization points per training step (three per minibatch plus + two around the optimizer step), allowing the system to react as + quickly as possible to termination signals: 1. **Before each minibatch forward pass** — no partial computation; the current state is saved as-is. From d577f4729b7345b56fa4521cca4e2ae67203e61d Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:13:40 +0000 Subject: [PATCH 11/14] Document manual trigger file creation for on-demand checkpointing --- docs/on_demand_checkpointing.md | 46 +++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/docs/on_demand_checkpointing.md b/docs/on_demand_checkpointing.md index b3cb9f52..58b4555b 100644 --- a/docs/on_demand_checkpointing.md +++ b/docs/on_demand_checkpointing.md @@ -125,6 +125,52 @@ trigger file, the new run's `ParentSignalHandler` detects and removes it during initialization. This prevents a new job from immediately checkpointing and exiting due to a leftover trigger from a prior run. +## Manually Triggering a Checkpoint + +You can trigger a checkpoint-and-exit without sending a signal by writing +the trigger file directly. This is useful for debugging, testing, or +integration with custom orchestration that doesn't use Unix signals. + +The trigger file path is: + +``` +/dev/shm/instructlab_checkpoint_requested_ +``` + +Where `` is the `rdzv_id` passed to `TorchrunArgs`. If no job ID +was set, the path is `/dev/shm/instructlab_checkpoint_requested` (no suffix). + +To trigger a checkpoint from a shell on any node in the training cluster: + +```bash +# Find the job ID (it's the rdzv_id, also stored in the environment) +JOB_ID=$(printenv INSTRUCTLAB_ON_DEMAND_JOB_ID) + +# Write the trigger file +echo 1 > /dev/shm/instructlab_checkpoint_requested_${JOB_ID} +``` + +Or without the job ID: + +```bash +echo 1 > /dev/shm/instructlab_checkpoint_requested +``` + +Workers check for the trigger file at each synchronization point in the +training loop (multiple times per step). Once any rank on any node detects +it, all ranks coordinate via `all_reduce` to save a checkpoint and exit. + +You only need to write the file on **one node** — the `all_reduce` ensures +all nodes participate even if they don't see the file locally. + +From Python: + +```python +from instructlab.training.on_demand_checkpoint import write_trigger_file + +write_trigger_file(job_id="12345") # or job_id=None for default path +``` + ## Kubernetes / OpenShift Configuration To give workers enough time to save a checkpoint before the hard SIGKILL, From 563afbd53bda79a84e8aa8cc7eaedc9be01b4c82 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:15:23 +0000 Subject: [PATCH 12/14] Improve manual trigger docs: clarify job ID requirement --- docs/on_demand_checkpointing.md | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/docs/on_demand_checkpointing.md b/docs/on_demand_checkpointing.md index 58b4555b..70da36a5 100644 --- a/docs/on_demand_checkpointing.md +++ b/docs/on_demand_checkpointing.md @@ -131,29 +131,25 @@ You can trigger a checkpoint-and-exit without sending a signal by writing the trigger file directly. This is useful for debugging, testing, or integration with custom orchestration that doesn't use Unix signals. -The trigger file path is: - -``` -/dev/shm/instructlab_checkpoint_requested_ -``` - -Where `` is the `rdzv_id` passed to `TorchrunArgs`. If no job ID -was set, the path is `/dev/shm/instructlab_checkpoint_requested` (no suffix). - -To trigger a checkpoint from a shell on any node in the training cluster: +The trigger file lives in `/dev/shm` and is named using the job ID that +the training process was started with. To find the correct filename and +create the trigger: ```bash -# Find the job ID (it's the rdzv_id, also stored in the environment) -JOB_ID=$(printenv INSTRUCTLAB_ON_DEMAND_JOB_ID) +# Find the trigger filename for the running job — look for the job ID +# that was set when training started +ls /dev/shm/instructlab_checkpoint_requested* -# Write the trigger file -echo 1 > /dev/shm/instructlab_checkpoint_requested_${JOB_ID} +# Create the trigger file (use the exact name shown by ls above) +touch /dev/shm/instructlab_checkpoint_requested_ ``` -Or without the job ID: +If you don't know the job ID, you can read it from the training process +environment: ```bash -echo 1 > /dev/shm/instructlab_checkpoint_requested +# From inside the same pod / container where training is running +cat /proc/$(pgrep -f main_ds.py | head -1)/environ | tr '\0' '\n' | grep INSTRUCTLAB_ON_DEMAND_JOB_ID ``` Workers check for the trigger file at each synchronization point in the From 13a0e34914388f5cdccf670846cb9b4c725f4aa3 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:19:41 +0000 Subject: [PATCH 13/14] Simplify trigger file: remove job_id namespacing Drop the job_id suffix from the trigger file path. The file is now always /dev/shm/instructlab_checkpoint_requested with no suffix. The namespacing was defensive against concurrent jobs sharing /dev/shm, but in practice Kubernetes pods each get their own /dev/shm. This makes manual triggering trivial: touch /dev/shm/instructlab_checkpoint_requested --- docs/on_demand_checkpointing.md | 22 ++------- src/instructlab/training/main_ds.py | 19 ++------ .../training/on_demand_checkpoint.py | 46 +++++++------------ 3 files changed, 26 insertions(+), 61 deletions(-) diff --git a/docs/on_demand_checkpointing.md b/docs/on_demand_checkpointing.md index 70da36a5..b21d71d5 100644 --- a/docs/on_demand_checkpointing.md +++ b/docs/on_demand_checkpointing.md @@ -131,25 +131,11 @@ You can trigger a checkpoint-and-exit without sending a signal by writing the trigger file directly. This is useful for debugging, testing, or integration with custom orchestration that doesn't use Unix signals. -The trigger file lives in `/dev/shm` and is named using the job ID that -the training process was started with. To find the correct filename and -create the trigger: +The trigger file is always at a fixed path. To trigger a checkpoint +(e.g. via `kubectl exec` into the training pod): ```bash -# Find the trigger filename for the running job — look for the job ID -# that was set when training started -ls /dev/shm/instructlab_checkpoint_requested* - -# Create the trigger file (use the exact name shown by ls above) -touch /dev/shm/instructlab_checkpoint_requested_ -``` - -If you don't know the job ID, you can read it from the training process -environment: - -```bash -# From inside the same pod / container where training is running -cat /proc/$(pgrep -f main_ds.py | head -1)/environ | tr '\0' '\n' | grep INSTRUCTLAB_ON_DEMAND_JOB_ID +touch /dev/shm/instructlab_checkpoint_requested ``` Workers check for the trigger file at each synchronization point in the @@ -164,7 +150,7 @@ From Python: ```python from instructlab.training.on_demand_checkpoint import write_trigger_file -write_trigger_file(job_id="12345") # or job_id=None for default path +write_trigger_file() ``` ## Kubernetes / OpenShift Configuration diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 35064eb6..3cb8767b 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -185,7 +185,6 @@ def train( base_logger = logging.getLogger("instructlab.training") # Import on-demand checkpointing utilities once if the feature is enabled - checkpoint_job_id = None if on_demand_checkpointing: # First Party from instructlab.training.on_demand_checkpoint import ( @@ -193,7 +192,6 @@ def train( save_on_demand_checkpoint, ) - checkpoint_job_id = os.environ.get("INSTRUCTLAB_ON_DEMAND_JOB_ID") base_logger.info("On-demand checkpointing is enabled in worker process.") def _save_and_exit(checkpoint_location: str) -> None: @@ -255,7 +253,7 @@ def _save_and_exit(checkpoint_location: str) -> None: # the check runs after every minibatch backward rather than # waiting for the full optimizer step. _interrupt_check = ( - (lambda: check_checkpoint_requested(checkpoint_job_id)) + (lambda: check_checkpoint_requested()) if on_demand_checkpointing else None ) @@ -271,9 +269,7 @@ def _save_and_exit(checkpoint_location: str) -> None: _save_and_exit("during minibatch processing") return - if on_demand_checkpointing and check_checkpoint_requested( - checkpoint_job_id - ): + if on_demand_checkpointing and check_checkpoint_requested(): _save_and_exit("before optimizer step") return @@ -287,9 +283,7 @@ def _save_and_exit(checkpoint_location: str) -> None: # Update samples seen after the optimizer step has been applied samples_seen += batch_metrics.total_samples - if on_demand_checkpointing and check_checkpoint_requested( - checkpoint_job_id - ): + if on_demand_checkpointing and check_checkpoint_requested(): _save_and_exit("after optimizer step") return @@ -863,14 +857,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: # Use rdzv_id to namespace the trigger file so concurrent jobs # sharing /dev/shm don't interfere with each other. - checkpoint_job_id = str(torch_args.rdzv_id) - os.environ["INSTRUCTLAB_ON_DEMAND_JOB_ID"] = checkpoint_job_id - signal_handler = ParentSignalHandler(job_id=checkpoint_job_id) + signal_handler = ParentSignalHandler() signal_handler.install() logger.info( - "On-demand checkpointing is ENABLED (job_id=%s). " + "On-demand checkpointing is ENABLED. " "Termination signals will trigger a full-state checkpoint before exit.", - checkpoint_job_id, ) process = None diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py index fa99b48d..be1771b5 100644 --- a/src/instructlab/training/on_demand_checkpoint.py +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -93,23 +93,18 @@ _TRIGGER_FILENAME = "instructlab_checkpoint_requested" -def _get_trigger_path(job_id: Optional[str] = None) -> Path: - """Return the path to the checkpoint trigger file. - - An optional *job_id* can be supplied to avoid collisions if multiple - training jobs share the same ``/dev/shm`` (unlikely but possible). - """ - name = f"{_TRIGGER_FILENAME}_{job_id}" if job_id else _TRIGGER_FILENAME - return _TRIGGER_DIR / name +def _get_trigger_path() -> Path: + """Return the path to the checkpoint trigger file.""" + return _TRIGGER_DIR / _TRIGGER_FILENAME -def write_trigger_file(job_id: Optional[str] = None) -> Path: +def write_trigger_file() -> Path: """Create the trigger file that tells workers to checkpoint. This is called from the *parent* process signal handler. Returns the path that was written. """ - path = _get_trigger_path(job_id) + path = _get_trigger_path() # Use a atomic write via tempfile + rename to avoid partial reads. fd, tmp = tempfile.mkstemp(dir=_TRIGGER_DIR, prefix=".ckpt_trigger_") try: @@ -124,14 +119,14 @@ def write_trigger_file(job_id: Optional[str] = None) -> Path: return path -def trigger_file_exists(job_id: Optional[str] = None) -> bool: +def trigger_file_exists() -> bool: """Check whether the trigger file exists (worker-side).""" - return _get_trigger_path(job_id).exists() + return _get_trigger_path().exists() -def remove_trigger_file(job_id: Optional[str] = None) -> None: +def remove_trigger_file() -> None: """Remove the trigger file after the checkpoint has been saved.""" - path = _get_trigger_path(job_id) + path = _get_trigger_path() try: path.unlink(missing_ok=True) except OSError: @@ -165,15 +160,9 @@ class ParentSignalHandler: The handler is idempotent – multiple signals will not create multiple trigger files. - Parameters - ---------- - job_id : str, optional - Unique identifier for this training job. Used to namespace the - trigger file. """ - def __init__(self, job_id: Optional[str] = None): - self.job_id = job_id + def __init__(self): self.signal_received: Optional[signal.Signals] = None self._original_handlers: dict[signal.Signals, _SignalHandler] = {} self._trigger_written = False @@ -184,14 +173,13 @@ def install(self) -> None: # exists before we've even installed signal handlers, it cannot # be from this job — it's left over from a prior run that was # killed before the workers could clean it up. - if trigger_file_exists(self.job_id): + if trigger_file_exists(): logger.info( "On-demand checkpoint: clearing stale trigger file from " - "a previous run (job_id=%s).", - self.job_id, + "a previous run.", ) try: - remove_trigger_file(self.job_id) + remove_trigger_file() except Exception: logger.warning( "On-demand checkpoint: failed to remove stale trigger file, " @@ -233,7 +221,7 @@ def _handle(self, signum: int, _frame) -> None: self.signal_received = sig if not self._trigger_written: - write_trigger_file(self.job_id) + write_trigger_file() self._trigger_written = True @@ -242,7 +230,7 @@ def _handle(self, signum: int, _frame) -> None: # --------------------------------------------------------------------------- -def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: +def check_checkpoint_requested() -> bool: """Check across all ranks whether an on-demand checkpoint was requested. This function must be called by **all ranks** at the same point in the @@ -251,7 +239,7 @@ def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: Returns ``True`` if any rank detected the trigger file, meaning all ranks should save a checkpoint. """ - local_trigger = trigger_file_exists(job_id) + local_trigger = trigger_file_exists() # Convert to a tensor and all-reduce (MAX) so that if ANY rank on ANY # node saw the trigger, every rank gets True. @@ -272,7 +260,7 @@ def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: ) # Clean up the trigger file so that if the process somehow # continues, we don't save again immediately. - remove_trigger_file(job_id) + remove_trigger_file() return requested From a9b8d8cfce0e709181f8195d770adc398d1cc4ec Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Tue, 24 Mar 2026 17:49:22 +0000 Subject: [PATCH 14/14] Fix pylint: remove unnecessary lambda wrapper --- src/instructlab/training/main_ds.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 3cb8767b..6a4813d1 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -253,9 +253,7 @@ def _save_and_exit(checkpoint_location: str) -> None: # the check runs after every minibatch backward rather than # waiting for the full optimizer step. _interrupt_check = ( - (lambda: check_checkpoint_requested()) - if on_demand_checkpointing - else None + check_checkpoint_requested if on_demand_checkpointing else None ) batch_metrics, avg_loss_across_ranks = batch_loss_manager.process_batch( batch, interrupt_check=_interrupt_check