Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Use an official Python 3.12 runtime as a parent image
Use an official Python 3.12 runtime as a parent image
FROM python:3.12-slim

# Set the working directory
Expand Down Expand Up @@ -32,4 +32,4 @@ RUN pip install ./jax_tpu_embedding-0.1.0.dev20260121-cp312-cp312-manylinux_2_31
COPY . /app

# Default command to run the training script
CMD ["python", "recml/examples/dlrm_experiment_test.py"]
CMD ["python", "recml/examples/dlrm_experiment_test.py"]
24 changes: 23 additions & 1 deletion recml/core/training/jax_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import pprint
from typing import Any, Generic, Protocol, Self, TypeVar
import time

from absl import logging
from clu import data as clu_data
Expand Down Expand Up @@ -571,9 +572,30 @@ def _train_n_steps(
for step in range(start_step, start_step + num_steps):
with jax.profiler.StepTraceAnnotation("train", step_num=step):
train_batch = next(train_iter)
step_start = time.time()
inputs = self._partitioner.shard_inputs(train_batch)
state, metrics_update = train_step(inputs, state)
metrics_accum.accumulate(metrics_update, step)

jax.block_until_ready(metrics_update)
step_duration = time.time() - step_start


timing_metrics = {
'perf/step_time_ms': base_metrics.scalar(step_duration * 1000),
'perf/steps_per_sec': base_metrics.scalar(1.0 / step_duration if step_duration > 0 else 0),
}


if 'common/batch_size' in metrics_update:
bs = metrics_update['common/batch_size'].compute()
timing_metrics['perf/throughput_ex_per_sec'] = base_metrics.scalar(bs / step_duration)

metrics_accum.accumulate({**metrics_update, **timing_metrics}, step)

if step == start_step:
global_step = int(jax.device_get(state.step))
logging.info(f"DEBUG: Trainer running step {global_step}. Duration: {step_duration:.4f}s")

self.report_progress(step)
if (step != start_step + num_steps - 1) and self._enable_checkpointing:
self._maybe_save_checkpoint(step, state)
Expand Down
3 changes: 3 additions & 0 deletions recml/examples/dlrm_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,17 +276,20 @@ def _loss_fn(params: jt.PyTree) -> tuple[jt.Scalar, jt.Array]:
loss = jnp.mean(optax.sigmoid_binary_cross_entropy(logits, label), axis=0)
return loss, logits

global_batch_size = self.train_data.global_batch_size
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True, allow_int=True)
(loss, logits), grads = grad_fn(state.params)
state = state.update(grads=grads)

print(f"DEBUG: Trainer global step {global_batch_size}.")
metrics = {
'loss': recml.metrics.scalar(loss),
'accuracy': recml.metrics.binary_accuracy(label, logits, threshold=0.0),
'auc': recml.metrics.aucpr(label, logits, from_logits=True),
'aucroc': recml.metrics.aucroc(label, logits, from_logits=True),
'label/mean': recml.metrics.mean(label),
'prediction/mean': recml.metrics.mean(jax.nn.sigmoid(logits)),
'common/batch_size': recml.metrics.scalar(global_batch_size),
}
return state, metrics

Expand Down
2 changes: 1 addition & 1 deletion recml/examples/dlrm_experiment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_dlrm_experiment(self):

experiment.task.train_data.global_batch_size = 128
experiment.task.eval_data.global_batch_size = 128
experiment.trainer.train_steps = 12
experiment.trainer.train_steps = 120
experiment.trainer.steps_per_loop = 4
experiment.trainer.steps_per_eval = 4
experiment.trainer.enable_checkpointing = False
Expand Down
Loading