Skip to content

Add Callbacks to Instructlab Training #694

@RobotSail

Description

@RobotSail

Overview

Goal: Add Callbacks to Instructlab-Training

Why:

We want to facilitate use-cases such as:

  • initiating eval jobs within a GPU cluster on checkpoint save
  • ending the training loop early based on some condition

In order to implement this, we would add features in the following order:

  1. implement simple, read-only callbacks
  2. implement callbacks that can control flow of training loop (e.g. stopping the loop early, initiating a checkpoint, etc.)

Related Issues:

Expected usage

Users calling the run_training API would directly provide their callbacks to the TrainingArgs object through a flat API.

As an example:

run_training(
  on_checkpoint=run_eval,
  on_validation_loss=end_if_diverged,
  # ... etc
)

Implementation Details

To implement callbacks in this training library, we have to worry about two major pieces:

  1. Taking callbacks that were provided to the run_training API and passing them to the torchrun subprocesses which run the actual training loop
  2. Receiving the callbacks inside of the training loop (child process of torchrun) and executing them at the appropriate points without interfering with the training process

We will eventually also want to implement callbacks which can modify what actions are performed by the training loop itself, however this will require some additional effort to ensure proper synchronization of the global process group. These callbacks will not be covered in this issue though.

The callback API

We can define a general callback as a generic function which accepts some general context as well as a list of args + kwargs:

def my_callback(context, **kwargs) -> None:
  # do stuff
  return

Callbacks may also have arguments for certain events, e.g. on_save may have the following interface:

def on_save(context, checkpoint_path: str, **kwargs) -> None:
  # do stuff
  return

While the on_evaluate callback would look like this:

def on_evaluate(context, validation_loss: float, **kwargs) -> None:
  # do stuff
  return

We expect these callbacks to be entirely self-contained, so that any and all packages are imported inside of the callback function and must be present within whatever venv the training loop runs in.

Additionally, since we expect the initial callbacks to be read-only and inconsequential to the training loop, the training loop will wrap them as asynchronous tasks which run on the global event loop. For this reason, callbacks should be careful with their usage of async functionality.

Warning

To prevent any unexpected behavior or complex testing scenarios, the callbacks will not be allowed to have their exceptions propagate into the training loop, therefore any exceptions must be handled by the callback or they will fail silently.

Adding callbacks to TrainingArgs

To maintain a simple interface, the TrainingArgs Pydantic model (in src/instructlab/training/config.py) would expose registration for each callback with a flat interface like this:

from instructlab.training.config import TrainingArgs

args = TrainingArgs(
  on_save=evaluate_gsm8k,
  on_evaluate=save_if_improved,
  ...
)

# then we call training
run_training(args, torch_args)

The API for these event hooks should also allow providing multiple callbacks. You would see a definition like this:

class TrainingArgs(pydantic.BaseModel):
  # ...
  on_save: typing.Callable | list[typing.Callable] | None = Field(None, ...)

Passing callbacks to torchrun procs

In order to execute callbacks during training, they must first be passed between the main process where run_training is called and the child processes that are spawned to run the training loop in parallel. When run_training is first called, it uses the subprocess.Popen API (via StreamablePopen from utils.py) in order to invoke torchrun which then creates a worker for each GPU. The training loop itself expects to be invoked as a CLI, so run_training maps all of the fields in the TrainingArgs to the CLI flags expected by the actual training script. You can see this in action here: https://github.com/instructlab/training/blob/main/src/instructlab/training/main_ds.py#L624-L801

Therefore we can provide these callbacks to our training script by first serializing them with base64, passing them as string arguments to the CLI, then unserializing and eval-ing each callback to register it in the child's memory space. As long as the callback is self-contained, it should execute on the child process just as it would in the main process.

Registering and executing callbacks

Once they're unserialized, each callback will be registered on the global main/rank-0 process only by a callback manager.
This is necessary to avoid race conditions or a flood of evaluation requests from each process after every checkpoint save. We can add this functionality later, but we should avoid it to start so the initial implementation can be straightforward.

In order to manage callbacks, we should have a class like CallbacksManager which registers the callbacks and is responsible for providing them with the relevant context when they're called.

