Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions src/instructlab/training/batch_loss_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
"""

# Standard
from dataclasses import dataclass
from collections.abc import Callable
from dataclasses import dataclass, field
import logging

# Third Party
Expand All @@ -33,6 +34,7 @@ class BatchMetrics:
accumulated_aux_loss: torch.Tensor | None
grad_accum_steps: int
num_minibatches: int
interrupted: bool = field(default=False)


class BatchLossManager:
Expand Down Expand Up @@ -62,12 +64,22 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int):
self.local_rank: int = local_rank
self.torch_device = torch.device("cuda", local_rank)

def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]:
def process_batch(
self,
batch: list[CollatedItem],
interrupt_check: Callable[[], bool] | None = None,
) -> tuple[BatchMetrics, float]:
"""
Process a batch of minibatches, computing losses and accumulating gradients.

Args:
batch: List of minibatches to process
interrupt_check: Optional callback invoked at three points per
minibatch: before forward, before backward, and after
backward. If it returns ``True`` at any point, processing
stops early and ``BatchMetrics.interrupted`` is set. Used by
on-demand checkpointing to react as quickly as possible
instead of waiting for the full optimizer step.

Returns:
tuple: (BatchMetrics, average_loss_across_ranks)
Expand All @@ -82,9 +94,15 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
accumulated_loss = 0.0
accumulated_aux_loss = 0.0
grad_accum_steps = 0
interrupted = False

# process each minibatch
for mb in batch:
# Check for on-demand checkpoint before forward
if interrupt_check is not None and interrupt_check():
interrupted = True
break

# extract minibatch-specific info
micro_batch_size = mb["num_samples"]
total_length = mb["total_length"]
Expand All @@ -96,10 +114,16 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
# prepare model inputs
model_inputs = self._prepare_model_inputs(mb)

# compute loss and backward pass
# compute loss (forward pass)
scaled_loss, raw_losses = self.model.compute_loss(
model_inputs, self.world_size, batch_num_loss_counted_tokens
)

# Check for on-demand checkpoint before backward
if interrupt_check is not None and interrupt_check():
interrupted = True
break

self.accelerator.backward(scaled_loss)

# accumulate losses
Expand All @@ -108,6 +132,11 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
if raw_losses.aux_loss is not None:
accumulated_aux_loss += raw_losses.aux_loss

# Check for on-demand checkpoint after backward
if interrupt_check is not None and interrupt_check():
interrupted = True
break

# reduce metrics across ranks
batch_total_samples, batch_total_length = self._reduce_metrics(
batch_total_samples, batch_total_length
Expand All @@ -127,6 +156,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
accumulated_aux_loss=accumulated_aux_loss,
grad_accum_steps=grad_accum_steps,
num_minibatches=num_minibatches,
interrupted=interrupted,
)

return metrics, avg_loss_across_ranks
Expand Down Expand Up @@ -165,8 +195,8 @@ def _reduce_metrics(

def _compute_average_loss(
self,
accumulated_loss: torch.Tensor,
accumulated_aux_loss: torch.Tensor | None,
accumulated_loss: torch.Tensor | float,
accumulated_aux_loss: torch.Tensor | float | None,
batch_num_loss_counted_tokens: int,
) -> float:
"""Compute average loss across all ranks for metrics logging."""
Expand All @@ -177,11 +207,16 @@ def _compute_average_loss(
if accumulated_aux_loss is not None:
total_batch_loss += accumulated_aux_loss

# Extract scalar value — accumulated_loss may be a plain float if the
# minibatch loop was interrupted before any forward pass completed.
if isinstance(total_batch_loss, torch.Tensor):
loss_value = total_batch_loss.detach().item()
else:
loss_value = float(total_batch_loss)

# reduce across ranks
avg_loss_across_ranks = self.accelerator.reduce(
torch.tensor(
total_batch_loss.detach().item(), device=self.accelerator.device
),
torch.tensor(loss_value, device=self.accelerator.device),
reduction="mean",
).item()

Expand Down
13 changes: 13 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,19 @@ class TrainingArgs(BaseModel):
description="How often to evaluate validation loss (in training steps). Required when validation_split > 0.",
)

on_demand_checkpointing: bool = Field(
default=False,
description=(
"Enable on-demand full-state checkpointing triggered by Unix signals. "
"When enabled, the parent process intercepts termination signals "
"(SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, SIGHUP) and writes a "
"trigger file to /dev/shm. Worker processes check for this trigger "
"after each minibatch backward pass and collectively save a distributed "
"checkpoint before exiting gracefully. Designed for OpenShift AI / "
"KubeFlow training jobs where preemption signals must be handled."
),
)

@model_validator(mode="after")
def validate_validation_config(self):
if not 0.0 <= self.validation_split < 1.0:
Expand Down
Loading
Loading