-
Notifications
You must be signed in to change notification settings - Fork 78
Description
Overview
Goal: Add Callbacks to Instructlab-Training
Why:
We want to facilitate use-cases such as:
- initiating eval jobs within a GPU cluster on checkpoint save
- ending the training loop early based on some condition
In order to implement this, we would add features in the following order:
- implement simple, read-only callbacks
- implement callbacks that can control flow of training loop (e.g. stopping the loop early, initiating a checkpoint, etc.)
Related Issues:
- Red-Hat-AI-Innovation-Team/mini_trainer#76 — equivalent implementation for mini-trainer
Expected usage
Users calling the run_training API would directly provide their callbacks to the TrainingArgs object through a flat API.
As an example:
run_training(
on_checkpoint=run_eval,
on_validation_loss=end_if_diverged,
# ... etc
)Implementation Details
To implement callbacks in this training library, we have to worry about two major pieces:
- Taking callbacks that were provided to the
run_trainingAPI and passing them to thetorchrunsubprocesses which run the actual training loop - Receiving the callbacks inside of the training loop (child process of
torchrun) and executing them at the appropriate points without interfering with the training process
We will eventually also want to implement callbacks which can modify what actions are performed by the training loop itself, however this will require some additional effort to ensure proper synchronization of the global process group. These callbacks will not be covered in this issue though.
The callback API
We can define a general callback as a generic function which accepts some general context as well as a list of args + kwargs:
def my_callback(context, **kwargs) -> None:
# do stuff
returnCallbacks may also have arguments for certain events, e.g. on_save may have the following interface:
def on_save(context, checkpoint_path: str, **kwargs) -> None:
# do stuff
returnWhile the on_evaluate callback would look like this:
def on_evaluate(context, validation_loss: float, **kwargs) -> None:
# do stuff
returnWe expect these callbacks to be entirely self-contained, so that any and all packages are imported inside of the callback function and must be present within whatever venv the training loop runs in.
Additionally, since we expect the initial callbacks to be read-only and inconsequential to the training loop, the training loop will wrap them as asynchronous tasks which run on the global event loop. For this reason, callbacks should be careful with their usage of async functionality.
Warning
To prevent any unexpected behavior or complex testing scenarios, the callbacks will not be allowed to have their exceptions propagate into the training loop, therefore any exceptions must be handled by the callback or they will fail silently.
Adding callbacks to TrainingArgs
To maintain a simple interface, the TrainingArgs Pydantic model (in src/instructlab/training/config.py) would expose registration for each callback with a flat interface like this:
from instructlab.training.config import TrainingArgs
args = TrainingArgs(
on_save=evaluate_gsm8k,
on_evaluate=save_if_improved,
...
)
# then we call training
run_training(args, torch_args)The API for these event hooks should also allow providing multiple callbacks. You would see a definition like this:
class TrainingArgs(pydantic.BaseModel):
# ...
on_save: typing.Callable | list[typing.Callable] | None = Field(None, ...)Passing callbacks to torchrun procs
In order to execute callbacks during training, they must first be passed between the main process where run_training is called and the child processes that are spawned to run the training loop in parallel. When run_training is first called, it uses the subprocess.Popen API (via StreamablePopen from utils.py) in order to invoke torchrun which then creates a worker for each GPU. The training loop itself expects to be invoked as a CLI, so run_training maps all of the fields in the TrainingArgs to the CLI flags expected by the actual training script. You can see this in action here: https://github.com/instructlab/training/blob/main/src/instructlab/training/main_ds.py#L624-L801
Therefore we can provide these callbacks to our training script by first serializing them with base64, passing them as string arguments to the CLI, then unserializing and eval-ing each callback to register it in the child's memory space. As long as the callback is self-contained, it should execute on the child process just as it would in the main process.
Registering and executing callbacks
Once they're unserialized, each callback will be registered on the global main/rank-0 process only by a callback manager.
This is necessary to avoid race conditions or a flood of evaluation requests from each process after every checkpoint save. We can add this functionality later, but we should avoid it to start so the initial implementation can be straightforward.
In order to manage callbacks, we should have a class like CallbacksManager which registers the callbacks and is responsible for providing them with the relevant context when they're called.
The callback manager would have an on_<event> method for each of the following events:
-
on_train_begin: Called aftermodel.train()and state initialization, before the epoch loop (~line 200 inmain_ds.py:train()) -
on_epoch_begin: Called at the top of the epochforloop, aftersampler.set_epoch(epoch)(~line 208) -
on_step_begin: Called at the beginning of each training step, afterstart = time.time()(~line 221) -
on_before_forward: Called insideBatchLossManager.process_batch(), beforemodel.compute_loss()(inbatch_loss_manager.py) -
on_after_forward: Called insideBatchLossManager.process_batch(), aftermodel.compute_loss()returns(scaled_loss, raw_losses), before backward -
on_before_backward: Called insideBatchLossManager.process_batch(), beforeaccelerator.backward(scaled_loss) -
on_after_backward: Called insideBatchLossManager.process_batch(), afteraccelerator.backward(scaled_loss) -
on_before_optimizer_step: Called beforeaccelerator.take_optimizer_step()(~line 236), which internally doesclip_grad_norm_,optimizer.step(),lr_scheduler.step(),optimizer.zero_grad() -
on_after_optimizer_step: Called afteraccelerator.take_optimizer_step()(~line 236) -
on_log: Called aroundmetric_logger.info()(~lines 258-279), after metrics dict is built withelapsed_time,overall_throughput,current_lr,cuda_mem_allocated,global_grad_norm, etc. -
on_evaluate: Called aftercompute_validation_loss()(~line 295), which returns{"val_loss": float, "val_num_tokens": int} -
on_save: Called aftersave_checkpoint()anddist.barrier()(~lines 297-309) -
on_step_end: Called afterglobal_step += 1andtorch.cuda.empty_cache()(~line 314) -
on_epoch_end: Called after the inner step loop exits, before/after epoch checkpoint (~line 315) -
on_train_end: Called after the epoch loop exits, before final save (~line 331)
Internally, the callbacks should be treated as asynchronous coroutines. When the manager's event hook is triggered, it would create a task for each of the registered callbacks and schedule them for execution on the global event loop. There should be no expectation of reading the return values of the callbacks or that they pass values from one to another.
Context available to callbacks
The train() function has several key variables that should be exposed to callbacks via a context object:
| Variable | Description |
|---|---|
global_step |
Current training step (starts at 1) |
epoch |
Current epoch |
samples_seen |
Cumulative samples processed |
avg_loss_across_ranks |
Average loss across all ranks for current step |
batch_metrics |
BatchMetrics dataclass: total_samples, total_length, num_loss_counted_tokens, accumulated_loss, accumulated_aux_loss, grad_accum_steps, num_minibatches |
val_metrics |
Dict with val_loss and val_num_tokens (when validation is enabled) |
elapsed_time |
Time taken for the current step |
overall_throughput |
Tokens per second |
current_lr |
Current learning rate |
global_grad_norm |
Gradient norm after clipping |
cuda_mem_allocated |
Current GPU memory usage |
args |
Full argparse namespace with all training config |
Key files involved
| File | Role |
|---|---|
src/instructlab/training/config.py |
All Pydantic config classes (TrainingArgs, TorchrunArgs, etc.) |
src/instructlab/training/main_ds.py |
Entry point, torchrun spawning via run_training(), main() init, train() loop |
src/instructlab/training/batch_loss_manager.py |
BatchLossManager — forward/backward loop over minibatches, loss reduction. Forward/backward hooks go here. |
src/instructlab/training/model.py |
Model wrapper, CausalLMModel, LigerModel, compute_loss(), setup_optimizer() |
src/instructlab/training/accelerator.py |
Accelerator wrapper around HF Accelerate, FSDP/DS config, take_optimizer_step() |
src/instructlab/training/utils.py |
save_checkpoint(), save_hf_format_accelerate(), load_latest_full_state(), StreamablePopen |
Architectural notes
-
The
train()function is the core loop inmain_ds.py(~lines 170-340). It receives model, accelerator, train/val loaders, and args. TheCallbacksManagerwould be passed as an additional parameter. -
The
run_training()API launches a subprocess (main_ds.py, lines 571-843) viaStreamablePopen.TrainingArgsfields are mapped 1:1 to CLI flags and parsed back viaargparsein the subprocess. Callbacks must be serialized (e.g., base64-encoded source viainspect.getsourceordill/cloudpickle) to cross this process boundary. -
Forward/backward passes are inside
BatchLossManager(batch_loss_manager.py), not inline intrain(). Theprocess_batch()method loops over minibatches, callsmodel.compute_loss(), thenaccelerator.backward(). Hooks around forward/backward need to be injected into this class. -
Optimizer step is inside
Accelerator(accelerator.py). Thetake_optimizer_step()method bundlesclip_grad_norm_,optimizer.step(),lr_scheduler.step(), andoptimizer.zero_grad(). -
Validation is a separate function
compute_validation_loss()(~lines 87-167 inmain_ds.py). It switches tomodel.eval(), runs a forward-only loop over the val loader, doesall_reducefor loss/tokens across ranks, then restoresmodel.train(). -
Logging uses Python's
loggingmodule with custom handlers (JSONL async writer, TensorBoard, W&B, MLflow) configured at setup time. The callback system should coexist with this, not replace it. -
The loop is distributed (FSDP or DeepSpeed via HF Accelerate). Callbacks should fire on rank-0 only by default to avoid race conditions. The context object should expose
is_main_processfor callbacks that need to check.