The callback manager would have an on_<event> method for each of the following events:

  • on_train_begin: Called after model.train() and state initialization, before the epoch loop (~line 200 in main_ds.py:train())
  • on_epoch_begin: Called at the top of the epoch for loop, after sampler.set_epoch(epoch) (~line 208)
  • on_step_begin: Called at the beginning of each training step, after start = time.time() (~line 221)
  • on_before_forward: Called inside BatchLossManager.process_batch(), before model.compute_loss() (in batch_loss_manager.py)
  • on_after_forward: Called inside BatchLossManager.process_batch(), after model.compute_loss() returns (scaled_loss, raw_losses), before backward
  • on_before_backward: Called inside BatchLossManager.process_batch(), before accelerator.backward(scaled_loss)
  • on_after_backward: Called inside BatchLossManager.process_batch(), after accelerator.backward(scaled_loss)
  • on_before_optimizer_step: Called before accelerator.take_optimizer_step() (~line 236), which internally does clip_grad_norm_, optimizer.step(), lr_scheduler.step(), optimizer.zero_grad()
  • on_after_optimizer_step: Called after accelerator.take_optimizer_step() (~line 236)
  • on_log: Called around metric_logger.info() (~lines 258-279), after metrics dict is built with elapsed_time, overall_throughput, current_lr, cuda_mem_allocated, global_grad_norm, etc.
  • on_evaluate: Called after compute_validation_loss() (~line 295), which returns {"val_loss": float, "val_num_tokens": int}
  • on_save: Called after save_checkpoint() and dist.barrier() (~lines 297-309)
  • on_step_end: Called after global_step += 1 and torch.cuda.empty_cache() (~line 314)
  • on_epoch_end: Called after the inner step loop exits, before/after epoch checkpoint (~line 315)
  • on_train_end: Called after the epoch loop exits, before final save (~line 331)

Internally, the callbacks should be treated as asynchronous coroutines. When the manager's event hook is triggered, it would create a task for each of the registered callbacks and schedule them for execution on the global event loop. There should be no expectation of reading the return values of the callbacks or that they pass values from one to another.

Context available to callbacks

The train() function has several key variables that should be exposed to callbacks via a context object:

Variable Description
global_step Current training step (starts at 1)
epoch Current epoch
samples_seen Cumulative samples processed
avg_loss_across_ranks Average loss across all ranks for current step
batch_metrics BatchMetrics dataclass: total_samples, total_length, num_loss_counted_tokens, accumulated_loss, accumulated_aux_loss, grad_accum_steps, num_minibatches
val_metrics Dict with val_loss and val_num_tokens (when validation is enabled)
elapsed_time Time taken for the current step
overall_throughput Tokens per second
current_lr Current learning rate
global_grad_norm Gradient norm after clipping
cuda_mem_allocated Current GPU memory usage
args Full argparse namespace with all training config

Key files involved

File Role
src/instructlab/training/config.py All Pydantic config classes (TrainingArgs, TorchrunArgs, etc.)
src/instructlab/training/main_ds.py Entry point, torchrun spawning via run_training(), main() init, train() loop
src/instructlab/training/batch_loss_manager.py BatchLossManager — forward/backward loop over minibatches, loss reduction. Forward/backward hooks go here.
src/instructlab/training/model.py Model wrapper, CausalLMModel, LigerModel, compute_loss(), setup_optimizer()
src/instructlab/training/accelerator.py Accelerator wrapper around HF Accelerate, FSDP/DS config, take_optimizer_step()
src/instructlab/training/utils.py save_checkpoint(), save_hf_format_accelerate(), load_latest_full_state(), StreamablePopen

Architectural notes

  1. The train() function is the core loop in main_ds.py (~lines 170-340). It receives model, accelerator, train/val loaders, and args. The CallbacksManager would be passed as an additional parameter.

  2. The run_training() API launches a subprocess (main_ds.py, lines 571-843) via StreamablePopen. TrainingArgs fields are mapped 1:1 to CLI flags and parsed back via argparse in the subprocess. Callbacks must be serialized (e.g., base64-encoded source via inspect.getsource or dill/cloudpickle) to cross this process boundary.

  3. Forward/backward passes are inside BatchLossManager (batch_loss_manager.py), not inline in train(). The process_batch() method loops over minibatches, calls model.compute_loss(), then accelerator.backward(). Hooks around forward/backward need to be injected into this class.

  4. Optimizer step is inside Accelerator (accelerator.py). The take_optimizer_step() method bundles clip_grad_norm_, optimizer.step(), lr_scheduler.step(), and optimizer.zero_grad().

  5. Validation is a separate function compute_validation_loss() (~lines 87-167 in main_ds.py). It switches to model.eval(), runs a forward-only loop over the val loader, does all_reduce for loss/tokens across ranks, then restores model.train().

  6. Logging uses Python's logging module with custom handlers (JSONL async writer, TensorBoard, W&B, MLflow) configured at setup time. The callback system should coexist with this, not replace it.

  7. The loop is distributed (FSDP or DeepSpeed via HF Accelerate). Callbacks should fire on rank-0 only by default to avoid race conditions. The context object should expose is_main_process for callbacks that need to check.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions