Skip to content

Conversation

@ChingTsai
Copy link
Collaborator

@ChingTsai ChingTsai commented Jan 29, 2026

Description

FIXES: b/478823561

This PR resolves discrepancies in loss calculation and training step counts when running SFT with gradient accumulation enabled. Currently, MaxText.sft.sft_trainer (which uses the Tunix trainer) exhibits breaking behavior when gradient accumulation is turned on, diverging significantly from the native implementation in MaxText.sft_trainer. This change aligns the Tunix-based SFT logic to match the native behavior.

Problem Statement

  • Loss Disparity: When GA is enabled, MaxText-Tunix observed a massive loss scale disparity compared to the native implementation. This occurs because MaxText-Tunix uses the same loss_fn as the native implementation. The native logic explicitly skips dividing by total_weights inside the function (deferring normalization to a later stage here). Consequently, Tunix inherited this behavior and failed to normalize the loss, resulting in broken calculations and inflated values.

pr_fix_vs_original_vs_native

  • Step Count Mismatch: While MaxText-Native handles micro-batching internally by reshaping the full global batch, Tunix relies on the input pipeline to provide pre-sized micro-batches. Without this adjustment, Tunix was ingesting full global batches at every step, resulting in incorrect epoch calculations and causing the run to terminate prematurely compared to the native implementation.

Changes

  • Introduced a new configuration use_tunix

    • Determine whether gradient accumulation (GA) is performed through the Tunix implementation.
  • Modified the loss function

    • Aligned Tunix loss calculation with native behavior. these changes are backward-compatible and will not affect existing native GA implementations.
  • Pre-sliced batches in data preprocessing

    • Ensures that when using Tunix GA, the chunk size is correctly calculated during the data preparation stage.
  • Add Tunix SFT integration test for gradient accumulation loss verification here

Tests

End2End

python3 -m MaxText.sft.sft_trainer \
    src/MaxText/configs/sft.yml \
    run_name=$RUN_NAME \
    base_output_directory=..../qwen3-4b \
    model_name=qwen3-4b \
    load_parameters_path=..../qwen3-4b/0/items \
    tokenizer_path=Qwen/qwen3-4b \
    steps=$train_step \
    profiler=xplane \
    hf_path=arrow \
    dataset_type=hf \
    train_split=train \
    hf_train_files=..../data-00000-of-00001.arrow \
    hf_eval_files=..../data-00000-of-00001.arrow \
    per_device_batch_size=4 \
    gradient_accumulation_steps=4 \
    max_target_length=1024 \
    learning_rate=1.3e-5 \
    warmup_steps_fraction=0.05 \
    data_shuffle_seed=42 \
    gradient_clipping_threshold=1 \
    learning_rate_final_fraction=0 \
    weight_dtype=bfloat16

After applying the changes, the loss graphs of both versions are now almost identical.

graph_2026-01-29_15-52-15

Integration Test

python -m unittest tests.integration.gradient_accumulation_test.GradientAccumulationTest.test_tunix_sft_grad_accumulate_same_loss

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ChingTsai ChingTsai changed the title Fix loss and batching when using tunix Fix gradient accumulation in post training Jan 29, 2026
@ChingTsai ChingTsai changed the title Fix gradient accumulation in post training Fix gradient accumulation in post training sft Jan 29, 2026
@codecov
Copy link

codecov bot commented Jan 29, 2026

Codecov Report

❌ Patch coverage is 12.50000% with 7 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/input_pipeline/_hf_data_processing.py 14.28% 5 Missing and 1 partial ⚠️
src/MaxText/train.py 0.00% 0 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@ChingTsai ChingTsai force-pushed the jimmytsai/fix-ga-in-sft-trainer branch from 69e7031 to f36a364 Compare January 29, 2026 08:44
@ChingTsai ChingTsai self-assigned this Jan 29, 2026
@ChingTsai ChingTsai force-pushed the jimmytsai/fix-ga-in-sft-trainer branch 2 times, most recently from b891b70 to d44bc7b Compare January 29, 2026 14:28
@ChingTsai ChingTsai marked this pull request as ready for review January 29, 2026 14:28
@ChingTsai ChingTsai force-pushed the jimmytsai/fix-ga-in-sft-trainer branch 3 times, most recently from 5ddf27b to 99d4d28 Compare January 30, 2026 06:31
@github-actions
Copy link

🤖 Hi @ChingTsai, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces a new implementation for gradient accumulation, "Tunix GA," which is handled at the data pipeline level. The changes are well-structured, and the logic is consistently applied across the configuration, data processing, and training loop. The addition of a comprehensive integration test is a great way to ensure the correctness of this new feature.

🔍 General Feedback

  • The implementation is clean and avoids major refactoring by integrating the logic into the existing data pipeline.
  • The use of a specific entrypoint to enable the feature (use_tunix_ga=True) is a good way to control the configuration.
  • The integration test is robust and covers the essential aspects of the feature.

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

mt_config = pyconfig.initialize(argv)
mt_config = pyconfig.initialize(argv, use_tunix_ga=True)
Copy link
Collaborator

@SurbhiJainUSC SurbhiJainUSC Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can use_tunix_ga be part of SFT configs: https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/sft.yml? It can be set to True by default. Also, we can rename it to use_tunix.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@ChingTsai ChingTsai force-pushed the jimmytsai/fix-ga-in-sft-trainer branch 4 times, most recently from e628e2c to 5a4f67e Compare February 2, 2026 02:28
os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
)
config = pyconfig.initialize(argv)
config = pyconfig.initialize(argv, use_tunix=False)
Copy link
Collaborator Author

@ChingTsai ChingTsai Feb 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an override here to maintain backward compatibility with the native sft_trainer.

@ChingTsai ChingTsai force-pushed the jimmytsai/fix-ga-in-sft-trainer branch 4 times, most recently from 6926f52 to a49598f Compare February 2, 2026 09:27
@NuojCheng NuojCheng self-assigned this Feb 2, 2026
use_tunix: bool = Field(
False,
description="Whether to use the Tunix implementation for gradient accumulation.",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name use_tunix is too broad. If we are only interested in cases for sft + tunix + GA, please use a more specific name, e.g. use_tunix_gradient_accumulation.

batch_size = global_batch_size // jax.process_count()
# Tunix GA requires per-micro-batch slicing at the data level,
# whereas Native GA processes the full batch and splits it internally.
batch_size = global_batch_size // jax.process_count() // (config.gradient_accumulation_steps if config.use_tunix else 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would perfer

if config.use_tunix:
   batch_size = ()
else:
  batch_size = ()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I love expressions over statements but this is a bit complex for the former

assert global_batch_size % global_mesh.size == 0, "Batch size should be divisible by number of global devices."
# Tunix GA requires per-micro-batch slicing at the data level,
# whereas Native GA processes the full batch and splits it internally.
batch_size = global_batch_size // jax.process_count() // (num_microbatches if use_tunix else 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same with comment above

# EPS was used to avoid division by zero, but it's not needed when gradient
# accumulation is enabled since there's no division.
if config.gradient_accumulation_steps > 1:
if config.gradient_accumulation_steps > 1 and not config.use_tunix:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add comments explaining why we don't follow standard GA path when use_tunix=true

Copy link
Collaborator

@NuojCheng NuojCheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit concerning to me why this PR fails tests/integration/train_tests.py::TrainTests::test_tpu_zero1_gradient_accumulation with OOM. The changes made by this PR should not affect regular GA performance in train.py. We will need to take a look on xprofs to figure out why.

Copy link
Collaborator

@richjames0 richjames0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once Nuojin's comments are addressed and he's happy :)

batch_size = global_batch_size // jax.process_count()
# Tunix GA requires per-micro-batch slicing at the data level,
# whereas Native GA processes the full batch and splits it internally.
batch_size = global_batch_size // jax.process_count() // (config.gradient_accumulation_steps if config.use_tunix else 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I love expressions over statements but this is a bit complex for the former

@ChingTsai ChingTsai force-pushed the jimmytsai/fix-ga-in-sft-trainer branch 3 times, most recently from 49587e1 to d56abda Compare February 3, 2026 03:48
@ChingTsai ChingTsai requested a review from parambole as a code owner February 3, 2026 05:09
@ChingTsai ChingTsai force-pushed the jimmytsai/fix-ga-in-sft-trainer branch 2 times, most recently from f14c32a to d7e843e Compare February 3, 2026 07:36
@ChingTsai ChingTsai force-pushed the jimmytsai/fix-ga-in-sft-trainer branch from f70b81c to 6855036 Compare February 3, 2026 10:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants