diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index cc6da021..46e2af30 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -7,7 +7,8 @@ """ # Standard -from dataclasses import dataclass +from collections.abc import Callable +from dataclasses import dataclass, field import logging # Third Party @@ -33,6 +34,7 @@ class BatchMetrics: accumulated_aux_loss: torch.Tensor | None grad_accum_steps: int num_minibatches: int + interrupted: bool = field(default=False) class BatchLossManager: @@ -62,12 +64,22 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int): self.local_rank: int = local_rank self.torch_device = torch.device("cuda", local_rank) - def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]: + def process_batch( + self, + batch: list[CollatedItem], + interrupt_check: Callable[[], bool] | None = None, + ) -> tuple[BatchMetrics, float]: """ Process a batch of minibatches, computing losses and accumulating gradients. Args: batch: List of minibatches to process + interrupt_check: Optional callback invoked at three points per + minibatch: before forward, before backward, and after + backward. If it returns ``True`` at any point, processing + stops early and ``BatchMetrics.interrupted`` is set. Used by + on-demand checkpointing to react as quickly as possible + instead of waiting for the full optimizer step. Returns: tuple: (BatchMetrics, average_loss_across_ranks) @@ -82,9 +94,15 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] accumulated_loss = 0.0 accumulated_aux_loss = 0.0 grad_accum_steps = 0 + interrupted = False # process each minibatch for mb in batch: + # Check for on-demand checkpoint before forward + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + # extract minibatch-specific info micro_batch_size = mb["num_samples"] total_length = mb["total_length"] @@ -96,10 +114,16 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] # prepare model inputs model_inputs = self._prepare_model_inputs(mb) - # compute loss and backward pass + # compute loss (forward pass) scaled_loss, raw_losses = self.model.compute_loss( model_inputs, self.world_size, batch_num_loss_counted_tokens ) + + # Check for on-demand checkpoint before backward + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + self.accelerator.backward(scaled_loss) # accumulate losses @@ -108,6 +132,11 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] if raw_losses.aux_loss is not None: accumulated_aux_loss += raw_losses.aux_loss + # Check for on-demand checkpoint after backward + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + # reduce metrics across ranks batch_total_samples, batch_total_length = self._reduce_metrics( batch_total_samples, batch_total_length @@ -127,6 +156,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] accumulated_aux_loss=accumulated_aux_loss, grad_accum_steps=grad_accum_steps, num_minibatches=num_minibatches, + interrupted=interrupted, ) return metrics, avg_loss_across_ranks @@ -165,8 +195,8 @@ def _reduce_metrics( def _compute_average_loss( self, - accumulated_loss: torch.Tensor, - accumulated_aux_loss: torch.Tensor | None, + accumulated_loss: torch.Tensor | float, + accumulated_aux_loss: torch.Tensor | float | None, batch_num_loss_counted_tokens: int, ) -> float: """Compute average loss across all ranks for metrics logging.""" @@ -177,11 +207,16 @@ def _compute_average_loss( if accumulated_aux_loss is not None: total_batch_loss += accumulated_aux_loss + # Extract scalar value — accumulated_loss may be a plain float if the + # minibatch loop was interrupted before any forward pass completed. + if isinstance(total_batch_loss, torch.Tensor): + loss_value = total_batch_loss.detach().item() + else: + loss_value = float(total_batch_loss) + # reduce across ranks avg_loss_across_ranks = self.accelerator.reduce( - torch.tensor( - total_batch_loss.detach().item(), device=self.accelerator.device - ), + torch.tensor(loss_value, device=self.accelerator.device), reduction="mean", ).item() diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 911c3898..11182a0e 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -351,6 +351,19 @@ class TrainingArgs(BaseModel): description="How often to evaluate validation loss (in training steps). Required when validation_split > 0.", ) + on_demand_checkpointing: bool = Field( + default=False, + description=( + "Enable on-demand full-state checkpointing triggered by Unix signals. " + "When enabled, the parent process intercepts termination signals " + "(SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, SIGHUP) and writes a " + "trigger file to /dev/shm. Worker processes check for this trigger " + "after each minibatch backward pass and collectively save a distributed " + "checkpoint before exiting gracefully. Designed for OpenShift AI / " + "KubeFlow training jobs where preemption signals must be handled." + ), + ) + @model_validator(mode="after") def validate_validation_config(self): if not 0.0 <= self.validation_split < 1.0: diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 072b27c6..a0cc802f 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -173,6 +173,7 @@ def train( accelerator: Accelerator, val_data_loader=None, validation_frequency=None, + on_demand_checkpointing: bool = False, ): model.train() @@ -183,6 +184,34 @@ def train( metric_logger = logging.getLogger("instructlab.training.metrics") base_logger = logging.getLogger("instructlab.training") + # Import on-demand checkpointing utilities once if the feature is enabled + checkpoint_job_id = None + if on_demand_checkpointing: + # First Party + from instructlab.training.on_demand_checkpoint import ( + check_checkpoint_requested, + save_on_demand_checkpoint, + ) + + checkpoint_job_id = os.environ.get("INSTRUCTLAB_ON_DEMAND_JOB_ID") + base_logger.info("On-demand checkpointing is enabled in worker process.") + + def _save_and_exit(checkpoint_location: str) -> None: + """Save an on-demand checkpoint and exit the training loop.""" + save_on_demand_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=model.tokenizer, + samples_seen=samples_seen, + epoch=epoch, + is_lora=bool(args.lora_r), + ) + base_logger.info( + "On-demand checkpoint saved (%s). Exiting training.", + checkpoint_location, + ) + # Mini_trainer approach: batch_size will be determined dynamically by data loader # For save logic, use effective_batch_size since that's the target samples_seen = 0 @@ -220,13 +249,32 @@ def train( continue start = time.time() - # Process the batch using the BatchLossManager + # Process the batch using the BatchLossManager. + # When on-demand checkpointing is enabled, pass a callback so + # the check runs after every minibatch backward rather than + # waiting for the full optimizer step. + _interrupt_check = ( + (lambda: check_checkpoint_requested(checkpoint_job_id)) + if on_demand_checkpointing + else None + ) batch_metrics, avg_loss_across_ranks = batch_loss_manager.process_batch( - batch + batch, interrupt_check=_interrupt_check ) - # Update samples seen - samples_seen += batch_metrics.total_samples + # If the batch was interrupted by an on-demand checkpoint + # request, save immediately and exit — skip the optimizer step + # since we want to preserve the pre-step model state for + # exact resumption. + if batch_metrics.interrupted: + _save_and_exit("during minibatch processing") + return + + if on_demand_checkpointing and check_checkpoint_requested( + checkpoint_job_id + ): + _save_and_exit("before optimizer step") + return base_logger.info( f"Epoch: {epoch}, Step: {global_step}, Rank: {dist.get_rank()}, loss = {avg_loss_across_ranks:.6f}, grad_accum_steps = {batch_metrics.grad_accum_steps}" @@ -235,6 +283,15 @@ def train( # Take optimizer step after all minibatches accelerator.take_optimizer_step() + # Update samples seen after the optimizer step has been applied + samples_seen += batch_metrics.total_samples + + if on_demand_checkpointing and check_checkpoint_requested( + checkpoint_job_id + ): + _save_and_exit("after optimizer step") + return + if local_rank == 0: elapsed_time = time.time() - start overall_throughput = batch_metrics.total_samples / elapsed_time @@ -561,6 +618,7 @@ def main(args): accelerator=accelerator, val_data_loader=val_loader, validation_frequency=validation_frequency, + on_demand_checkpointing=getattr(args, "on_demand_checkpointing", False), ) dist.barrier() @@ -791,7 +849,29 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.keep_last_checkpoint_only: command.append("--keep_last_checkpoint_only") + if train_args.on_demand_checkpointing: + command.append("--on_demand_checkpointing") + logger.info("Running training command as subprocess: %s", " ".join(command)) + + # --- On-demand checkpointing: install signal handlers in the parent --- + signal_handler = None + if train_args.on_demand_checkpointing: + # First Party + from instructlab.training.on_demand_checkpoint import ParentSignalHandler + + # Use rdzv_id to namespace the trigger file so concurrent jobs + # sharing /dev/shm don't interfere with each other. + checkpoint_job_id = str(torch_args.rdzv_id) + os.environ["INSTRUCTLAB_ON_DEMAND_JOB_ID"] = checkpoint_job_id + signal_handler = ParentSignalHandler(job_id=checkpoint_job_id) + signal_handler.install() + logger.info( + "On-demand checkpointing is ENABLED (job_id=%s). " + "Termination signals will trigger a full-state checkpoint before exit.", + checkpoint_job_id, + ) + process = None interrupt: KeyboardInterrupt | Exception | None = None failure = False @@ -811,36 +891,85 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: interrupt = e finally: if "process" not in locals() or process is None: + if signal_handler is not None: + signal_handler.uninstall() return - # wait for the process to exit so we can properly read the exit code - process.wait(timeout=60) - process_code = process.poll() - failure = process_code != 0 - - if not failure: - logger.info("Operation completed successfully! 🎉") - else: - logger.error( - f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {process_code}" + # If a signal was caught by the on-demand checkpoint handler, give + # the workers time to detect the trigger file and save a checkpoint + # before we start sending our own signals to the subprocess. + if signal_handler is not None and signal_handler.signal_received is not None: + logger.info( + "On-demand checkpoint: signal %s received. Waiting for workers to " + "save checkpoint before proceeding with shutdown...", + signal_handler.signal_received.name, ) + # Give workers generous time to complete the checkpoint save. + # The workers will exit on their own after saving. + try: + process.wait(timeout=300) + except subprocess.TimeoutExpired: + logger.warning( + "On-demand checkpoint: workers did not finish within 300s. " + "Proceeding with shutdown." + ) - process.terminate() + # wait for the process to exit so we can properly read the exit code try: - logger.info("Waiting for process to exit, 60s...") process.wait(timeout=60) except subprocess.TimeoutExpired: + pass + process_code = process.poll() + + if process_code is not None and process_code == 0: + logger.info("Operation completed successfully!") + elif process_code is None: + logger.error("Training subprocess has not exited yet. Sending SIGTERM.") + process.terminate() + try: + logger.info("Waiting for process to exit, 60s...") + process.wait(timeout=60) + except subprocess.TimeoutExpired: + logger.error( + "Training subprocess did not terminate before timeout, sending SIGKILL." + ) + process.kill() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + pass + else: logger.error( - "Training subprocess did not terminate before timeout, sending SIGKILL." + "Training subprocess exited with code %d.", + process_code, ) - process.kill() + + # Recompute final exit status after any forced shutdown + process_code = process.poll() + failure = process_code is None or process_code != 0 + + if signal_handler is not None: + signal_handler.uninstall() if interrupt: raise interrupt if failure: - raise RuntimeError( - "Suffered a failure during distributed training. Please see the training logs for more context." - ) + msg = "Suffered a failure during distributed training. Please see the training logs for more context." + if ( + signal_handler is not None + and signal_handler.signal_received is not None + ): + msg += ( + f"\n\nNote: signal {signal_handler.signal_received.name} was" + " received and on-demand checkpointing was enabled, but the" + " training subprocess did not exit cleanly. This usually" + " means the process was killed (SIGKILL) before the" + " checkpoint could be saved. To fix this, increase" + " terminationGracePeriodSeconds in your pod spec to give" + " workers more time, or reduce the model's forward/backward" + " pass time so the checkpoint check fires sooner." + ) + raise RuntimeError(msg) if __name__ == "__main__": @@ -1045,6 +1174,19 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ), ) + parser.add_argument( + "--on_demand_checkpointing", + action="store_true", + default=False, + help=( + "Enable on-demand full-state checkpointing triggered by Unix signals. " + "When enabled, workers check for a trigger file in /dev/shm at five " + "synchronization points per step (before/after each minibatch forward " + "and backward pass, and before/after the optimizer step) and collectively " + "save a distributed checkpoint before exiting. Designed for OpenShift AI / " + "KubeFlow preemption handling." + ), + ) parser.add_argument( "--use_liger", action="store_true", diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py new file mode 100644 index 00000000..20238ee5 --- /dev/null +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -0,0 +1,300 @@ +# SPDX-License-Identifier: Apache-2.0 + +""" +On-demand checkpointing for distributed training. + +This module enables graceful checkpoint-and-exit when termination signals are +received. It is designed for environments like OpenShift AI / KubeFlow where +training jobs can be preempted at any time and the platform sends Unix signals +before killing the pod. + +Architecture +------------ +There are two sides to this feature: + +**Parent process** (``run_training`` in ``main_ds.py``): + Installs signal handlers that catch every signal OpenShift / Kubernetes can + send before a SIGKILL. When a signal arrives the handler writes a small + *trigger file* to ``/dev/shm`` (a tmpfs shared between containers in the + same pod). Because ``/dev/shm`` is node-local, every worker on the **same + node** can see the file instantly with zero network I/O. + +**Worker processes** (torchrun children): + The training loop calls ``check_checkpoint_requested()`` at five + synchronization points per training step, allowing the system to + react as quickly as possible to termination signals: + + 1. **Before each minibatch forward pass** — no partial computation; + the current state is saved as-is. + 2. **Before each minibatch backward pass** — the forward result is + discarded; the pre-step state is saved. + 3. **After each minibatch backward pass** — gradients are computed but + not yet applied; the pre-step state is saved (gradients will be + recomputed on resume). + 4. **Before the optimizer step** — all minibatches are done and + gradients are ready, but the step is skipped; the pre-step state + is saved. + 5. **After the optimizer step** — the step has been applied; + ``samples_seen`` is updated and the post-step state is saved. + + Each rank checks its local ``/dev/shm`` for the trigger file, converts + the boolean to a tensor, and does an ``all_reduce(MAX)`` so that if + *any* rank on *any* node detected the trigger, *every* rank agrees to + save a checkpoint. This works correctly in multi-node training because + all_reduce is a global collective. + +Signals handled +--------------- +We intercept every signal that Kubernetes / OpenShift can deliver before the +hard SIGKILL (which cannot be caught): + +* **SIGTERM** – the standard graceful-shutdown signal. Kubernetes sends this + first (configurable via ``terminationGracePeriodSeconds``). +* **SIGINT** – sent on Ctrl-C or by some job controllers. +* **SIGUSR1 / SIGUSR2** – commonly used by batch schedulers and custom + preemption controllers to signal upcoming eviction. +* **SIGXCPU** – sent when CPU time limits are exceeded (relevant for jobs + with resource quotas). +* **SIGHUP** – sent when the controlling terminal disconnects; some + container runtimes forward this on pod eviction. +""" + +# Standard +from pathlib import Path +from typing import Callable, Optional, Union +import logging +import os +import signal +import tempfile +import types + +# Third Party +import torch +import torch.distributed as dist + +# Type alias matching the return type of signal.getsignal(). +_SignalHandler = Union[ + Callable[[int, Optional[types.FrameType]], None], int, signal.Handlers, None +] + +logger = logging.getLogger("instructlab.training") + +# --------------------------------------------------------------------------- +# Trigger file helpers +# --------------------------------------------------------------------------- + +# The trigger file lives in /dev/shm which is a tmpfs (RAM-backed filesystem). +# It is: +# 1. Extremely fast (no disk I/O). +# 2. Shared between all containers in the same Kubernetes pod. +# 3. Automatically cleaned up when the pod is destroyed. +_TRIGGER_DIR = Path("/dev/shm") +_TRIGGER_FILENAME = "instructlab_checkpoint_requested" + + +def _get_trigger_path(job_id: Optional[str] = None) -> Path: + """Return the path to the checkpoint trigger file. + + An optional *job_id* can be supplied to avoid collisions if multiple + training jobs share the same ``/dev/shm`` (unlikely but possible). + """ + name = f"{_TRIGGER_FILENAME}_{job_id}" if job_id else _TRIGGER_FILENAME + return _TRIGGER_DIR / name + + +def write_trigger_file(job_id: Optional[str] = None) -> Path: + """Create the trigger file that tells workers to checkpoint. + + This is called from the *parent* process signal handler. + Returns the path that was written. + """ + path = _get_trigger_path(job_id) + # Use a atomic write via tempfile + rename to avoid partial reads. + fd, tmp = tempfile.mkstemp(dir=_TRIGGER_DIR, prefix=".ckpt_trigger_") + try: + os.write(fd, b"1") + finally: + os.close(fd) + os.rename(tmp, path) + logger.info( + "On-demand checkpoint trigger file written: %s", + path, + ) + return path + + +def trigger_file_exists(job_id: Optional[str] = None) -> bool: + """Check whether the trigger file exists (worker-side).""" + return _get_trigger_path(job_id).exists() + + +def remove_trigger_file(job_id: Optional[str] = None) -> None: + """Remove the trigger file after the checkpoint has been saved.""" + path = _get_trigger_path(job_id) + try: + path.unlink(missing_ok=True) + except OSError: + pass + + +# --------------------------------------------------------------------------- +# Parent-side signal handling +# --------------------------------------------------------------------------- + +# Signals that OpenShift / Kubernetes / batch schedulers may send before +# the hard SIGKILL. SIGKILL (9) and SIGSTOP (19) cannot be caught. +_CATCHABLE_SIGNALS = ( + signal.SIGTERM, # Kubernetes default graceful shutdown signal + signal.SIGINT, # Ctrl-C / some job controllers + signal.SIGUSR1, # Custom preemption controllers + signal.SIGUSR2, # Custom preemption controllers + signal.SIGXCPU, # CPU time limit exceeded (resource quotas) + signal.SIGHUP, # Terminal disconnect / some eviction paths +) + + +class ParentSignalHandler: + """Installs signal handlers in the parent (launcher) process. + + When any of the catchable signals fire, the handler: + 1. Writes the trigger file to ``/dev/shm``. + 2. Records that a signal was received (so the caller can decide to + wait for the child process to finish checkpointing). + + The handler is idempotent – multiple signals will not create multiple + trigger files. + + Parameters + ---------- + job_id : str, optional + Unique identifier for this training job. Used to namespace the + trigger file. + """ + + def __init__(self, job_id: Optional[str] = None): + self.job_id = job_id + self.signal_received: Optional[signal.Signals] = None + self._original_handlers: dict[signal.Signals, _SignalHandler] = {} + self._trigger_written = False + + def install(self) -> None: + """Register signal handlers for all catchable signals.""" + for sig in _CATCHABLE_SIGNALS: + try: + self._original_handlers[sig] = signal.getsignal(sig) + signal.signal(sig, self._handle) + except (OSError, ValueError): + # Some signals may not be available on all platforms + logger.debug("Could not install handler for %s", sig.name) + + logger.info( + "On-demand checkpoint signal handlers installed for: %s", + ", ".join(s.name for s in self._original_handlers), + ) + + def uninstall(self) -> None: + """Restore original signal handlers.""" + for sig, handler in self._original_handlers.items(): + try: + signal.signal(sig, handler) # type: ignore[arg-type] + except (OSError, ValueError): + pass + self._original_handlers.clear() + + def _handle(self, signum: int, _frame) -> None: + """Signal handler callback.""" + sig = signal.Signals(signum) + logger.info( + "On-demand checkpoint: received signal %s (%d). " + "Writing trigger file for workers to checkpoint before exit.", + sig.name, + signum, + ) + self.signal_received = sig + + if not self._trigger_written: + write_trigger_file(self.job_id) + self._trigger_written = True + + +# --------------------------------------------------------------------------- +# Worker-side synchronization +# --------------------------------------------------------------------------- + + +def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: + """Check across all ranks whether an on-demand checkpoint was requested. + + This function must be called by **all ranks** at the same point in the + training loop (it contains a collective all_reduce). + + Returns ``True`` if any rank detected the trigger file, meaning all + ranks should save a checkpoint. + """ + local_trigger = trigger_file_exists(job_id) + + # Convert to a tensor and all-reduce (MAX) so that if ANY rank on ANY + # node saw the trigger, every rank gets True. + trigger_tensor = torch.tensor( + [1 if local_trigger else 0], + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + dist.all_reduce(trigger_tensor, op=dist.ReduceOp.MAX) + + requested = trigger_tensor.item() > 0 + + if requested: + if dist.is_initialized() and dist.get_rank() == 0: + logger.info( + "On-demand checkpoint: global consensus reached – " + "all ranks will save a checkpoint." + ) + # Clean up the trigger file so that if the process somehow + # continues, we don't save again immediately. + remove_trigger_file(job_id) + + return requested + + +def save_on_demand_checkpoint( + args, + accelerator, + model, + tokenizer, + samples_seen: int, + epoch: int, + is_lora: bool, +) -> None: + """Save a full-state distributed checkpoint for on-demand resume. + + This is a thin wrapper that calls the existing ``save_checkpoint`` + utility with ``full_state=True`` so that optimizer + LR scheduler + state are also persisted, enabling exact training resumption. + """ + # First Party + from instructlab.training.utils import save_checkpoint + + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if local_rank == 0: + logger.info( + "On-demand checkpoint: saving full-state checkpoint at " + "epoch=%d, samples_seen=%d", + epoch, + samples_seen, + ) + + save_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=tokenizer, + samples_seen=samples_seen, + is_lora=is_lora, + full_state=True, + hf_format=True, + epoch=epoch, + ) + + if local_rank == 0: + logger.info("On-demand checkpoint: checkpoint saved successfully.")