From c1db017c844f5571ba1f17f9f176a63a1e7fd78d Mon Sep 17 00:00:00 2001 From: Mohit Khatwani Date: Fri, 6 Feb 2026 18:29:32 +0000 Subject: [PATCH] move src/configs to src/maxtext/configs --- .github/workflows/run_jupyter_notebooks.yml | 9 +- .github/workflows/run_pathways_tests.yml | 2 +- .../workflows/run_tests_against_package.yml | 2 +- .vscode/launch.json | 8 +- PREFLIGHT.md | 8 +- README.md | 2 +- benchmarks/api_server/README.md | 6 +- .../api_server/launch_gke_server.sh.template | 2 +- benchmarks/api_server/start_server.sh | 2 +- benchmarks/globals.py | 5 +- benchmarks/maxtext_xpk_runner.py | 8 +- benchmarks/mmlu/mmlu_eval.py | 6 +- codecov.yml | 2 +- .../convert_checkpoint.md | 10 +- .../emergency_checkpointing.md | 204 +++++++-------- .../multi_tier_checkpointing.md | 235 +++++++++--------- .../data_input_pipeline/data_input_grain.md | 48 ++-- .../data_input_pipeline/data_input_hf.md | 4 +- .../data_input_pipeline/data_input_tfds.md | 2 +- .../features_and_diagnostics.md | 40 ++- .../gcp_workload_observability.md | 9 +- .../ml_workload_diagnostics.md | 79 +++--- .../monitor_goodput.md | 83 ++++--- .../understand_logs_and_metrics.md | 72 ++++-- docs/guides/optimization/sharding.md | 14 +- .../architecture/architecture_overview.md | 91 ++++--- docs/reference/core_concepts/checkpoints.md | 31 +-- .../core_concepts/moe_configuration.md | 66 +++-- docs/reference/core_concepts/quantization.md | 95 +++---- docs/run_maxtext/run_maxtext_localhost.md | 18 +- .../run_maxtext_single_host_gpu.md | 15 +- .../run_maxtext_via_multihost_job.md | 73 +++--- .../run_maxtext_via_multihost_runner.md | 137 +++++----- docs/run_maxtext/run_maxtext_via_pathways.md | 64 ++--- docs/run_maxtext/run_maxtext_via_xpk.md | 185 +++++++------- docs/tutorials/first_run.md | 16 +- .../tutorials/posttraining/full_finetuning.md | 4 +- .../posttraining/knowledge_distillation.md | 10 +- docs/tutorials/posttraining/multimodal.md | 10 +- docs/tutorials/posttraining/rl.md | 24 +- .../posttraining/rl_on_multi_host.md | 12 +- docs/tutorials/posttraining/sft.md | 2 +- .../posttraining/sft_on_multi_host.md | 4 +- docs/tutorials/pretraining.md | 46 ++-- src/MaxText/configs/v5p/gpt3_175b/v5p_1024.sh | 15 -- src/MaxText/configs/v5p/gpt3_175b/v5p_2048.sh | 15 -- src/MaxText/configs/v5p/gpt3_175b/v5p_3072.sh | 15 -- .../agent/ckpt_conversion_agent/README.md | 6 +- src/MaxText/get_flops.py | 2 +- src/MaxText/globals.py | 8 +- .../vllm/maxtext_vllm_adapter/adapter.py | 1 + src/MaxText/layerwise_quantization.py | 2 +- src/MaxText/pyconfig.py | 8 +- src/MaxText/rl/train_rl.py | 8 +- .../utils/ckpt_conversion/compare_hf_ckpt.py | 4 +- .../examples/convert_gemma2_to_hf.sh | 4 +- .../examples/convert_gemma2_to_mt.sh | 6 +- .../examples/convert_gemma3_to_hf.sh | 4 +- .../utils/ckpt_conversion/to_huggingface.py | 2 +- .../utils/ckpt_conversion/to_maxtext.py | 4 +- .../llama_mistral_mixtral_orbax_to_hf.py | 2 +- src/{MaxText => maxtext}/configs/README.md | 58 ++--- src/{MaxText => maxtext}/configs/__init__.py | 0 src/{MaxText => maxtext}/configs/base.yml | 2 +- .../configs/decoupled_base_test.yml | 0 .../configs/experimental/1024b.sh | 4 +- .../configs/experimental/128b.sh | 4 +- .../configs/experimental/256b.sh | 4 +- .../configs/experimental/32b.sh | 4 +- .../configs/experimental/512b.sh | 4 +- .../configs/experimental/64b.sh | 4 +- .../configs/gpu}/a3/llama_2_7b/16vm.sh | 4 +- .../configs/gpu}/a3/llama_2_7b/1vm.sh | 4 +- .../configs/gpu}/a3/llama_2_7b/2vm.sh | 4 +- .../configs/gpu}/a3/llama_2_7b/4vm.sh | 4 +- .../configs/gpu}/a3/llama_2_7b/8vm.sh | 4 +- .../configs/gpu}/a3/llama_2_7b/README.md | 0 .../configs/gpu}/a3/llama_3.1_405b/128vm.sh | 4 +- .../configs/gpu}/gpu_smoke_test.yml | 0 .../configs/gpu/models}/llama2_70b.yml | 0 .../configs/gpu/models}/llama2_7b.yml | 0 .../configs/gpu/models}/llama3.1_405b.yml | 0 .../configs/gpu/models}/llama3_70b.yml | 0 .../configs/gpu/models}/llama3_8b.yml | 0 .../configs/gpu/models}/mixtral_8x1b.yml | 0 .../configs/gpu/models}/mixtral_8x2b.yml | 0 .../configs/gpu/models}/mixtral_8x7b.yml | 0 .../configs/inference}/inference.yml | 0 .../inference}/inference_jetstream.yml | 0 .../disaggregation/llama3_405b_v6e-16-16.yml | 2 +- .../interleaved}/llama2_70b_v5e-16.yml | 2 +- .../interleaved/llama3_405b_v5e-64.yml | 2 +- .../interleaved/llama3_70b_v5e-16.yml | 2 +- .../configs/models/deepseek-custom.yml | 0 .../configs/models/deepseek2-16b.yml | 0 .../configs/models/deepseek2-236b.yml | 0 .../configs/models/deepseek3-671b-2dfsdp.yml | 0 .../configs/models/deepseek3-671b.yml | 0 .../configs/models/deepseek3-test.yml | 0 src/maxtext/configs/models/deepseek3-tiny.yml | 50 ++++ .../configs/models/deepseek3.2-671b.yml | 0 .../configs/models/gemma-2b.yml | 0 .../configs/models/gemma-7b.yml | 0 .../configs/models/gemma2-27b.yml | 0 .../configs/models/gemma2-2b.yml | 0 .../configs/models/gemma2-9b.yml | 0 .../configs/models/gemma3-12b.yml | 0 .../configs/models/gemma3-27b.yml | 0 .../configs/models/gemma3-4b.yml | 0 .../configs/models/gpt-oss-120b.yml | 0 .../configs/models/gpt-oss-20b.yml | 0 .../configs/models/gpt3-175b.yml | 0 .../configs/models/gpt3-22b.yml | 0 .../configs/models/gpt3-52k.yml | 0 .../configs/models/gpt3-6b.yml | 0 .../configs/models/kimi-k2-1t.yml | 0 .../configs/models/llama2-13b.yml | 0 .../configs/models/llama2-70b.yml | 0 .../configs/models/llama2-7b.yml | 0 .../configs/models/llama3-405b.yml | 0 .../configs/models/llama3-70b.yml | 0 .../configs/models/llama3-8b.yml | 0 .../configs/models/llama3.1-405b.yml | 0 .../configs/models/llama3.1-70b.yml | 0 .../configs/models/llama3.1-8b.yml | 0 .../configs/models/llama3.3-70b.yml | 0 .../configs/models/llama4-17b-128e.yml | 0 .../configs/models/llama4-17b-16e.yml | 0 .../configs/models/mistral-7b.yml | 0 .../configs/models/mixtral-8x22b.yml | 0 .../configs/models/mixtral-8x7b.yml | 0 .../configs/models/olmo3_32b.yml | 0 .../configs/models/olmo3_7b.yml | 0 .../configs/models/qwen3-0.6b.yml | 0 .../configs/models/qwen3-14b.yml | 0 .../configs/models/qwen3-235b-a22b.yml | 0 .../configs/models/qwen3-30b-a3b.yml | 0 .../configs/models/qwen3-32b.yml | 0 .../configs/models/qwen3-480b-a35b.yml | 0 .../configs/models/qwen3-4b-thinking-2507.yml | 0 .../configs/models/qwen3-4b.yml | 0 .../configs/models/qwen3-8b.yml | 0 .../configs/models/qwen3-next-80b-a3b.yml | 2 +- .../configs/models/qwen3-omni-30b-a3b.yml | 0 .../configs/post_train}/distillation.yml | 0 .../configs/post_train}/dpo.yml | 0 .../configs/post_train}/rl.yml | 0 .../configs/post_train}/rl_mt_jt.yml | 0 .../post_train}/sft-vision-chartqa.yml | 0 .../post_train}/sft-vision-slidevqa.yml | 0 .../configs/post_train}/sft.yml | 0 .../configs/quantization/README.md | 0 .../quantization/dense_llm_subchannel.json | 0 .../dense_llm_weight_only_scale.json | 0 .../quantization/int4_weight_only.json | 0 .../quantization/int8_weight_only.json | 0 .../configs/tpu}/tpu_smoke_test.yml | 0 .../configs => maxtext/configs/tpu}/v4/22b.sh | 6 +- .../configs => maxtext/configs/tpu}/v4/52b.sh | 6 +- .../configs/tpu}/v4/README.md | 0 .../configs/tpu}/v5e/128b.sh | 6 +- .../configs/tpu}/v5e/16b.sh | 6 +- .../configs/tpu}/v5e/32b.sh | 6 +- .../configs/tpu}/v5e/64b.sh | 6 +- .../configs/tpu}/v5e/README.md | 0 .../configs/tpu}/v5e/gpt3_175b.sh | 6 +- .../configs/tpu}/v5e/llama2_13b.sh | 6 +- .../configs/tpu}/v5e/llama2_70b.sh | 6 +- .../tpu/v5e}/llama2_70b_v5e-16.yml | 2 +- .../configs/tpu}/v5e/llama2_7b.sh | 6 +- .../configs/tpu}/v5e/llama3_405b_v5e-64.yml | 2 +- .../configs/tpu}/v5e/llama3_70b_v5e-16.yml | 2 +- .../configs/tpu}/v5p/1024b.sh | 6 +- .../configs/tpu}/v5p/128b.sh | 6 +- .../configs/tpu}/v5p/256b.sh | 6 +- .../configs/tpu}/v5p/32b.sh | 6 +- .../configs/tpu}/v5p/512b.sh | 6 +- .../configs/tpu}/v5p/64b.sh | 6 +- .../configs/tpu}/v5p/README.md | 0 .../tpu}/v5p/gpt3_175b/gpt3_175b_base.sh | 8 +- .../configs/tpu/v5p/gpt3_175b/v5p_1024.sh | 15 ++ .../configs/tpu}/v5p/gpt3_175b/v5p_12288.sh | 4 +- .../configs/tpu/v5p/gpt3_175b/v5p_2048.sh | 15 ++ .../configs/tpu/v5p/gpt3_175b/v5p_3072.sh | 15 ++ .../configs/tpu}/v5p/gpt3_175b/v5p_4096.sh | 6 +- .../configs/tpu}/v5p/gpt3_175b/v5p_8192.sh | 6 +- .../configs/tpu}/v5p/llama2_70b.sh | 6 +- .../configs/tpu}/v5p/llama2_7b.sh | 6 +- .../configs/tpu/v6e}/gemma2_27b.sh | 4 +- .../configs/tpu/v6e}/gemma2_9b.sh | 4 +- .../configs/tpu/v6e}/gemma3_27b.sh | 4 +- .../configs/tpu/v6e}/gpt3_175b.sh | 4 +- .../v6e/inference/llama4_maverick_v6e-64.yml | 2 +- .../configs/tpu/v6e}/llama2_7b_4096.sh | 4 +- .../configs/tpu/v6e}/mixtral_8x7b.sh | 4 +- src/{MaxText => maxtext}/configs/types.py | 0 src/{MaxText => maxtext}/configs/vllm.yml | 0 src/maxtext/examples/demo_decoding.ipynb | 4 +- .../examples/multimodal_gemma3_demo.ipynb | 6 +- src/maxtext/examples/rl_llama3_demo.ipynb | 2 +- src/maxtext/examples/sft_llama3_demo.ipynb | 55 ++-- src/maxtext/examples/sft_qwen3_demo.ipynb | 5 +- .../examples/sft_train_and_evaluate.py | 4 +- .../gpu/microbenchmark_llama2-70b_h100-8.sh | 2 +- .../maxengine_server_entrypoint.sh | 2 +- src/maxtext/inference/mlperf/README.md | 4 +- .../gpu/benchmarks_llama2-70b-h100_8.sh | 2 +- .../inference/mlperf/llama_offline_run.sh | 2 +- .../benchmarks_llama2-70b-trillium_2x4.sh | 2 +- src/maxtext/scratch_code/gemma_7b.sh | 4 +- .../run_inference_microbenchmark.sh | 2 +- .../trainers/post_train/sft/train_sft.py | 6 +- src/maxtext/utils/maxtext_utils.py | 2 +- src/maxtext/vllm_decode.py | 4 +- .../gpu/a3/test_convergence_125m_params.sh | 2 +- .../gpu/a3/test_convergence_1b_params.sh | 2 +- tests/end_to_end/gpu/a3/test_gemma3_logits.sh | 4 +- tests/end_to_end/gpu/a3/test_llama2_7b.sh | 6 +- tests/end_to_end/gpu/mixtral/test_8x7b.sh | 6 +- .../gpu/te/run_single_node_model_parallel.sh | 2 +- .../gpu/test_collective_matmul_llama2_7b.sh | 2 +- .../end_to_end/gpu/test_fp8_gemm_llama2_7b.sh | 2 +- .../test_checkpoint_compatibility.sh | 6 +- tests/end_to_end/test_checkpointing.sh | 4 +- .../test_generate_param_only_checkpoint.sh | 6 +- .../end_to_end/test_mtc_phase_2_save_path.sh | 2 +- .../test_multi_tier_checkpointing.sh | 4 +- tests/end_to_end/tpu/deepseek/Run_DeepSeek.md | 12 +- .../tpu/deepseek/v2-16b/test_deepseek.sh | 8 +- .../tpu/deepseek/v3-671b/2_test_deepseek.sh | 4 +- .../tpu/deepseek/v3-671b/test_deepseek_mtp.sh | 2 +- tests/end_to_end/tpu/gemma/2b/test_gemma.sh | 18 +- tests/end_to_end/tpu/gemma/7b/1_test_gemma.sh | 2 +- tests/end_to_end/tpu/gemma/7b/2_test_gemma.sh | 14 +- .../end_to_end/tpu/gemma2/27b/1_test_gemma.sh | 2 +- .../end_to_end/tpu/gemma2/27b/2_test_gemma.sh | 6 +- tests/end_to_end/tpu/gemma2/2b/test_gemma2.sh | 16 +- .../tpu/gemma2/2b/test_gemma2_to_hf.sh | 4 +- .../tpu/gemma2/2b/test_gemma2_to_mt.sh | 12 +- .../end_to_end/tpu/gemma2/9b/1_test_gemma.sh | 2 +- .../end_to_end/tpu/gemma2/9b/2_test_gemma.sh | 6 +- .../end_to_end/tpu/gemma3/12b/test_gemma3.sh | 10 +- .../end_to_end/tpu/gemma3/27b/test_gemma3.sh | 10 +- tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh | 10 +- .../gemma3/4b/test_gemma3_multimodal_sft.sh | 10 +- .../tpu/gemma3/4b/test_gemma3_to_hf.sh | 4 +- .../tpu/gemma3/4b/test_gemma3_to_mt.sh | 16 +- tests/end_to_end/tpu/gemma3/Run_Gemma3.md | 6 +- .../tpu/gpt_oss/120b/test_gpt_oss.sh | 10 +- .../tpu/gpt_oss/20b/test_gpt_oss.sh | 10 +- tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md | 10 +- .../tpu/llama2/13b/1_test_llama2_13b.sh | 2 +- .../tpu/llama2/13b/2_test_llama2_13b.sh | 12 +- .../tpu/llama2/70b/1_test_llama2_70b.sh | 2 +- .../tpu/llama2/70b/2_test_llama2_70b.sh | 14 +- .../tpu/llama2/7b/test_llama2_7b.sh | 20 +- .../tpu/llama3.1/405b/2_test_llama3.1_405b.sh | 6 +- .../tpu/llama3.1/405b/3_test_llama3.1_405b.sh | 2 +- .../tpu/llama3.1/70b/1_test_llama3.1_70b.sh | 2 +- .../tpu/llama3.1/70b/2_test_llama3.1_70b.sh | 8 +- .../tpu/llama3.1/70b/3_test_llama3.1_70b.sh | 14 +- .../tpu/llama3.1/8b/1_test_llama3.1_8b.sh | 2 +- .../tpu/llama3.1/8b/2_test_llama3.1_8b.sh | 12 +- .../tpu/llama3.1/8b/3_test_llama3.1_8b.sh | 14 +- tests/end_to_end/tpu/llama3.1/8b/run_sft.sh | 4 +- .../tpu/llama3.3/70b/1_test_llama3.3_70b.sh | 2 +- .../tpu/llama3.3/70b/2_test_llama3.3_70b.sh | 8 +- .../tpu/llama3/70b/1_test_llama3_70b.sh | 2 +- .../tpu/llama3/70b/2_test_llama3_70b.sh | 14 +- .../tpu/llama3/8b/1_test_llama3_8b.sh | 2 +- .../tpu/llama3/8b/2_test_llama3_8b.sh | 14 +- tests/end_to_end/tpu/llama4/2_test_llama4.sh | 2 +- tests/end_to_end/tpu/llama4/Run_Llama4.md | 4 +- tests/end_to_end/tpu/llama_finetuning_test.sh | 2 +- .../tpu/mistral/7b/test_mistral-7b.sh | 6 +- .../tpu/mixtral/8x22b/2_test_mixtral.sh | 6 +- .../tpu/mixtral/8x7b/1_test_mixtral.sh | 2 +- .../tpu/mixtral/8x7b/2_test_mixtral.sh | 12 +- .../qwen3-235b-a22b/1_test_qwen3_235b_a22b.sh | 2 +- .../moe/qwen3-30b-a3b/1_test_qwen3_30b_a3b.sh | 2 +- .../qwen3-480b-a35b/1_test_qwen3_480b_a35b.sh | 2 +- tests/end_to_end/tpu/qwen/moe/run_qwen_moe.md | 4 +- .../1_test_qwen3_next_80b_a3b.sh | 2 +- .../tpu/qwen/next/run_qwen3_next.md | 2 +- .../tpu/qwen3/4b/test_qwen3_to_hf.sh | 4 +- .../tpu/qwen3/4b/test_qwen3_to_mt.sh | 10 +- tests/end_to_end/tpu/run_sft.sh | 6 +- .../tpu/test_checkpoint_resharding.sh | 4 +- .../tpu/test_convergence_1b_params.sh | 2 +- .../tpu/test_decode_load_quantized_ckpt.sh | 2 +- .../tpu/test_decode_save_quantized_ckpt.sh | 2 +- tests/end_to_end/tpu/test_dpo.sh | 6 +- tests/end_to_end/tpu/test_gpt3.sh | 2 +- tests/end_to_end/tpu/test_sft_trainer.sh | 4 +- tests/end_to_end/tpu/test_vocab_creation.sh | 2 +- tests/integration/aot_identical_test.py | 10 +- tests/integration/decode_tests.py | 4 +- tests/integration/determinism_test.py | 5 +- .../integration/gradient_accumulation_test.py | 4 +- .../inference_microbenchmark_smoke_test.py | 2 +- .../integration/smoke/train_gpu_smoke_test.py | 2 +- tests/integration/smoke/train_smoke_test.py | 4 +- .../train_using_ragged_dot_smoke_test.py | 3 +- tests/integration/train_tests.py | 4 +- tests/integration/xaot_test.py | 10 +- tests/unit/attention_test.py | 4 +- tests/unit/configs_test.py | 58 ++--- tests/unit/configs_value_test.py | 4 +- tests/unit/deepseek32_vs_reference_test.py | 5 +- tests/unit/grain_data_processing_test.py | 8 +- tests/unit/moe_test.py | 6 +- tests/unit/pyconfig_deprecated_test.py | 9 +- tests/unit/qwen3_omni_layers_test.py | 4 +- tests/unit/sft_data_processing_test.py | 2 +- tests/unit/sft_hooks_test.py | 2 +- tests/unit/train_compile_test.py | 11 +- tests/utils/attention_test_util.py | 4 +- tests/utils/test_helper.py | 38 +++ .../generate_distillation_data.py | 2 +- 319 files changed, 1652 insertions(+), 1409 deletions(-) delete mode 100644 src/MaxText/configs/v5p/gpt3_175b/v5p_1024.sh delete mode 100644 src/MaxText/configs/v5p/gpt3_175b/v5p_2048.sh delete mode 100644 src/MaxText/configs/v5p/gpt3_175b/v5p_3072.sh rename src/{MaxText => maxtext}/configs/README.md (62%) rename src/{MaxText => maxtext}/configs/__init__.py (100%) rename src/{MaxText => maxtext}/configs/base.yml (99%) rename src/{MaxText => maxtext}/configs/decoupled_base_test.yml (100%) rename src/{MaxText => maxtext}/configs/experimental/1024b.sh (87%) rename src/{MaxText => maxtext}/configs/experimental/128b.sh (88%) rename src/{MaxText => maxtext}/configs/experimental/256b.sh (88%) rename src/{MaxText => maxtext}/configs/experimental/32b.sh (88%) rename src/{MaxText => maxtext}/configs/experimental/512b.sh (87%) rename src/{MaxText => maxtext}/configs/experimental/64b.sh (88%) rename src/{MaxText/configs => maxtext/configs/gpu}/a3/llama_2_7b/16vm.sh (87%) rename src/{MaxText/configs => maxtext/configs/gpu}/a3/llama_2_7b/1vm.sh (87%) rename src/{MaxText/configs => maxtext/configs/gpu}/a3/llama_2_7b/2vm.sh (87%) rename src/{MaxText/configs => maxtext/configs/gpu}/a3/llama_2_7b/4vm.sh (87%) rename src/{MaxText/configs => maxtext/configs/gpu}/a3/llama_2_7b/8vm.sh (87%) rename src/{MaxText/configs => maxtext/configs/gpu}/a3/llama_2_7b/README.md (100%) rename src/{MaxText/configs => maxtext/configs/gpu}/a3/llama_3.1_405b/128vm.sh (90%) rename src/{MaxText/configs => maxtext/configs/gpu}/gpu_smoke_test.yml (100%) rename src/{MaxText/configs/models/gpu => maxtext/configs/gpu/models}/llama2_70b.yml (100%) rename src/{MaxText/configs/models/gpu => maxtext/configs/gpu/models}/llama2_7b.yml (100%) rename src/{MaxText/configs/models/gpu => maxtext/configs/gpu/models}/llama3.1_405b.yml (100%) rename src/{MaxText/configs/models/gpu => maxtext/configs/gpu/models}/llama3_70b.yml (100%) rename src/{MaxText/configs/models/gpu => maxtext/configs/gpu/models}/llama3_8b.yml (100%) rename src/{MaxText/configs/models/gpu => maxtext/configs/gpu/models}/mixtral_8x1b.yml (100%) rename src/{MaxText/configs/models/gpu => maxtext/configs/gpu/models}/mixtral_8x2b.yml (100%) rename src/{MaxText/configs/models/gpu => maxtext/configs/gpu/models}/mixtral_8x7b.yml (100%) rename src/{MaxText/configs => maxtext/configs/inference}/inference.yml (100%) rename src/{MaxText/configs => maxtext/configs/inference}/inference_jetstream.yml (100%) rename src/maxtext/{inference/configs/multi_host => configs/inference/multihost}/disaggregation/llama3_405b_v6e-16-16.yml (97%) rename src/{MaxText/configs/v5e => maxtext/configs/inference/multihost/interleaved}/llama2_70b_v5e-16.yml (97%) rename src/maxtext/{inference/configs/multi_host => configs/inference/multihost}/interleaved/llama3_405b_v5e-64.yml (97%) rename src/maxtext/{inference/configs/multi_host => configs/inference/multihost}/interleaved/llama3_70b_v5e-16.yml (97%) rename src/{MaxText => maxtext}/configs/models/deepseek-custom.yml (100%) rename src/{MaxText => maxtext}/configs/models/deepseek2-16b.yml (100%) rename src/{MaxText => maxtext}/configs/models/deepseek2-236b.yml (100%) rename src/{MaxText => maxtext}/configs/models/deepseek3-671b-2dfsdp.yml (100%) rename src/{MaxText => maxtext}/configs/models/deepseek3-671b.yml (100%) rename src/{MaxText => maxtext}/configs/models/deepseek3-test.yml (100%) create mode 100644 src/maxtext/configs/models/deepseek3-tiny.yml rename src/{MaxText => maxtext}/configs/models/deepseek3.2-671b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gemma-2b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gemma-7b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gemma2-27b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gemma2-2b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gemma2-9b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gemma3-12b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gemma3-27b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gemma3-4b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gpt-oss-120b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gpt-oss-20b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gpt3-175b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gpt3-22b.yml (100%) rename src/{MaxText => maxtext}/configs/models/gpt3-52k.yml (100%) rename src/{MaxText => maxtext}/configs/models/gpt3-6b.yml (100%) rename src/{MaxText => maxtext}/configs/models/kimi-k2-1t.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama2-13b.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama2-70b.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama2-7b.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama3-405b.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama3-70b.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama3-8b.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama3.1-405b.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama3.1-70b.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama3.1-8b.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama3.3-70b.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama4-17b-128e.yml (100%) rename src/{MaxText => maxtext}/configs/models/llama4-17b-16e.yml (100%) rename src/{MaxText => maxtext}/configs/models/mistral-7b.yml (100%) rename src/{MaxText => maxtext}/configs/models/mixtral-8x22b.yml (100%) rename src/{MaxText => maxtext}/configs/models/mixtral-8x7b.yml (100%) rename src/{MaxText => maxtext}/configs/models/olmo3_32b.yml (100%) rename src/{MaxText => maxtext}/configs/models/olmo3_7b.yml (100%) rename src/{MaxText => maxtext}/configs/models/qwen3-0.6b.yml (100%) rename src/{MaxText => maxtext}/configs/models/qwen3-14b.yml (100%) rename src/{MaxText => maxtext}/configs/models/qwen3-235b-a22b.yml (100%) rename src/{MaxText => maxtext}/configs/models/qwen3-30b-a3b.yml (100%) rename src/{MaxText => maxtext}/configs/models/qwen3-32b.yml (100%) rename src/{MaxText => maxtext}/configs/models/qwen3-480b-a35b.yml (100%) rename src/{MaxText => maxtext}/configs/models/qwen3-4b-thinking-2507.yml (100%) rename src/{MaxText => maxtext}/configs/models/qwen3-4b.yml (100%) rename src/{MaxText => maxtext}/configs/models/qwen3-8b.yml (100%) rename src/{MaxText => maxtext}/configs/models/qwen3-next-80b-a3b.yml (96%) rename src/{MaxText => maxtext}/configs/models/qwen3-omni-30b-a3b.yml (100%) rename src/{MaxText/configs => maxtext/configs/post_train}/distillation.yml (100%) rename src/{MaxText/configs => maxtext/configs/post_train}/dpo.yml (100%) rename src/{MaxText/configs => maxtext/configs/post_train}/rl.yml (100%) rename src/{MaxText/configs => maxtext/configs/post_train}/rl_mt_jt.yml (100%) rename src/{MaxText/configs => maxtext/configs/post_train}/sft-vision-chartqa.yml (100%) rename src/{MaxText/configs => maxtext/configs/post_train}/sft-vision-slidevqa.yml (100%) rename src/{MaxText/configs => maxtext/configs/post_train}/sft.yml (100%) rename src/{MaxText => maxtext}/configs/quantization/README.md (100%) rename src/{MaxText => maxtext}/configs/quantization/dense_llm_subchannel.json (100%) rename src/{MaxText => maxtext}/configs/quantization/dense_llm_weight_only_scale.json (100%) rename src/{MaxText => maxtext}/configs/quantization/int4_weight_only.json (100%) rename src/{MaxText => maxtext}/configs/quantization/int8_weight_only.json (100%) rename src/{MaxText/configs => maxtext/configs/tpu}/tpu_smoke_test.yml (100%) rename src/{MaxText/configs => maxtext/configs/tpu}/v4/22b.sh (83%) rename src/{MaxText/configs => maxtext/configs/tpu}/v4/52b.sh (84%) rename src/{MaxText/configs => maxtext/configs/tpu}/v4/README.md (100%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5e/128b.sh (82%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5e/16b.sh (81%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5e/32b.sh (81%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5e/64b.sh (81%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5e/README.md (100%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5e/gpt3_175b.sh (80%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5e/llama2_13b.sh (80%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5e/llama2_70b.sh (80%) rename src/maxtext/{inference/configs/multi_host/interleaved => configs/tpu/v5e}/llama2_70b_v5e-16.yml (97%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5e/llama2_7b.sh (80%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5e/llama3_405b_v5e-64.yml (97%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5e/llama3_70b_v5e-16.yml (97%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/1024b.sh (83%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/128b.sh (83%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/256b.sh (83%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/32b.sh (84%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/512b.sh (83%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/64b.sh (83%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/README.md (100%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/gpt3_175b/gpt3_175b_base.sh (85%) create mode 100644 src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_1024.sh rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/gpt3_175b/v5p_12288.sh (68%) create mode 100644 src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_2048.sh create mode 100644 src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_3072.sh rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/gpt3_175b/v5p_4096.sh (50%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/gpt3_175b/v5p_8192.sh (50%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/llama2_70b.sh (83%) rename src/{MaxText/configs => maxtext/configs/tpu}/v5p/llama2_7b.sh (84%) rename src/{MaxText/configs/trillium => maxtext/configs/tpu/v6e}/gemma2_27b.sh (85%) rename src/{MaxText/configs/trillium => maxtext/configs/tpu/v6e}/gemma2_9b.sh (86%) rename src/{MaxText/configs/trillium => maxtext/configs/tpu/v6e}/gemma3_27b.sh (85%) rename src/{MaxText/configs/trillium => maxtext/configs/tpu/v6e}/gpt3_175b.sh (86%) rename src/{MaxText/configs => maxtext/configs/tpu}/v6e/inference/llama4_maverick_v6e-64.yml (98%) rename src/{MaxText/configs/trillium => maxtext/configs/tpu/v6e}/llama2_7b_4096.sh (84%) rename src/{MaxText/configs/trillium => maxtext/configs/tpu/v6e}/mixtral_8x7b.sh (84%) rename src/{MaxText => maxtext}/configs/types.py (100%) rename src/{MaxText => maxtext}/configs/vllm.yml (100%) create mode 100644 tests/utils/test_helper.py diff --git a/.github/workflows/run_jupyter_notebooks.yml b/.github/workflows/run_jupyter_notebooks.yml index b9af2b74d1..b1efc5e59f 100644 --- a/.github/workflows/run_jupyter_notebooks.yml +++ b/.github/workflows/run_jupyter_notebooks.yml @@ -90,8 +90,11 @@ jobs: PYTHONPATH: "${{ github.workspace }}/src" HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | - MAXTEXT_REPO_ROOT=$(pwd) - MAXTEXT_NOTEBOOKS_ROOT="$MAXTEXT_REPO_ROOT/src/maxtext/examples" + source .venv/bin/activate + + export MAXTEXT_REPO_ROOT=$(pwd) + export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext + export MAXTEXT_NOTEBOOKS_ROOT="$MAXTEXT_REPO_ROOT/src/maxtext/examples" for notebook in "$MAXTEXT_NOTEBOOKS_ROOT"/{sft,rl}*.ipynb; do filename=$(basename "$notebook") @@ -101,7 +104,7 @@ jobs: echo "Running $filename ..." echo "------------------------------------------------------" - .venv/bin/papermill "$notebook" "$output_name" -k maxtext_venv + papermill "$notebook" "$output_name" -k maxtext_venv done - name: Record Commit IDs shell: bash diff --git a/.github/workflows/run_pathways_tests.yml b/.github/workflows/run_pathways_tests.yml index 08ab9eab32..ecf4182f22 100644 --- a/.github/workflows/run_pathways_tests.yml +++ b/.github/workflows/run_pathways_tests.yml @@ -100,7 +100,7 @@ jobs: export MAXTEXT_REPO_ROOT=$(pwd) export MAXTEXT_ASSETS_ROOT=$(pwd)/src/maxtext/assets export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets - export MAXTEXT_PKG_DIR=$(pwd)/src/MaxText + export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext # TODO(b/454659463): Enable test_default_hlo_match after volume mount is supported. .venv/bin/python3 -m pytest ${{ inputs.pytest_addopts }} -v -m "${FINAL_PYTEST_MARKER}" -k "not AotHloIdenticalTest and not CompileThenLoad" --durations=0 env: diff --git a/.github/workflows/run_tests_against_package.yml b/.github/workflows/run_tests_against_package.yml index 7ad07d1c17..6afda428ce 100644 --- a/.github/workflows/run_tests_against_package.yml +++ b/.github/workflows/run_tests_against_package.yml @@ -110,7 +110,7 @@ jobs: export MAXTEXT_REPO_ROOT=$(pwd) export MAXTEXT_ASSETS_ROOT=$(pwd)/src/maxtext/assets export MAXTEXT_TEST_ASSETS_ROOT=$(pwd)/tests/assets - export MAXTEXT_PKG_DIR=$(pwd)/src/MaxText + export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext # omit this libtpu init args for gpu tests if [ "${{ inputs.device_type }}" != "cuda12" ]; then export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' diff --git a/.vscode/launch.json b/.vscode/launch.json index fbf766b182..4da832302a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,7 +9,7 @@ "justMyCode": false, "python": "python3", "module": "maxtext.decode", - "args": ["src/MaxText/configs/base.yml", + "args": ["src/maxtext/configs/base.yml", "run_name=runner_$(date +%Y-%m-%d-%H-%M)", "base_output_directory=gs://test-maxtext-output", "dataset_path=gs://test-maxtext-dataset", @@ -36,7 +36,7 @@ "justMyCode": false, "python": "python3", "module": "maxtext.decode", - "args": ["src/MaxText/configs/base.yml", + "args": ["src/maxtext/configs/base.yml", "run_name=runner_$(date +%Y-%m-%d-%H-%M)", "base_output_directory=gs://test-maxtext-output", "dataset_path=gs://test-maxtext-dataset", @@ -52,7 +52,7 @@ "justMyCode": false, "python": "python3", "module": "MaxText.train", - "args": ["src/MaxText/configs/base.yml", + "args": ["src/maxtext/configs/base.yml", "run_name=runner_$(date +%Y-%m-%d-%H-%M)", "base_output_directory=gs://test-maxtext-output", "dataset_path=gs://test-maxtext-dataset", @@ -68,7 +68,7 @@ "python": "python3", "module": "maxtext.inference.inference_microbenchmark", "args": [ - "src/MaxText/configs/base.yml", + "src/maxtext/configs/base.yml", "model_name=llama2-7b", "tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.llama2", "weight_dtype=bfloat16", diff --git a/PREFLIGHT.md b/PREFLIGHT.md index b589bc1f4f..ddb44fb8ca 100644 --- a/PREFLIGHT.md +++ b/PREFLIGHT.md @@ -7,12 +7,12 @@ Before you run ML workload on Multihost with GCE or GKE, simply apply `bash pref Here is an example for GCE: ``` -bash preflight.sh PLATFORM=GCE && python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME +bash preflight.sh PLATFORM=GCE && python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME ``` Here is an example for GKE: ``` -bash preflight.sh PLATFORM=GKE && python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME +bash preflight.sh PLATFORM=GKE && python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME ``` # Optimization 2: Numa binding (You can only apply this to v4 and v5p) @@ -22,14 +22,14 @@ For GCE, [preflight.sh](https://github.com/google/maxtext/blob/main/preflight.sh) will help you install `numactl` dependency, so you can use it directly, here is an example: ``` -bash preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME +bash preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME ``` For GKE, `numactl` should be built into your docker image from [maxtext_tpu_dependencies.Dockerfile](https://github.com/google/maxtext/blob/main/dependencies/dockerfiles/maxtext_tpu_dependencies.Dockerfile), so you can use it directly if you built the maxtext docker image. Here is an example ``` -bash preflight.sh PLATFORM=GKE && numactl --membind 0 --cpunodebind=0 python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME +bash preflight.sh PLATFORM=GKE && numactl --membind 0 --cpunodebind=0 python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME ``` 1. `numactl`: This is the command-line tool used for controlling NUMA policy for processes or shared memory. It's particularly useful on multi-socket systems where memory locality can impact performance. diff --git a/README.md b/README.md index 383905d50d..f8d23e5efb 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies ## 🔥 Latest news 🔥 * \[December 22, 2025\] [Muon optimizer](https://kellerjordan.github.io/posts/muon) is now supported. -* \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model. +* \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model. * \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) are available. * \[December 4, 2025\] The [ReadTheDocs documentation site](https://maxtext.readthedocs.io/en/latest/index.html) has been reorganized. * \[December 3, 2025\] Multi-host support for GSPO and GRPO is now available via [new RL tutorials](https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/rl_on_multi_host.html). diff --git a/benchmarks/api_server/README.md b/benchmarks/api_server/README.md index ad65c1c27f..ff51371e54 100644 --- a/benchmarks/api_server/README.md +++ b/benchmarks/api_server/README.md @@ -33,7 +33,7 @@ export HF_TOKEN= The primary way to launch the API server is by using the `start_server.sh` script. This script ensures that the server is run from the project's root directory, which is necessary for the Python interpreter to find all the required modules. -The script takes the path to a base configuration file (e.g., `MaxText/configs/base.yml`) followed by any number of model-specific configuration overrides. +The script takes the path to a base configuration file (e.g., `maxtext/configs/base.yml`) followed by any number of model-specific configuration overrides. ### Benchmarking Configuration @@ -56,7 +56,7 @@ Here is an example of how to launch the server with a `qwen3-30b-a3b` model, con # Make sure you are in the root directory of the maxtext project. bash benchmarks/api_server/start_server.sh \ - MaxText/configs/base.yml \ + maxtext/configs/base.yml \ model_name="qwen3-30b-a3b" \ tokenizer_path="Qwen/Qwen3-30B-A3B-Thinking-2507" \ load_parameters_path="" \ @@ -135,7 +135,7 @@ CMD="export HF_TOKEN=${HF_TOKEN} && \ pip install --upgrade pip && \ pip install -r benchmarks/api_server/requirements.txt && \ bash benchmarks/api_server/start_server.sh \ - MaxText/configs/base.yml \ + maxtext/configs/base.yml \ model_name="${MODEL_NAME}" \ tokenizer_path="${TOKENIZER_PATH}" \ load_parameters_path="${LOAD_PARAMETERS_PATH}" \ diff --git a/benchmarks/api_server/launch_gke_server.sh.template b/benchmarks/api_server/launch_gke_server.sh.template index fedb5d8dd8..2f80608c47 100644 --- a/benchmarks/api_server/launch_gke_server.sh.template +++ b/benchmarks/api_server/launch_gke_server.sh.template @@ -53,7 +53,7 @@ CMD="export HF_TOKEN=${HF_TOKEN} && \ pip install --upgrade pip && \ pip install -r benchmarks/api_server/requirements.txt && \ bash benchmarks/api_server/start_server.sh \ - MaxText/configs/base.yml \ + maxtext/configs/base.yml \ model_name=\"${MODEL_NAME}\" \ tokenizer_path=\"${TOKENIZER_PATH}\" \ load_parameters_path=\"${LOAD_PARAMETERS_PATH}\" \ diff --git a/benchmarks/api_server/start_server.sh b/benchmarks/api_server/start_server.sh index 7d9b9aa431..cb115163a7 100644 --- a/benchmarks/api_server/start_server.sh +++ b/benchmarks/api_server/start_server.sh @@ -20,7 +20,7 @@ # # Example: # bash benchmarks/api_server/start_server.sh \ -# MaxText/configs/base.yml \ +# maxtext/configs/base.yml \ # model_name="qwen3-30b-a3b" \ # tokenizer_path="Qwen/Qwen3-30B-A3B-Thinking-2507" \ # load_parameters_path="" \ diff --git a/benchmarks/globals.py b/benchmarks/globals.py index db5ba34183..ed184f7dcd 100644 --- a/benchmarks/globals.py +++ b/benchmarks/globals.py @@ -25,7 +25,10 @@ r if os.path.isdir(os.path.join(r := os.path.dirname(os.path.dirname(__file__)), ".git")) else MAXTEXT_PKG_DIR, ) +# This is the configs root: with "base.yml"; "models/"; &etc. +MAXTEXT_CONFIGS_DIR = os.environ.get("MAXTEXT_CONFIGS_DIR", os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "configs")) + # This is the assets root: with "tokenizers/"; &etc. MAXTEXT_ASSETS_ROOT = os.environ.get("MAXTEXT_ASSETS_ROOT", os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "assets")) -__all__ = ["MAXTEXT_ASSETS_ROOT", "MAXTEXT_PKG_DIR", "MAXTEXT_REPO_ROOT"] +__all__ = ["MAXTEXT_ASSETS_ROOT", "MAXTEXT_CONFIGS_DIR", "MAXTEXT_PKG_DIR", "MAXTEXT_REPO_ROOT"] diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index 9837556f86..168a2e1554 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -35,7 +35,7 @@ import omegaconf import benchmarks.maxtext_trillium_model_configs as model_configs -from benchmarks.globals import MAXTEXT_PKG_DIR +from benchmarks.globals import MAXTEXT_CONFIGS_DIR from benchmarks.command_utils import run_command_with_updates import benchmarks.xla_flags_library as xla_flags from benchmarks.disruption_management.disruption_handler import DisruptionConfig @@ -107,7 +107,7 @@ class WorkloadConfig: generate_metrics_and_upload_to_big_query: bool = True hardware_id: str = "v6e" metrics_gcs_file: str = "" - base_config: str = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml") + base_config: str = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml") topology: str = dataclasses.field(init=False) num_devices_per_slice: int = dataclasses.field(init=False) db_project: str = "" @@ -354,7 +354,7 @@ def _build_args_from_config(wl_config: WorkloadConfig) -> dict: "xla_flags": f"'{xla_flags_str}'", "dataset": dataset, "run_type": "maxtext-xpk", - "config_file": os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + "config_file": os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml"), "topology": wl_config.topology, "tuning_params": f"'{tuning_params_str}'", "db_project": wl_config.db_project, @@ -440,7 +440,7 @@ def build_user_command( f"export JAX_PLATFORMS={jax_platforms} &&", "export ENABLE_PJRT_COMPATIBILITY=true &&", "export MAXTEXT_ASSETS_ROOT=/deps/src/maxtext/assets MAXTEXT_PKG_DIR=/deps/src/MaxText MAXTEXT_REPO_ROOT=/deps &&" - f'{hlo_dump} python3 -m MaxText.train {os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")}', + f'{hlo_dump} python3 -m MaxText.train {os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")}', f"{config_tuning_params}", f"steps={wl_config.num_steps}", f"model_name={wl_config.model.model_type}", diff --git a/benchmarks/mmlu/mmlu_eval.py b/benchmarks/mmlu/mmlu_eval.py index f4db7ceee9..3d5b017631 100644 --- a/benchmarks/mmlu/mmlu_eval.py +++ b/benchmarks/mmlu/mmlu_eval.py @@ -20,13 +20,13 @@ To run the MMLU benchmark: # Default is zero-shot prompting -python3 -m benchmarks.mmlu.mmlu_eval src/MaxText/configs/base.yml \ +python3 -m benchmarks.mmlu.mmlu_eval src/maxtext/configs/base.yml \ tokenizer_path=src/maxtext/assets/tokenizer_llama3.tiktoken \ load_parameters_path=check_point_path model_name=llama3.1-8b \ max_prefill_predict_length=1024 max_target_length=2048 ici_tensor_parallelism=4 per_device_batch_size=1 # Example of using the prompt_template flag for Chain-of-Thought (CoT) prompting: -python3 -m benchmarks.mmlu.mmlu_eval src/MaxText/configs/base.yml \ +python3 -m benchmarks.mmlu.mmlu_eval src/maxtext/configs/base.yml \ tokenizer_path=src/maxtext/assets/tokenizer_llama3.tiktoken \ load_parameters_path=check_point_path model_name=llama3.1-8b \ max_prefill_predict_length=1024 max_target_length=2048 ici_tensor_parallelism=4 per_device_batch_size=1 \ @@ -34,7 +34,7 @@ {choices}\nAnswer: Let's think step by step." # Example of using the prompt_template flag for 5-shot prompting (replace with actual examples): -python3 -m benchmarks.mmlu.mmlu_eval src/MaxText/configs/base.yml \ +python3 -m benchmarks.mmlu.mmlu_eval src/maxtext/configs/base.yml \ tokenizer_path=src/maxtext/assets/tokenizer_llama3.tiktoken \ load_parameters_path=check_point_path model_name=llama3.1-8b \ max_prefill_predict_length=1024 max_target_length=2048 ici_tensor_parallelism=4 per_device_batch_size=1 \ diff --git a/codecov.yml b/codecov.yml index 9511faab27..a2db5365c4 100644 --- a/codecov.yml +++ b/codecov.yml @@ -33,7 +33,7 @@ fixes: - "/github/workspace/::" ignore: - "src/maxtext/assets" - - "src/MaxText/configs" + - "src/maxtext/configs" - "src/maxtext/examples" - "src/MaxText/experimental" - "src/maxtext/inference" diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index b37d2923c8..31196685dc 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -66,7 +66,7 @@ export LAZY_LOAD_TENSORS= # True to use lazy load, False to u Finally, run below command to complete the conversion ```bash -python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_maxtext maxtext/configs/base.yml \ model_name=${HF_MODEL} \ hf_access_token=${HF_TOKEN} \ base_output_directory=${MODEL_CHECKPOINT_DIRECTORY} \ @@ -104,7 +104,7 @@ Use the `to_huggingface.py` script to convert a MaxText checkpoint into the Hugg The following command converts a MaxText checkpoint and saves it locally, to GCS, or uploads it directly to the Hugging Face Hub. ```bash -python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/MaxText/configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_huggingface src/maxtext/configs/base.yml \ model_name= \ load_parameters_path= \ base_output_directory= \ @@ -131,7 +131,7 @@ To ensure the conversion was successful, you can use the `tests/utils/forward_pa ### Usage ```bash -python3 -m tests.utils.forward_pass_logit_checker src/MaxText/configs/base.yml \ +python3 -m tests.utils.forward_pass_logit_checker src/maxtext/configs/base.yml \ tokenizer_path=assets/ \ load_parameters_path= \ model_name= \ @@ -216,8 +216,8 @@ To extend conversion support to a new model architecture, you must define its sp - In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer. 2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion. -1. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py), add the new model key in `HF_IDS`. -1. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in ['src/MaxText/configs/models'](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture. +3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py), add the new model key in `HF_IDS`. +4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in ['src/maxtext/configs/models'](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture. Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983) diff --git a/docs/guides/checkpointing_solutions/emergency_checkpointing.md b/docs/guides/checkpointing_solutions/emergency_checkpointing.md index f162765cbe..9e900bdda6 100644 --- a/docs/guides/checkpointing_solutions/emergency_checkpointing.md +++ b/docs/guides/checkpointing_solutions/emergency_checkpointing.md @@ -4,12 +4,12 @@ Emergency checkpointing is a vital feature for large-scale, multi-slice training ## Assumptions -* **GKE Environment**: A **Google Kubernetes Engine (GKE)** cluster must be used. GCE infrastructure solutions like QueuedResources are not supported. -* **Multi-Tier Checkpointing Enabled on GKE cluster level**: The Multi-Tier Checkpointing feature must be enabled and configured on your GKE cluster. This involves setting up the necessary CSI drivers and configurations as outlined in the [Google Cloud Checkpointing Documentation](https://cloud.google.com/kubernetes-engine/docs/how-to/machine-learning/training/multi-tier-checkpointing). -* **Multi-Slice Workload**: The training job must be a [multi-slice environment](https://cloud.google.com/kubernetes-engine/docs/how-to/tpu-multislice), meaning it utilizes more than one node pool. -* **Orbax Checkpointer**: The [Orbax library](https://orbax.readthedocs.io) must be used for checkpointing in your training script. -* **Ramdisk Mounted via Jobset**: Each workload pod must have a [ramdisk directory mounted by Jobset](https://cloud.google.com/kubernetes-engine/docs/how-to/machine-learning/training/multi-tier-checkpointing#update-jobset) using the Multi-Tier Checkpointing CSI driver. This provides a high-speed, in-memory storage location for checkpoints. -* **Supported TPU types**: [v4](https://cloud.google.com/tpu/docs/v4), [v5e](https://cloud.google.com/tpu/docs/v5e), [v5p](https://cloud.google.com/tpu/docs/v5p), and [v6e](https://cloud.google.com/tpu/docs/v6e) +- **GKE Environment**: A **Google Kubernetes Engine (GKE)** cluster must be used. GCE infrastructure solutions like QueuedResources are not supported. +- **Multi-Tier Checkpointing Enabled on GKE cluster level**: The Multi-Tier Checkpointing feature must be enabled and configured on your GKE cluster. This involves setting up the necessary CSI drivers and configurations as outlined in the [Google Cloud Checkpointing Documentation](https://cloud.google.com/kubernetes-engine/docs/how-to/machine-learning/training/multi-tier-checkpointing). +- **Multi-Slice Workload**: The training job must be a [multi-slice environment](https://cloud.google.com/kubernetes-engine/docs/how-to/tpu-multislice), meaning it utilizes more than one node pool. +- **Orbax Checkpointer**: The [Orbax library](https://orbax.readthedocs.io) must be used for checkpointing in your training script. +- **Ramdisk Mounted via Jobset**: Each workload pod must have a [ramdisk directory mounted by Jobset](https://cloud.google.com/kubernetes-engine/docs/how-to/machine-learning/training/multi-tier-checkpointing#update-jobset) using the Multi-Tier Checkpointing CSI driver. This provides a high-speed, in-memory storage location for checkpoints. +- **Supported TPU types**: [v4](https://cloud.google.com/tpu/docs/v4), [v5e](https://cloud.google.com/tpu/docs/v5e), [v5p](https://cloud.google.com/tpu/docs/v5p), and [v6e](https://cloud.google.com/tpu/docs/v6e) ## Cluster creation using XPK @@ -17,10 +17,10 @@ To run workloads with Emergency Checkpointing, you need a Google Kubernetes Engi The `xpk` script provides a streamlined way to create a GKE cluster with all the required MTC settings. The key flags used are: -* `--enable-mtc`: Enables the Multi-Tier Checkpointing feature. -* `--enable-gcsfuse-csi-driver`: Installs the required GCS FUSE CSI driver. -* `--mtc-ramdisk-size`: Allocates an in-memory ramdisk on each node for fast, local checkpoints. -* `--mtc-gcs-bucket`: Specifies the GCS bucket. It is not utilized in emergency checkpointing, but is needed to deploy checkpointing configurations. +- `--enable-mtc`: Enables the Multi-Tier Checkpointing feature. +- `--enable-gcsfuse-csi-driver`: Installs the required GCS FUSE CSI driver. +- `--mtc-ramdisk-size`: Allocates an in-memory ramdisk on each node for fast, local checkpoints. +- `--mtc-gcs-bucket`: Specifies the GCS bucket. It is not utilized in emergency checkpointing, but is needed to deploy checkpointing configurations. ### Calculating ramdisk size per host @@ -41,117 +41,121 @@ It's a good practice to add a **10-15% buffer** . Let's walk through an example for a large model. -* **Model**: A 70 billion parameter language model. -* **Training Slice**: A nodepool with **32 hosts**. +- **Model**: A 70 billion parameter language model. +- **Training Slice**: A nodepool with **32 hosts**. -1. **Estimate Total Checkpoint Size**: - `70,000,000,000 parameters × 12 bytes/parameter = 840,000,000,000 bytes` - `840,000,000,000 bytes ≈ 840 GB` -2. **Calculate Per-Host Checkpoint shard**: - `(Total Checkpoint Size / 32 hosts) = 26.25 GB per host` +1. **Estimate Total Checkpoint Size**: + `70,000,000,000 parameters × 12 bytes/parameter = 840,000,000,000 bytes` + `840,000,000,000 bytes ≈ 840 GB` -3. **Calculate Per-Host Ramdisk Size**: - `(Per-Host Checkpoint shard) * 2 = 52.50 GB per host` +2. **Calculate Per-Host Checkpoint shard**: + `(Total Checkpoint Size / 32 hosts) = 26.25 GB per host` -4. **Add a Safety Buffer (e.g., 15%)**: - `(Per-Host Ramdisk Size) × 1.15 ≈ 60.3 GB` +3. **Calculate Per-Host Ramdisk Size**: + `(Per-Host Checkpoint shard) * 2 = 52.50 GB per host` + +4. **Add a Safety Buffer (e.g., 15%)**: + `(Per-Host Ramdisk Size) × 1.15 ≈ 60.3 GB` In this scenario, you should configure each pod in that slice with a ramdisk of at least **60 GB**. ### Example XPK cluster creation command -1. **Set up environment variables:** - ```bash - OUTPUT_PATH= - PROJECT_ID= - ZONE= - CLUSTER_NAME= - TPU_TYPE= #example: v6e-256 - MACHINE_TYPE= - NUM_SLICES= - RAMDISK_SIZE= #example: 60000Mi - GKE_VERSION= #example: 1.32.3-gke.1785000 - ``` -2. **Configure gcloud:** - ```bash - gcloud config set project ${PROJECT_ID} - gcloud config set compute/zone ${ZONE} - ``` -3. **Clone the XPK repository:** - ```bash - git clone [https://github.com/AI-Hypercomputer/xpk.git](https://github.com/AI-Hypercomputer/xpk.git) - ``` -4. **Run the cluster creation command:** - ```bash - python3 xpk/xpk.py cluster create \ - --cluster ${CLUSTER_NAME} \ - --cluster-cpu-machine-type=${MACHINE_TYPE} \ - --num-slices=${NUM_SLICES} \ - --tpu-type=${TPU_TYPE} \ - --enable-mtc \ - --enable-gcsfuse-csi-driver \ - --mtc-ramdisk-size=${RAMDISK_SIZE} \ - --mtc-gcs-bucket=${OUTPUT_PATH} \ - --gke-version=${GKE_VERSION} - ``` +1. **Set up environment variables:** + ```bash + OUTPUT_PATH= + PROJECT_ID= + ZONE= + CLUSTER_NAME= + TPU_TYPE= #example: v6e-256 + MACHINE_TYPE= + NUM_SLICES= + RAMDISK_SIZE= #example: 60000Mi + GKE_VERSION= #example: 1.32.3-gke.1785000 + ``` +2. **Configure gcloud:** + ```bash + gcloud config set project ${PROJECT_ID} + gcloud config set compute/zone ${ZONE} + ``` +3. **Clone the XPK repository:** + ```bash + git clone [https://github.com/AI-Hypercomputer/xpk.git](https://github.com/AI-Hypercomputer/xpk.git) + ``` +4. **Run the cluster creation command:** + ```bash + python3 xpk/xpk.py cluster create \ + --cluster ${CLUSTER_NAME} \ + --cluster-cpu-machine-type=${MACHINE_TYPE} \ + --num-slices=${NUM_SLICES} \ + --tpu-type=${TPU_TYPE} \ + --enable-mtc \ + --enable-gcsfuse-csi-driver \ + --mtc-ramdisk-size=${RAMDISK_SIZE} \ + --mtc-gcs-bucket=${OUTPUT_PATH} \ + --gke-version=${GKE_VERSION} + ``` ## MaxText configuration MaxText provides a set of configuration flags to control checkpointing options. This configuration manages a `two-tiered checkpointing` system designed for both durability and rapid recovery. -* **Local Emergency Checkpoints**: It saves checkpoints much more frequently to a fast, local directory on each host (i.e. a ramdisk). If a preemption or failure occurs, the job can restore from this recent local copy, minimizing lost work without needing to download from slower persistent storage. This feature is enabled by setting `enable_checkpointing`, `enable_emergency_checkpoint`, `local_checkpoint_directory` and a non-zero `local_checkpoint_period`. +- **Local Emergency Checkpoints**: It saves checkpoints much more frequently to a fast, local directory on each host (i.e. a ramdisk). If a preemption or failure occurs, the job can restore from this recent local copy, minimizing lost work without needing to download from slower persistent storage. This feature is enabled by setting `enable_checkpointing`, `enable_emergency_checkpoint`, `local_checkpoint_directory` and a non-zero `local_checkpoint_period`. -* **Persistent Checkpoints**: These are standard checkpoints saved periodically and much more rarely to durable storage(i.e. GCS bucket). They ensure that you can recover your training state even after a complete cluster failure. This is controlled by `enable_checkpointing`, and `checkpoint_period`. +- **Persistent Checkpoints**: These are standard checkpoints saved periodically and much more rarely to durable storage(i.e. GCS bucket). They ensure that you can recover your training state even after a complete cluster failure. This is controlled by `enable_checkpointing`, and `checkpoint_period`. -| Flag | Description | Type | Default | -| :--- | :--- | :--- | :--- | -| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` | -| `enable_emergency_checkpoint` | When set to (`True`), this flag enables the two-tiered emergency checkpointing feature. | `boolean` | `False` | -| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` | -| `local_checkpoint_directory` | The high-speed local filesystem path(i.e. ramdisk) where **emergency checkpoints** are saved. Setting this path, along with a non-zero `local_checkpoint_period`, enables the emergency checkpointing feature. | `string` | `""` | -| `local_checkpoint_period` | The interval, in training steps, for how often a **local checkpoint** is saved. This should be set to a much smaller value than `checkpoint_period` for frequent, low-overhead saves. | `integer` | `0` | -| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved to **persistent storage**. | `integer` | `10000` | -| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage. | `boolean` | `False` | +| Flag | Description | Type | Default | +| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :-------- | :------ | +| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` | +| `enable_emergency_checkpoint` | When set to (`True`), this flag enables the two-tiered emergency checkpointing feature. | `boolean` | `False` | +| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` | +| `local_checkpoint_directory` | The high-speed local filesystem path(i.e. ramdisk) where **emergency checkpoints** are saved. Setting this path, along with a non-zero `local_checkpoint_period`, enables the emergency checkpointing feature. | `string` | `""` | +| `local_checkpoint_period` | The interval, in training steps, for how often a **local checkpoint** is saved. This should be set to a much smaller value than `checkpoint_period` for frequent, low-overhead saves. | `integer` | `0` | +| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved to **persistent storage**. | `integer` | `10000` | +| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage. | `boolean` | `False` | ## Workload creation using XPK The flags below would give the user access to the ramdisk in their workload: -| Flag | Description | -| :--- | :--- | -| `--mtc-enabled` | Enables the Multi-Tier Checkpointing feature, by mounting ramdisk to the workload pods, using csi drivers. | +| Flag | Description | +| :-------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `--mtc-enabled` | Enables the Multi-Tier Checkpointing feature, by mounting ramdisk to the workload pods, using csi drivers. | | `--ramdisk-directory` | Specifies the mount path inside each pod where the high-speed ramdisk will be accessible. Your training application should write its local, emergency checkpoints to this path. | ### Example XPK workload creation command -1. **Set up environment variables:** - ```bash - RAMDISK_DIRECTORY= - WORKLOAD_NAME= - TPU_TYPE= - NUM_SLICES= - PROJECT_ID= - LOCAL_CHECKPOINT_PERIOD=<> - CHECKPOINT_PEROID= - STEPS= - DATA_PATH= - OUTPUT_PATH= - ``` - -2. **Define the Docker image:** - ```bash - DOCKER_IMAGE=gcr.io/${PROJECT_ID}/${USER}_mtc_runner:latest - ``` - -3. **Run the workload creation command:** - ```bash - python3 xpk/xpk.py workload create \ - --cluster ${CLUSTER_NAME} \ - --docker-image ${DOCKER_IMAGE} \ - --workload ${WORKLOAD_NAME} \ - --tpu-type=${TPU_TYPE} \ - --num-slices=${NUM_SLICES} \ - --ramdisk-directory=${RAMDISK_DIRECTORY} \ - --mtc-enabled \ - --command "python3 src/MaxText/train.py src/MaxText/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH steps=120 per_device_batch_size=6 enable_checkpoint_cloud_logger=True checkpoint_period=${CHECKPOINT_PEROID} enable_emergency_checkpoint=True local_checkpoint_period=${LOCAL_CHECKPOINT_PERIOD} local_checkpoint_directory=/${RAMDISK_DIRECTORY}" - ``` +1. **Set up environment variables:** + + ```bash + RAMDISK_DIRECTORY= + WORKLOAD_NAME= + TPU_TYPE= + NUM_SLICES= + PROJECT_ID= + LOCAL_CHECKPOINT_PERIOD=<> + CHECKPOINT_PEROID= + STEPS= + DATA_PATH= + OUTPUT_PATH= + ``` + +2. **Define the Docker image:** + + ```bash + DOCKER_IMAGE=gcr.io/${PROJECT_ID}/${USER}_mtc_runner:latest + ``` + +3. **Run the workload creation command:** + + ```bash + python3 xpk/xpk.py workload create \ + --cluster ${CLUSTER_NAME} \ + --docker-image ${DOCKER_IMAGE} \ + --workload ${WORKLOAD_NAME} \ + --tpu-type=${TPU_TYPE} \ + --num-slices=${NUM_SLICES} \ + --ramdisk-directory=${RAMDISK_DIRECTORY} \ + --mtc-enabled \ + --command "python3 src/MaxText/train.py src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH steps=120 per_device_batch_size=6 enable_checkpoint_cloud_logger=True checkpoint_period=${CHECKPOINT_PEROID} enable_emergency_checkpoint=True local_checkpoint_period=${LOCAL_CHECKPOINT_PERIOD} local_checkpoint_directory=/${RAMDISK_DIRECTORY}" + ``` diff --git a/docs/guides/checkpointing_solutions/multi_tier_checkpointing.md b/docs/guides/checkpointing_solutions/multi_tier_checkpointing.md index 5bd80c2ca5..8770971d4e 100644 --- a/docs/guides/checkpointing_solutions/multi_tier_checkpointing.md +++ b/docs/guides/checkpointing_solutions/multi_tier_checkpointing.md @@ -4,40 +4,40 @@ Multi-tier checkpointing is a solution designed to optimize the storage and mana ## Purpose and benefits -* **Addresses frequent interruptions**: Large-scale ML training jobs are prone to frequent interruptions (potentially hourly), and recovery from these can be slow. -* **Improves Goodput**: By saving checkpoints more frequently and efficiently, multi-tier checkpointing reduces the amount of lost progress when a failure occurs, thereby increasing the overall Goodput of the training process. -* **Reduces MTTR**: The multi-tiered approach allows for faster restoration of training progress after a disruption. -* **Optimized restore**: During the ML training workload's startup, available checkpoint shards are asynchronously copied to the local ramdisk. These shards are pulled from the fastest available source, whether from local peer nodes or the backup on GCS persistent storage. This process ensures the data is ready to be picked up by Orbax from the ramdisk, minimizing startup delays. +- **Addresses frequent interruptions**: Large-scale ML training jobs are prone to frequent interruptions (potentially hourly), and recovery from these can be slow. +- **Improves Goodput**: By saving checkpoints more frequently and efficiently, multi-tier checkpointing reduces the amount of lost progress when a failure occurs, thereby increasing the overall Goodput of the training process. +- **Reduces MTTR**: The multi-tiered approach allows for faster restoration of training progress after a disruption. +- **Optimized restore**: During the ML training workload's startup, available checkpoint shards are asynchronously copied to the local ramdisk. These shards are pulled from the fastest available source, whether from local peer nodes or the backup on GCS persistent storage. This process ensures the data is ready to be picked up by Orbax from the ramdisk, minimizing startup delays. ## Architecture and tiers Multi-tier checkpointing stores checkpoints across multiple tiers of storage: -* **RAM (in-memory)**: Checkpoints are stored in each node's RAM for the fastest access and lowest latency. This is used for frequent, local saves. -* **In-cluster (peer replication)**: Checkpoints are replicated to other nodes or slices within the cluster. -* **GCS (persistent storage)**: Checkpoints are backed up to GCS for long-term durability and global accessibility. This tier is used for less frequent, but more robust, saves. +- **RAM (in-memory)**: Checkpoints are stored in each node's RAM for the fastest access and lowest latency. This is used for frequent, local saves. +- **In-cluster (peer replication)**: Checkpoints are replicated to other nodes or slices within the cluster. +- **GCS (persistent storage)**: Checkpoints are backed up to GCS for long-term durability and global accessibility. This tier is used for less frequent, but more robust, saves. ## Implementation details -* **GKE Component**: A managed GKE component is involved in handling high-scale checkpointing, including controllers, daemonsets, worker discovery, and rank assignment. -* **Local Storage**: For multi-tier checkpointing, local storage (such as ramdisk provided by a CSI ephemeral driver) is used for checkpoints, persisting across workload pod deletions. -* **Replication Service**: A replication service in a managed GKE component replicates checkpoints in-cluster and backs up local checkpoint files to GCS at certain intervals and is responsible for fetching latest checkpoint files to nodes without local checkpoints during restoration. +- **GKE Component**: A managed GKE component is involved in handling high-scale checkpointing, including controllers, daemonsets, worker discovery, and rank assignment. +- **Local Storage**: For multi-tier checkpointing, local storage (such as ramdisk provided by a CSI ephemeral driver) is used for checkpoints, persisting across workload pod deletions. +- **Replication Service**: A replication service in a managed GKE component replicates checkpoints in-cluster and backs up local checkpoint files to GCS at certain intervals and is responsible for fetching latest checkpoint files to nodes without local checkpoints during restoration. ## Comparison with other checkpointing methods -* **GCS Checkpointing**: This involves saving the model state directly to durable storage like GCS. However, this can be slow at larger model/cluster scales, blocking training, and leading to redundant data copies. -* **Emergency/Ramdisk Checkpointing**: While this method uses a low-latency ramdisk for checkpointing, Orbax manages the GCS save and restore operations at the workload level. As a result, saving to GCS blocks the training process during the device-to-host data transfer. -* **Multi-tier Checkpointing (Ramdisk + GCS)**: This approach combines the speed of ramdisk, the resilience of in-cluster replication, and the durability of GCS to offer a more robust and efficient solution. With Multi-Tier Checkpointing, above blocking issue is resolved because the GCS save is handled at the service level, operating on the checkpoint already saved locally. +- **GCS Checkpointing**: This involves saving the model state directly to durable storage like GCS. However, this can be slow at larger model/cluster scales, blocking training, and leading to redundant data copies. +- **Emergency/Ramdisk Checkpointing**: While this method uses a low-latency ramdisk for checkpointing, Orbax manages the GCS save and restore operations at the workload level. As a result, saving to GCS blocks the training process during the device-to-host data transfer. +- **Multi-tier Checkpointing (Ramdisk + GCS)**: This approach combines the speed of ramdisk, the resilience of in-cluster replication, and the durability of GCS to offer a more robust and efficient solution. With Multi-Tier Checkpointing, above blocking issue is resolved because the GCS save is handled at the service level, operating on the checkpoint already saved locally. ## Assumptions -* **GKE Environment**: A **Google Kubernetes Engine (GKE)** cluster must be used. GCE infrastructure solutions like QueuedResources are not supported. -* **Multi-Tier Checkpointing Enabled on GKE cluster level**: The Multi-Tier Checkpointing feature must be enabled and configured on your GKE cluster. This involves setting up the necessary CSI drivers and configurations as outlined in the [Google Cloud Checkpointing Documentation](https://cloud.google.com/kubernetes-engine/docs/how-to/machine-learning/training/multi-tier-checkpointing). -* **Multi-Slice Workload**: The training job must be a [multi-slice environment](https://cloud.google.com/kubernetes-engine/docs/how-to/tpu-multislice), meaning it utilizes more than one node pool. -* **Orbax Checkpointer**: The [Orbax library](https://orbax.readthedocs.io) must be used for checkpointing in your training script. -* **Ramdisk Mounted via Jobset**: Each workload pod must have a [ramdisk directory mounted by Jobset](https://cloud.google.com/kubernetes-engine/docs/how-to/machine-learning/training/multi-tier-checkpointing#update-jobset) using the Multi-Tier Checkpointing CSI driver. This provides a high-speed, in-memory storage location for checkpoints. -* **Supported TPU types**: [v4](https://cloud.google.com/tpu/docs/v4), [v5e](https://cloud.google.com/tpu/docs/v5e), [v5p](https://cloud.google.com/tpu/docs/v5p), and [v6e](https://cloud.google.com/tpu/docs/v6e) -* **Cluster version**: Gke cluster version needs to be later than [1.32.3-gke.1170000](https://cloud.google.com/kubernetes-engine/docs/how-to/machine-learning/training/multi-tier-checkpointing#existing-cluster). +- **GKE Environment**: A **Google Kubernetes Engine (GKE)** cluster must be used. GCE infrastructure solutions like QueuedResources are not supported. +- **Multi-Tier Checkpointing Enabled on GKE cluster level**: The Multi-Tier Checkpointing feature must be enabled and configured on your GKE cluster. This involves setting up the necessary CSI drivers and configurations as outlined in the [Google Cloud Checkpointing Documentation](https://cloud.google.com/kubernetes-engine/docs/how-to/machine-learning/training/multi-tier-checkpointing). +- **Multi-Slice Workload**: The training job must be a [multi-slice environment](https://cloud.google.com/kubernetes-engine/docs/how-to/tpu-multislice), meaning it utilizes more than one node pool. +- **Orbax Checkpointer**: The [Orbax library](https://orbax.readthedocs.io) must be used for checkpointing in your training script. +- **Ramdisk Mounted via Jobset**: Each workload pod must have a [ramdisk directory mounted by Jobset](https://cloud.google.com/kubernetes-engine/docs/how-to/machine-learning/training/multi-tier-checkpointing#update-jobset) using the Multi-Tier Checkpointing CSI driver. This provides a high-speed, in-memory storage location for checkpoints. +- **Supported TPU types**: [v4](https://cloud.google.com/tpu/docs/v4), [v5e](https://cloud.google.com/tpu/docs/v5e), [v5p](https://cloud.google.com/tpu/docs/v5p), and [v6e](https://cloud.google.com/tpu/docs/v6e) +- **Cluster version**: Gke cluster version needs to be later than [1.32.3-gke.1170000](https://cloud.google.com/kubernetes-engine/docs/how-to/machine-learning/training/multi-tier-checkpointing#existing-cluster). ## Cluster creation using XPK @@ -45,13 +45,12 @@ To run workloads with Multi-Tier Checkpointing (MTC), you need a Google Kubernet The [xpk script](https://github.com/AI-Hypercomputer/xpk/blob/main/xpk.py) provides a streamlined way to create a GKE cluster with all the required MTC settings. The key flags used are: -| Flag | Description | -| :--- | :--- | -| `--enable-mtc` | Enables the Multi-Tier Checkpointing feature. | -| `--enable-gcsfuse-csi-driver` | Installs the required GCS FUSE CSI driver. | -| `--mtc-ramdisk-size` | Allocates an in-memory ramdisk on each node for fast, local checkpoints. | -| `--mtc-gcs-bucket` | Specifies the GCS bucket. | - +| Flag | Description | +| :---------------------------- | :----------------------------------------------------------------------- | +| `--enable-mtc` | Enables the Multi-Tier Checkpointing feature. | +| `--enable-gcsfuse-csi-driver` | Installs the required GCS FUSE CSI driver. | +| `--mtc-ramdisk-size` | Allocates an in-memory ramdisk on each node for fast, local checkpoints. | +| `--mtc-gcs-bucket` | Specifies the GCS bucket. | ### Calculating ramdisk size per host @@ -72,116 +71,120 @@ It's a good practice to add a **10-15% buffer** . Let's walk through an example for a large model. -* **Model**: A 70 billion parameter language model. -* **Training Slice**: A nodepool with **32 hosts**. +- **Model**: A 70 billion parameter language model. +- **Training Slice**: A nodepool with **32 hosts**. + +1. **Estimate Total Checkpoint Size**: + `70,000,000,000 parameters × 12 bytes/parameter = 840,000,000,000 bytes` + `840,000,000,000 bytes ≈ 840 GB` -1. **Estimate Total Checkpoint Size**: - `70,000,000,000 parameters × 12 bytes/parameter = 840,000,000,000 bytes` - `840,000,000,000 bytes ≈ 840 GB` -2. **Calculate Per-Host Checkpoint shard**: - `(Total Checkpoint Size / 32 hosts) = 26.25 GB per host` +2. **Calculate Per-Host Checkpoint shard**: + `(Total Checkpoint Size / 32 hosts) = 26.25 GB per host` -3. **Calculate Per-Host Ramdisk Size**: - `(Per-Host Checkpoint shard) * 2 = 52.50 GB per host` +3. **Calculate Per-Host Ramdisk Size**: + `(Per-Host Checkpoint shard) * 2 = 52.50 GB per host` -4. **Add a Safety Buffer (e.g., 15%)**: - `(Per-Host Ramdisk Size) × 1.15 ≈ 60.3 GB` +4. **Add a Safety Buffer (e.g., 15%)**: + `(Per-Host Ramdisk Size) × 1.15 ≈ 60.3 GB` In this scenario, you should configure each pod in that slice with a ramdisk of at least **60 GB**. ### Example XPK cluster creation command -1. **Set up environment variables:** - ```bash - OUTPUT_PATH= - PROJECT_ID= - ZONE= - CLUSTER_NAME= - TPU_TYPE= #example: v6e-256 - MACHINE_TYPE= - NUM_SLICES= - RAMDISK_SIZE= #example: 60000Mi - GKE_VERSION= #example: 1.32.3-gke.1785000 - ``` -2. **Configure gcloud:** - ```bash - gcloud config set project ${PROJECT_ID} - gcloud config set compute/zone ${ZONE} - ``` -3. **Clone the XPK repository:** - ```bash - git clone [https://github.com/AI-Hypercomputer/xpk.git](https://github.com/AI-Hypercomputer/xpk.git) - ``` -4. **Run the cluster creation command:** - ```bash - python3 xpk/xpk.py cluster create \ - --cluster ${CLUSTER_NAME} \ - --cluster-cpu-machine-type=${MACHINE_TYPE} \ - --num-slices=${NUM_SLICES} \ - --tpu-type=${TPU_TYPE} \ - --enable-mtc \ - --enable-gcsfuse-csi-driver \ - --mtc-ramdisk-size=${RAMDISK_SIZE} \ - --mtc-gcs-bucket=${OUTPUT_PATH} \ - --gke-version=${GKE_VERSION} - ``` +1. **Set up environment variables:** + ```bash + OUTPUT_PATH= + PROJECT_ID= + ZONE= + CLUSTER_NAME= + TPU_TYPE= #example: v6e-256 + MACHINE_TYPE= + NUM_SLICES= + RAMDISK_SIZE= #example: 60000Mi + GKE_VERSION= #example: 1.32.3-gke.1785000 + ``` +2. **Configure gcloud:** + ```bash + gcloud config set project ${PROJECT_ID} + gcloud config set compute/zone ${ZONE} + ``` +3. **Clone the XPK repository:** + ```bash + git clone [https://github.com/AI-Hypercomputer/xpk.git](https://github.com/AI-Hypercomputer/xpk.git) + ``` +4. **Run the cluster creation command:** + ```bash + python3 xpk/xpk.py cluster create \ + --cluster ${CLUSTER_NAME} \ + --cluster-cpu-machine-type=${MACHINE_TYPE} \ + --num-slices=${NUM_SLICES} \ + --tpu-type=${TPU_TYPE} \ + --enable-mtc \ + --enable-gcsfuse-csi-driver \ + --mtc-ramdisk-size=${RAMDISK_SIZE} \ + --mtc-gcs-bucket=${OUTPUT_PATH} \ + --gke-version=${GKE_VERSION} + ``` ## MaxText configuration This configuration manages a `multi-tiered checkpointing` system designed for both durability and rapid recovery. -* **Local checkpointing**: Saves checkpoints much more frequently to a fast, local directory on each host (i.e. a ramdisk). If a preemption or failure occurs, the job can restore from this recent local copy almost instantly, minimizing lost work without needing to download from slower persistent storage. This feature is enabled by setting `enable_checkpointing`, `enable_multi_tier_checkpointing`, `local_checkpoint_directory`, and a non-zero `local_checkpoint_period` flags. +- **Local checkpointing**: Saves checkpoints much more frequently to a fast, local directory on each host (i.e. a ramdisk). If a preemption or failure occurs, the job can restore from this recent local copy almost instantly, minimizing lost work without needing to download from slower persistent storage. This feature is enabled by setting `enable_checkpointing`, `enable_multi_tier_checkpointing`, `local_checkpoint_directory`, and a non-zero `local_checkpoint_period` flags. -* **Backup checkpointing**: These are checkpoints saved periodically to persistent storage(i.e. GCS bucket). They ensure that you can recover your training state even after a complete job failure(repair of all nodepools). From User's perspective all restoration is from local ramdisk, its replicator service responsibility to make the checkpointing available to local storage in case of job restart. The interval for backup can be enabled by setting a non-zero `multi_tier_checkpointing_backup_interval_minutes` flags. +- **Backup checkpointing**: These are checkpoints saved periodically to persistent storage(i.e. GCS bucket). They ensure that you can recover your training state even after a complete job failure(repair of all nodepools). From User's perspective all restoration is from local ramdisk, its replicator service responsibility to make the checkpointing available to local storage in case of job restart. The interval for backup can be enabled by setting a non-zero `multi_tier_checkpointing_backup_interval_minutes` flags. -| Flag | Description | Type | Default | -| :--- | :--- | :--- | :--- | -| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` | -| `enable_multi_tier_checkpointing` | When set to (`True`), this flag enables the multi-tier checkpointing feature on maxtext level. | `boolean` | `False` | -| `local_checkpoint_directory` | The high-speed local filesystem path(i.e. ramdisk) where **Multi-tier checkpoints** are saved. Setting this path, along with a non-zero `local_checkpoint_period`, enables the Multi-tier Checkpointing feature. | `string` | `""` | -| `local_checkpoint_period` | The interval, in training steps, for how often a **Multi-tier checkpoint** is saved in local ramdisks. | `integer` | `0` | -| `multi_tier_checkpointing_backup_interval_minutes`| The interval, in minutes, for how often a **Multi-tier checkpoint** is saved to backup from local ramdisks. | `integer` | `0` | +| Flag | Description | Type | Default | +| :------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :-------- | :------ | +| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` | +| `enable_multi_tier_checkpointing` | When set to (`True`), this flag enables the multi-tier checkpointing feature on maxtext level. | `boolean` | `False` | +| `local_checkpoint_directory` | The high-speed local filesystem path(i.e. ramdisk) where **Multi-tier checkpoints** are saved. Setting this path, along with a non-zero `local_checkpoint_period`, enables the Multi-tier Checkpointing feature. | `string` | `""` | +| `local_checkpoint_period` | The interval, in training steps, for how often a **Multi-tier checkpoint** is saved in local ramdisks. | `integer` | `0` | +| `multi_tier_checkpointing_backup_interval_minutes` | The interval, in minutes, for how often a **Multi-tier checkpoint** is saved to backup from local ramdisks. | `integer` | `0` | ### Workload creation using XPK The flags below would give the user access to the ramdisk in their workload: -| Flag | Description | -| :--- | :--- | -| `--mtc-enabled` | Enables the Multi-Tier Checkpointing feature, by mounting ramdisk to the workload pods, using csi drivers. | +| Flag | Description | +| :-------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `--mtc-enabled` | Enables the Multi-Tier Checkpointing feature, by mounting ramdisk to the workload pods, using csi drivers. | | `--ramdisk-directory` | Specifies the mount path inside each pod where the high-speed ramdisk will be accessible. Your training application should write its local, emergency checkpoints to this path. | ### Example XPK workload creation command -1. **Set up environment variables:** - ```bash - RAMDISK_DIRECTORY= - WORKLOAD_NAME= - TPU_TYPE= - NUM_SLICES= - PROJECT_ID= - LOCAL_CHECKPOINT_PERIOD=<> - CHECKPOINT_PEROID= - STEPS= - DATA_PATH= - OUTPUT_PATH= - MULTI_TIER_CHECKPOINTING_BACKUP_INT_MIN= - ``` - -2. **Define the Docker image:** - ```bash - DOCKER_IMAGE=gcr.io/${PROJECT_ID}/${USER}_mtc_runner:latest - ``` - -3. **Run the workload creation command:** - ```bash - python3 xpk/xpk.py workload create \ - --cluster ${CLUSTER_NAME} \ - --docker-image ${DOCKER_IMAGE} \ - --workload ${WORKLOAD_NAME} \ - --tpu-type=${TPU_TYPE} \ - --num-slices=${NUM_SLICES} \ - --ramdisk-directory=${RAMDISK_DIRECTORY} \ - --mtc-enabled \ - --command "python3 src/MaxText/train.py src/MaxText/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH steps=120 per_device_batch_size=6 enable_checkpoint_cloud_logger=True checkpoint_period=${CHECKPOINT_PEROID} enable_multi_tier_checkpointing=True local_checkpoint_period=${LOCAL_CHECKPOINT_PERIOD} local_checkpoint_directory=/${RAMDISK_DIRECTORY} multi_tier_checkpointing_backup_interval_minutes=${MULTI_TIER_CHECKPOINTING_BACKUP_INT_MIN}" - ``` +1. **Set up environment variables:** + + ```bash + RAMDISK_DIRECTORY= + WORKLOAD_NAME= + TPU_TYPE= + NUM_SLICES= + PROJECT_ID= + LOCAL_CHECKPOINT_PERIOD=<> + CHECKPOINT_PEROID= + STEPS= + DATA_PATH= + OUTPUT_PATH= + MULTI_TIER_CHECKPOINTING_BACKUP_INT_MIN= + ``` + +2. **Define the Docker image:** + + ```bash + DOCKER_IMAGE=gcr.io/${PROJECT_ID}/${USER}_mtc_runner:latest + ``` + +3. **Run the workload creation command:** + + ```bash + python3 xpk/xpk.py workload create \ + --cluster ${CLUSTER_NAME} \ + --docker-image ${DOCKER_IMAGE} \ + --workload ${WORKLOAD_NAME} \ + --tpu-type=${TPU_TYPE} \ + --num-slices=${NUM_SLICES} \ + --ramdisk-directory=${RAMDISK_DIRECTORY} \ + --mtc-enabled \ + --command "python3 src/MaxText/train.py src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH steps=120 per_device_batch_size=6 enable_checkpoint_cloud_logger=True checkpoint_period=${CHECKPOINT_PEROID} enable_multi_tier_checkpointing=True local_checkpoint_period=${LOCAL_CHECKPOINT_PERIOD} local_checkpoint_directory=/${RAMDISK_DIRECTORY} multi_tier_checkpointing_backup_interval_minutes=${MULTI_TIER_CHECKPOINTING_BACKUP_INT_MIN}" + ``` diff --git a/docs/guides/data_input_pipeline/data_input_grain.md b/docs/guides/data_input_pipeline/data_input_grain.md index 9f49605819..1a1ecd0099 100644 --- a/docs/guides/data_input_pipeline/data_input_grain.md +++ b/docs/guides/data_input_pipeline/data_input_grain.md @@ -3,12 +3,13 @@ ## The recommended input pipeline for determinism and resilience! [Grain](https://google-grain.readthedocs.io/en/latest/) is a library for reading data for training and evaluating JAX models. It’s designed to be: -* **Powerful**: Users can bring arbitrary Python transformations. -* **Flexible**: Users can readily override Grain components for their needs. -* **Deterministic**: Multiple runs of the same pipeline will produce the same outputs. -* **Resilient to preemptions**: With minimal-sized checkpoints, users can resume the dataloader from the point at which it was preempted and produce the same output as if it was never preempted. -* **Performant**: Achieved with multiprocessing with shared memory. Tested on multiple data modalities. -* **With minimal dependencies**: Does not depend on ML frameworks (Tensorflow). + +- **Powerful**: Users can bring arbitrary Python transformations. +- **Flexible**: Users can readily override Grain components for their needs. +- **Deterministic**: Multiple runs of the same pipeline will produce the same outputs. +- **Resilient to preemptions**: With minimal-sized checkpoints, users can resume the dataloader from the point at which it was preempted and produce the same output as if it was never preempted. +- **Performant**: Achieved with multiprocessing with shared memory. Tested on multiple data modalities. +- **With minimal dependencies**: Does not depend on ML frameworks (Tensorflow). ## Why is determinism important for a data input pipeline? @@ -20,19 +21,19 @@ Grain ensures determinism in data input pipelines by saving the pipeline's state ## Cases where determinism is crucial -* **Model sensitive to repetition**: When models are sensitive to the frequency with which they encounter specific examples, precise control over the order and repetition of data during training is essential. All LLMs belong to this category. -* **Convergence comparison**: In sensitive convergence experiments like testing quantization techniques, maintaining identical data batches between runs (e.g., quantized vs. unquantized) is essential for comparison. Determinism ensures consistency even when the runs are long and undergo saving/resuming at different steps. -* **Debug training anomalies**: When troubleshooting training spikes or anomalies, the ability to replay the exact data sequence helps distinguish between bad data batches and underlying hardware or software issues. +- **Model sensitive to repetition**: When models are sensitive to the frequency with which they encounter specific examples, precise control over the order and repetition of data during training is essential. All LLMs belong to this category. +- **Convergence comparison**: In sensitive convergence experiments like testing quantization techniques, maintaining identical data batches between runs (e.g., quantized vs. unquantized) is essential for comparison. Determinism ensures consistency even when the runs are long and undergo saving/resuming at different steps. +- **Debug training anomalies**: When troubleshooting training spikes or anomalies, the ability to replay the exact data sequence helps distinguish between bad data batches and underlying hardware or software issues. ## Data shuffling -* **Global shuffle**: This feature is only available when using Grain with [ArrayRecord](https://github.com/google/array_record) (random access) format, achieved by shuffling indices globally at the beginning of each epoch and then reading the elements according to the random order. This shuffle method effectively prevents local overfitting, leading to better training results. -* **Hierarchical shuffle**: For sequential access format [Parquet](https://arrow.apache.org/docs/python/parquet.html), shuffle is performed by these steps: file shuffling, interleave from files, and window shuffle using a fixed size buffer. +- **Global shuffle**: This feature is only available when using Grain with [ArrayRecord](https://github.com/google/array_record) (random access) format, achieved by shuffling indices globally at the beginning of each epoch and then reading the elements according to the random order. This shuffle method effectively prevents local overfitting, leading to better training results. +- **Hierarchical shuffle**: For sequential access format [Parquet](https://arrow.apache.org/docs/python/parquet.html), shuffle is performed by these steps: file shuffling, interleave from files, and window shuffle using a fixed size buffer. ## Using Grain 1. Grain currently supports two data formats: [ArrayRecord](https://github.com/google/array_record) (random access) and [Parquet](https://arrow.apache.org/docs/python/parquet.html) (partial random-access through row groups). Only the ArrayRecord format supports the global shuffle mentioned above. For converting a dataset into ArrayRecord, see [Apache Beam Integration for ArrayRecord](https://github.com/google/array_record/tree/main/beam). Additionally, other random access data sources can be supported via a custom [data source](https://google-grain.readthedocs.io/en/latest/data_sources.html) class. - * **Community Resource**: The MaxText community has created a [ArrayRecord Documentation](https://array-record.readthedocs.io/). Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet. + - **Community Resource**: The MaxText community has created a [ArrayRecord Documentation](https://array-record.readthedocs.io/). Note: we appreciate the contribution from the community, but as of now it has not been verified by the MaxText or ArrayRecord developers yet. 2. When the dataset is hosted on a Cloud Storage bucket, Grain can read it through [Cloud Storage FUSE](https://cloud.google.com/storage/docs/gcs-fuse). The installation of Cloud Storage FUSE is included in [setup.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup.sh). The user then needs to mount the Cloud Storage bucket to a local path for each worker, using the script [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh). The script configures some parameters for the mount. ```sh @@ -44,11 +45,11 @@ MOUNT_PATH=$MOUNT_PATH \ Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pre-filling the metadata cache (see ["Performance tuning best practices" on the Google Cloud documentation](https://cloud.google.com/storage/docs/cloud-storage-fuse/performance#improve-first-time-reads)). -3. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet}`, `grain_train_files` in `src/MaxText/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path. +1. Set `dataset_type=grain`, `grain_file_type={arrayrecord|parquet}`, `grain_train_files` in `src/maxtext/configs/base.yml` or through command line arguments to match the file pattern on the mounted local path. -4. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html), [grain_pool.py](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling. +2. Tune `grain_worker_count` for performance. This parameter controls the number of child processes used by Grain (more details in [behind_the_scenes](https://google-grain.readthedocs.io/en/latest/behind_the_scenes.html), [grain_pool.py](https://github.com/google/grain/blob/main/grain/_src/python/grain_pool.py)). If you use a large number of workers, check your config for gcsfuse in [setup_gcsfuse.sh](https://github.com/google/maxtext/blob/main/tools/setup/setup_gcsfuse.sh) to avoid gcsfuse throttling. -5. ArrayRecord Only: For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and a comma (,) for weights. The weights will be automatically normalized to sum to 1.0. For example: +3. ArrayRecord Only: For multi-source blending, you can specify multiple data sources with their respective weights using semicolon (;) as a separator and a comma (,) for weights. The weights will be automatically normalized to sum to 1.0. For example: ``` # Blend two data sources with 30% from first source and 70% from second source @@ -105,13 +106,13 @@ Note that `FILE_PATH` is optional; when provided, the script runs `ls -R` for pr Packing and multi-process prefetching (mp_prefetch) operations rely on buffers. When a data mixture is updated, these buffers cannot be recovered, leading to discarded unused elements and thus minor skipping in the training data. ``` -6. Example command: +4. Example command: ```sh bash tools/setup/setup_gcsfuse.sh \ DATASET_GCS_BUCKET=maxtext-dataset \ MOUNT_PATH=/tmp/gcsfuse && \ -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ run_name= base_output_directory=gs:// \ dataset_type=grain \ grain_file_type=arrayrecord # or parquet \ @@ -119,9 +120,9 @@ grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record* \ grain_worker_count=2 ``` -7. Using validation set for evaluation +1. Using validation set for evaluation -When setting eval_interval > 0, evaluation will be run with a specified eval dataset. Example config (set in [`src/MaxText/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/base.yml) or through command line): +When setting eval_interval > 0, evaluation will be run with a specified eval dataset. Example config (set in [`src/maxtext/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/base.yml) or through command line): ```yaml eval_interval: 10000 @@ -129,9 +130,10 @@ eval_steps: 50 grain_eval_files: '/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-validation.array_record*' ``` -8. Experimental: resuming training with a different chip count +1. Experimental: resuming training with a different chip count In Grain checkpoints, each data-loading host has a corresponding JSON file. For cases where a user wants to resume training with a different number of data-loading hosts, MaxText provides an experimental feature: -* **Scaling up**: For example, if you have a checkpoint from 64 data-loading hosts and want to resume training with 128. This is achieved by having a subset of the hosts load the real data, which is then sent to the other hosts. The flag `expansion_factor_real_data` (default is -1) controls this behavior. When set to a value greater than 1, the number of hosts loading real data is `total number of hosts // expansion_factor_real_data`. Each of these data-loading hosts will load `expansion_factor_real_data * per_host_batch_size_to_train`. For code integrity, the non-loading hosts use a `PlaceHolderDataIterator` to generate dummy data, which is later discarded. A user can optionally set `max_checkify=true` to enable additional checks that ensure dummy data is not used for training. In this example, you would set `expansion_factor_real_data=2` to scale from 64 to 128 hosts. -* **Scaling down**: For example, if you have a checkpoint from 128 data-loading hosts and want to resume with 64. This is achieved by restoring multiple data iterators on each host. Set flag `expansion_factor_real_data` to have each host restore `1 / expansion_factor_real_data` data iterators. We then alternate between these iterators to produce batches. In this example, you would set `expansion_factor_real_data=0.5` to scale from 128 down to 64 hosts. -* **Note**: In both scaling up and scaling down scenarios, the `per_device_batch_size` must remain consistent. This is because Grain records the number of iterations (batches) in the iterator's state, and changing the batch size will result in either skipping or duplicating data. + +- **Scaling up**: For example, if you have a checkpoint from 64 data-loading hosts and want to resume training with 128. This is achieved by having a subset of the hosts load the real data, which is then sent to the other hosts. The flag `expansion_factor_real_data` (default is -1) controls this behavior. When set to a value greater than 1, the number of hosts loading real data is `total number of hosts // expansion_factor_real_data`. Each of these data-loading hosts will load `expansion_factor_real_data * per_host_batch_size_to_train`. For code integrity, the non-loading hosts use a `PlaceHolderDataIterator` to generate dummy data, which is later discarded. A user can optionally set `max_checkify=true` to enable additional checks that ensure dummy data is not used for training. In this example, you would set `expansion_factor_real_data=2` to scale from 64 to 128 hosts. +- **Scaling down**: For example, if you have a checkpoint from 128 data-loading hosts and want to resume with 64. This is achieved by restoring multiple data iterators on each host. Set flag `expansion_factor_real_data` to have each host restore `1 / expansion_factor_real_data` data iterators. We then alternate between these iterators to produce batches. In this example, you would set `expansion_factor_real_data=0.5` to scale from 128 down to 64 hosts. +- **Note**: In both scaling up and scaling down scenarios, the `per_device_batch_size` must remain consistent. This is because Grain records the number of iterations (batches) in the iterator's state, and changing the batch size will result in either skipping or duplicating data. diff --git a/docs/guides/data_input_pipeline/data_input_hf.md b/docs/guides/data_input_pipeline/data_input_hf.md index 25e52c1078..ee8d0c67a6 100644 --- a/docs/guides/data_input_pipeline/data_input_hf.md +++ b/docs/guides/data_input_pipeline/data_input_hf.md @@ -4,7 +4,7 @@ The Hugging Face pipeline supports streaming directly from the Hugging Face Hub, ## Example config for streaming from Hugging Face Hub (no download needed) -In [`src/MaxText/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/base.yml) or through command line, set the following parameters: +In [`src/maxtext/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/base.yml) or through command line, set the following parameters: ```yaml dataset_type: hf @@ -23,7 +23,7 @@ hf_access_token: '' # provide token if using gated dataset or tokenizer ## Example config for streaming from downloaded data in a Cloud Storage bucket -In [`src/MaxText/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/base.yml) or through the command line, set the following parameters: +In [`src/maxtext/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/base.yml) or through the command line, set the following parameters: ```yaml dataset_type: hf diff --git a/docs/guides/data_input_pipeline/data_input_tfds.md b/docs/guides/data_input_pipeline/data_input_tfds.md index 03c38e9838..acbf064055 100644 --- a/docs/guides/data_input_pipeline/data_input_tfds.md +++ b/docs/guides/data_input_pipeline/data_input_tfds.md @@ -6,7 +6,7 @@ bash download_dataset.sh {GCS_PROJECT} {GCS_BUCKET_NAME} ``` -2. In [`src/MaxText/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/base.yml) or through command line, set the following parameters: +2. In [`src/maxtext/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/base.yml) or through command line, set the following parameters: ```yaml dataset_type: tfds diff --git a/docs/guides/monitoring_and_debugging/features_and_diagnostics.md b/docs/guides/monitoring_and_debugging/features_and_diagnostics.md index 58049503f7..c2cd817a82 100644 --- a/docs/guides/monitoring_and_debugging/features_and_diagnostics.md +++ b/docs/guides/monitoring_and_debugging/features_and_diagnostics.md @@ -17,51 +17,61 @@ # Features and diagnostics ## Collect stack traces + When running a Single Program, Multiple Data (SPMD) job on accelerators, the overall process can hang if there is any error or any VM hangs/crashes for some reason. In this scenario, capturing stack traces will help to identify and troubleshoot the issues for the jobs running on TPU VMs. -The following configurations will help to debug a fault or when a program is stuck or hung somewhere by collecting stack traces. Change the parameter values accordingly in `src/MaxText/configs/base.yml`: +The following configurations will help to debug a fault or when a program is stuck or hung somewhere by collecting stack traces. Change the parameter values accordingly in `src/maxtext/configs/base.yml`: + 1. Set `collect_stack_trace: True` to enable collection of stack traces on faults or when the program is hung. This setting will periodically dump the traces for the program to help in debugging. To disable this, set `collect_stack_trace: False`. 2. Set `stack_trace_to_cloud: False` to display stack traces on console. `stack_trace_to_cloud: True` will create a temporary file in `/tmp/debugging` in the TPUs to store the stack traces. There is an agent running on TPU VMs that will periodically upload the traces from the temporary directory to cloud logging in the gcp project. You can view the traces in Logs Explorer on Cloud Logging using the following query: + ``` logName="projects//logs/tpu.googleapis.com%2Fruntime_monitor" jsonPayload.verb="stacktraceanalyzer" ``` + 3. `stack_trace_interval_seconds` signifies the duration in seconds between each stack trace collection event. Setting `stack_trace_interval_seconds: 600` will collect the stack traces every 600 seconds (10 minutes). Here is the related PyPI package: https://pypi.org/project/cloud-tpu-diagnostics. (aot-compilation)= + ## Ahead of Time compilation (AOT) + To compile your training run ahead of time, we provide a tool `train_compile.py`. This tool allows you to compile the main `train_step` in `train.py` for target hardware (e.g. a large number of v5e devices) without using the full cluster. ### TPU support You may use only a CPU or a single VM from a different family to pre-compile for a TPU cluster. This compilation helps with two main goals: -* It will flag any out of memory (OOM) information, such as when the `per_device_batch_size` is set too high, with an identical OOM stack trace as if it was compiled on the target hardware. +- It will flag any out of memory (OOM) information, such as when the `per_device_batch_size` is set too high, with an identical OOM stack trace as if it was compiled on the target hardware. -* The ahead of time compilation can be saved and then loaded for fast startup and restart times on the target hardware. +- The ahead of time compilation can be saved and then loaded for fast startup and restart times on the target hardware. -The tool `train_compile.py` is tightly linked to `train.py` and uses the same configuration file `src/MaxText/configs/base.yml`. Although you don't need to run on a TPU, you do need to install `jax[tpu]` in addition to other dependencies, so we recommend running `setup.sh` to install these if you have not already done so. +The tool `train_compile.py` is tightly linked to `train.py` and uses the same configuration file `src/maxtext/configs/base.yml`. Although you don't need to run on a TPU, you do need to install `jax[tpu]` in addition to other dependencies, so we recommend running `setup.sh` to install these if you have not already done so. #### Example AOT 1: Compile ahead of time basics + After installing the dependencies listed above, you are ready to compile ahead of time: + ```sh # Run the below on a single machine, e.g. a CPU -python3 MaxText.train_compile src/MaxText/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2 \ +python3 MaxText.train_compile src/maxtext/configs/base.yml compile_topology=v5e-256 compile_topology_num_slices=2 \ global_parameter_scale=16 per_device_batch_size=4 ``` This will compile a 16B parameter MaxText model on 2 v5e pods. #### Example AOT 2: Save compiled function, then load and run it + Here is an example that saves then loads the compiled `train_step`, starting with the save: **Step 1: Run AOT and save compiled function** + ```sh # Run the below on a single machine, e.g. a CPU export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true" -python3 -m MaxText.train_compile src/MaxText/configs/base.yml compile_topology=v5e-256 \ +python3 -m MaxText.train_compile src/maxtext/configs/base.yml compile_topology=v5e-256 \ compile_topology_num_slices=2 \ compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16 \ per_device_batch_size=4 steps=10000 learning_rate=1e-3 @@ -70,32 +80,36 @@ python3 -m MaxText.train_compile src/MaxText/configs/base.yml compile_topology=v **Step 2: Run `train.py` and load the compiled function** To load the compiled train_step, you just need to pass `compiled_trainstep_file=my_compiled_train.pickle` into `train.py`: + ```sh # Run the below on each host of the target hardware, e.g. each host on 2 slices of v5e-256 export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true" -python3 -m MaxText.train src/MaxText/configs/base.yml run_name=example_load_compile \ +python3 -m MaxText.train src/maxtext/configs/base.yml run_name=example_load_compile \ compiled_trainstep_file=my_compiled_train.pickle \ global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3 \ base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket ``` -In the save step of example 2 above we included exporting the compiler flag `LIBTPU_INIT_ARGS` and `learning_rate` because those affect the compiled object `my_compiled_train.pickle.` The sizes of the model (e.g. `global_parameter_scale`, `max_sequence_length` and `per_device_batch`) are fixed when you initially compile via `compile_train.py`, you will see a size error if you try to run the saved compiled object with different sizes than you compiled with. However a subtle note is that the **learning rate schedule** is also fixed when you run `compile_train` - which is determined by both `steps` and `learning_rate`. The optimizer parameters such as `adam_b1` are passed only as shaped objects to the compiler - thus their real values are determined when you run `train.py`, not during the compilation. If you do pass in different shapes (e.g. `per_device_batch`), you will get a clear error message reporting that the compiled signature has different expected shapes than what was input. If you attempt to run on different hardware than the compilation targets requested via `compile_topology`, you will get an error saying there is a failure to map the devices from the compiled to your real devices. Using different XLA flags or a LIBTPU than what was compiled will probably run silently with the environment you compiled in without error. However there is no guaranteed behavior in this case; you should run in the same environment you compiled in. +In the save step of example 2 above we included exporting the compiler flag `LIBTPU_INIT_ARGS` and `learning_rate` because those affect the compiled object `my_compiled_train.pickle.` The sizes of the model (e.g. `global_parameter_scale`, `max_sequence_length` and `per_device_batch`) are fixed when you initially compile via `compile_train.py`, you will see a size error if you try to run the saved compiled object with different sizes than you compiled with. However a subtle note is that the **learning rate schedule** is also fixed when you run `compile_train` - which is determined by both `steps` and `learning_rate`. The optimizer parameters such as `adam_b1` are passed only as shaped objects to the compiler - thus their real values are determined when you run `train.py`, not during the compilation. If you do pass in different shapes (e.g. `per_device_batch`), you will get a clear error message reporting that the compiled signature has different expected shapes than what was input. If you attempt to run on different hardware than the compilation targets requested via `compile_topology`, you will get an error saying there is a failure to map the devices from the compiled to your real devices. Using different XLA flags or a LIBTPU than what was compiled will probably run silently with the environment you compiled in without error. However there is no guaranteed behavior in this case; you should run in the same environment you compiled in. ### GPU support -Ahead-of-time compilation is also supported for GPUs with some differences from TPUs: + +Ahead-of-time compilation is also supported for GPUs with some differences from TPUs: 1. GPU does not support compilation across hardware: A GPU host is still required to run AoT compilation, but a single GPU host can compile a program for a larger cluster of the same hardware. -1. For [A3 Cloud GPUs](https://cloud.google.com/compute/docs/gpus#h100-gpus), the maximum "slice" size is a single host, and the `compile_topology_num_slices` parameter represents the number of A3 machines to precompile for. +2. For [A3 Cloud GPUs](https://cloud.google.com/compute/docs/gpus#h100-gpus), the maximum "slice" size is a single host, and the `compile_topology_num_slices` parameter represents the number of A3 machines to precompile for. #### Example + This example illustrates the flags to use for a multihost GPU compilation targeting a cluster of 4 A3 hosts: **Step 1: Run AOT and save compiled function** + ```sh # Run the below on a single A3 machine export XLA_FLAGS="--xla_gpu_enable_async_collectives=true" -python3 -m MaxText.train_compile src/MaxText/configs/base.yml compile_topology=a3 \ +python3 -m MaxText.train_compile src/maxtext/configs/base.yml compile_topology=a3 \ compile_topology_num_slices=4 \ compiled_trainstep_file=my_compiled_train.pickle global_parameter_scale=16 \ attention=dot_product per_device_batch_size=4 steps=10000 learning_rate=1e-3 @@ -104,10 +118,11 @@ python3 -m MaxText.train_compile src/MaxText/configs/base.yml compile_topology=a **Step 2: Run `train.py` and load the compiled function** To load the compiled `train_step`, you just need to pass `compiled_trainstep_file=my_compiled_train.pickle` into `train.py`: + ```sh # Run the below on each of the 4 target A3 hosts. export XLA_FLAGS="--xla_gpu_enable_async_collectives=true" -python3 -m MaxText.train src/MaxText/configs/base.yml run_name=example_load_compile \ +python3 -m MaxText.train src/maxtext/configs/base.yml run_name=example_load_compile \ compiled_trainstep_file=my_compiled_train.pickle \ attention=dot_product global_parameter_scale=16 per_device_batch_size=4 steps=10000 learning_rate=1e-3 \ base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket @@ -115,7 +130,6 @@ python3 -m MaxText.train src/MaxText/configs/base.yml run_name=example_load_comp As in the TPU case, note that the compilation environment must match the execution environment, in this case by setting the same `XLA_FLAGS`. - ## Automatically upload logs to Vertex AI Tensorboard MaxText supports automatic upload of logs collected in a directory to a Tensorboard instance in Vertex AI. Follow [](use_vertex_ai_tensorboard.md) to know more. diff --git a/docs/guides/monitoring_and_debugging/gcp_workload_observability.md b/docs/guides/monitoring_and_debugging/gcp_workload_observability.md index 5328120ed7..8ad6ff8dfa 100644 --- a/docs/guides/monitoring_and_debugging/gcp_workload_observability.md +++ b/docs/guides/monitoring_and_debugging/gcp_workload_observability.md @@ -15,12 +15,14 @@ --> # Enable GCP workload observabiltiy + This guide provides an overview on how to enable GCP workload observability for your MaxText workload. ## Overview + Google offers a monitoring and alerting feature that is well suited for critical MaxText workloads sensitive to infrastructure changes. Once enabled, metrics will be automatically sent to [Cloud Monarch](https://research.google/pubs/monarch-googles-planet-scale-in-memory-time-series-database/) for monitoring. -If a metric hits its pre-defined threshold, the Google Cloud on-call team will be alerted to see if any action is needed. +If a metric hits its pre-defined threshold, the Google Cloud on-call team will be alerted to see if any action is needed. The feature currently supports heartbeat and performance (training step time in seconds) metrics. In the near future, support for the goodput metric will also be added. Users should work with their Customer Engineer (CE) and the Google team to define appropriate thresholds for the performance metrics. @@ -28,15 +30,18 @@ Users should work with their Customer Engineer (CE) and the Google team to defin This guide layouts how to enable the feature for your MaxText workload. ## Enabling GCP workload observabiltiy + User can control which metric they want to report via config: ### Heartbeat metric + - This metric will be a boolean flag. - To turn on this metric, set `report_heartbeat_metric_for_gcp_monitoring` to `True` - To control the frequency of heartbeat reporting (default is every 5 seconds), set `heartbeat_reporting_interval_in_seconds` to your desired value. ### Performance metric + - This metric will be a double, capturing the training step time in seconds. - To turn on this metric, set `report_performance_metric_for_gcp_monitoring` to `True` -For an example, please refer to [base.yml](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/base.yml). +For an example, please refer to [base.yml](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/base.yml). diff --git a/docs/guides/monitoring_and_debugging/ml_workload_diagnostics.md b/docs/guides/monitoring_and_debugging/ml_workload_diagnostics.md index f1fccfab67..005b39d490 100644 --- a/docs/guides/monitoring_and_debugging/ml_workload_diagnostics.md +++ b/docs/guides/monitoring_and_debugging/ml_workload_diagnostics.md @@ -15,12 +15,15 @@ --> # Running a workload with Google Cloud ML Diagnostics Enabled + This guide provides an overview on how to enable ML Diagnostics for your MaxText workload. ## Overview -Google Cloud ML Diagnostics is an end-to-end managed platform for ML Engineers to optimize and diagnose their AI/ML workloads on Google Cloud. The product allows ML Engineers to collect and visualize all their workload metrics, configs and profiles with one single platform, all within the same UI. The current product offering focuses on workloads running on XLA-based frameworks (JAX, Pytorch XLA, Tensorflow/Keras) on Google Cloud TPUs and GPUs. Current support is for JAX on Google Cloud TPUs only. + +Google Cloud ML Diagnostics is an end-to-end managed platform for ML Engineers to optimize and diagnose their AI/ML workloads on Google Cloud. The product allows ML Engineers to collect and visualize all their workload metrics, configs and profiles with one single platform, all within the same UI. The current product offering focuses on workloads running on XLA-based frameworks (JAX, Pytorch XLA, Tensorflow/Keras) on Google Cloud TPUs and GPUs. Current support is for JAX on Google Cloud TPUs only. ## Enabling ML Diagnostics on Maxtext Workload + MaxText has integrated the ML Diagnostics [SDK](https://github.com/AI-Hypercomputer/google-cloud-mldiagnostics?tab=readme-ov-file) in its code. You can enable ML Diagnostics with the **managed-mldiagnostics** flag. If this is enabled, it will - Create a managed MachineLearning run with all the MaxText configs. @@ -29,37 +32,43 @@ MaxText has integrated the ML Diagnostics [SDK](https://github.com/AI-Hypercompu ### Examples -1. Enable ML Diagnostics to just capture Maxtext metrics and configs - - python3 -m MaxText.train src/MaxText/configs/base.yml \ - run_name=${USER}-tpu-job \ - base_output_directory="gs://your-output-bucket/" \ - dataset_path="gs://your-dataset-bucket/" \ - steps=100 \ - log_period=10 \ - managed_mldiagnostics=True - -2. Enable ML Diagnostics to capture Maxtext metrics, configs and singlehost profiles (on the first TPU device) - - python3 -m MaxText.train src/MaxText/configs/base.yml \ - run_name=${USER}-tpu-job \ - base_output_directory="gs://your-output-bucket/" \ - dataset_path="gs://your-dataset-bucket/" \ - steps=100 \ - log_period=10 \ - profiler=xplane \ - managed_mldiagnostics=True - -3. Enable ML Diagnostics to capture Maxtext metrics, configs and multihost profiles (on all TPU devices) - - python3 -m MaxText.train src/MaxText/configs/base.yml \ - run_name=${USER}-tpu-job \ - base_output_directory="gs://your-output-bucket/" \ - dataset_path="gs://your-dataset-bucket/" \ - steps=100 \ - log_period=10 \ - profiler=xplane \ - upload_all_profiler_results=True \ - managed_mldiagnostics=True - -Users can deploy the workload across all supported environments, including the standard XPK workload types (**xpk workload create** or **xpk workload create-pathways**) or by running the workload directly on a standalone TPU VM. \ No newline at end of file +1. Enable ML Diagnostics to just capture Maxtext metrics and configs + + ``` + python3 -m MaxText.train src/maxtext/configs/base.yml \ + run_name=${USER}-tpu-job \ + base_output_directory="gs://your-output-bucket/" \ + dataset_path="gs://your-dataset-bucket/" \ + steps=100 \ + log_period=10 \ + managed_mldiagnostics=True + ``` + +2. Enable ML Diagnostics to capture Maxtext metrics, configs and singlehost profiles (on the first TPU device) + + ``` + python3 -m MaxText.train src/maxtext/configs/base.yml \ + run_name=${USER}-tpu-job \ + base_output_directory="gs://your-output-bucket/" \ + dataset_path="gs://your-dataset-bucket/" \ + steps=100 \ + log_period=10 \ + profiler=xplane \ + managed_mldiagnostics=True + ``` + +3. Enable ML Diagnostics to capture Maxtext metrics, configs and multihost profiles (on all TPU devices) + + ``` + python3 -m MaxText.train src/maxtext/configs/base.yml \ + run_name=${USER}-tpu-job \ + base_output_directory="gs://your-output-bucket/" \ + dataset_path="gs://your-dataset-bucket/" \ + steps=100 \ + log_period=10 \ + profiler=xplane \ + upload_all_profiler_results=True \ + managed_mldiagnostics=True + ``` + +Users can deploy the workload across all supported environments, including the standard XPK workload types (**xpk workload create** or **xpk workload create-pathways**) or by running the workload directly on a standalone TPU VM. diff --git a/docs/guides/monitoring_and_debugging/monitor_goodput.md b/docs/guides/monitoring_and_debugging/monitor_goodput.md index e1f89dff4d..8ed46a33c8 100644 --- a/docs/guides/monitoring_and_debugging/monitor_goodput.md +++ b/docs/guides/monitoring_and_debugging/monitor_goodput.md @@ -15,6 +15,7 @@ --> (monitor-goodput)= + # ML Goodput measurement MaxText supports automatic measurement and upload of workload metrics such as Goodput, Badput Breakdown and Step Time Deviation using the ML Goodput Measurement library. @@ -22,10 +23,12 @@ MaxText supports automatic measurement and upload of workload metrics such as Go The [ML Goodput Measurement](https://github.com/AI-Hypercomputer/ml-goodput-measurement) library currently supports monitoring workloads running on Google Cloud Platform. For more information on details of the library, visit the Github page or the [ml-goodput-measurement](https://pypi.org/project/ml-goodput-measurement/) PyPI package documentation. ## What is Goodput -Goodput is the metric that measures the efficiency of model training jobs, i.e. productive time spent on training progress proportional to the total time spent by the workload. It is an actionable way for users to monitor where they can improve to get the most value from their accelerators. + +Goodput is the metric that measures the efficiency of model training jobs, i.e. productive time spent on training progress proportional to the total time spent by the workload. It is an actionable way for users to monitor where they can improve to get the most value from their accelerators. ## What is Badput -Badput is the metric that measures time that a workload spent on anything that is not productive training proportional to the total time spent by the workload. For example, the time spent in accelerator initialization, training preparation, program startup, data loading, portions of checkpointing, disruptions and wasted progress since the last checkpoint etc. all contribute to Badput. + +Badput is the metric that measures time that a workload spent on anything that is not productive training proportional to the total time spent by the workload. For example, the time spent in accelerator initialization, training preparation, program startup, data loading, portions of checkpointing, disruptions and wasted progress since the last checkpoint etc. all contribute to Badput. The ML Goodput Measurement library exposes Badput Breakdown. Further details of each bucket can be found [here](https://github.com/AI-Hypercomputer/ml-goodput-measurement?tab=readme-ov-file#badput-breakdown-details) @@ -38,13 +41,14 @@ The ML Goodput Measurement library exposes step time deviation by computing idea ## How to use ML Goodput Measurement in MaxText ### Prerequisites + The usage of this package requires the setup of a Google Cloud project with billing enabled to properly use Google Cloud Logging. If you don't have a Google Cloud project, or if you don't have billing enabled for your Google Cloud project, then do the following: 1. In the Google Cloud console, on the project selector page, - [select or create a Google Cloud project](https://cloud.google.com/resource-manager/docs/creating-managing-projects). + [select or create a Google Cloud project](https://cloud.google.com/resource-manager/docs/creating-managing-projects). 2. Make sure that billing is enabled for your Google Cloud project. Instructions can be found [here](https://cloud.google.com/billing/docs/how-to/verify-billing-enabled#console) @@ -85,7 +89,7 @@ Please use a unique workload name, unless you intend to monitor cumulative Goodp MaxText enables Goodput recording and monitoring by default with `enable_goodput_recording=True` and `monitor_goodput=True`. You can configure the goodput upload frequency by setting `goodput_upload_interval_seconds`. ```bash -python3 -m MaxText.train src/MaxText/configs/base.yml base_output_directory=$OUTPUT_PATH \ +python3 -m MaxText.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH \ dataset_path=$DATA_PATH run_name=goodput-test-run steps=200 goodput_upload_interval_seconds=30 ``` @@ -94,7 +98,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml base_output_directory=$OUT MaxText enables step time deviation monitoring by default with `monitor_step_time_deviation=True`. You can configure the upload frequency by setting `step_deviation_interval_seconds`. ```bash -python3 -m MaxText.train src/MaxText/configs/base.yml base_output_directory=$OUTPUT_PATH \ +python3 -m MaxText.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH \ dataset_path=$DATA_PATH run_name=goodput-test-run steps=200 step_deviation_interval_seconds=30 ``` @@ -107,7 +111,7 @@ Enabling `enable_pathways_goodput` turns on Goodput measurement for Pathways wor ``` ```bash -python3 -m MaxText.train src/MaxText/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH \ +python3 -m MaxText.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH \ run_name=goodput-test-run steps=200 goodput_upload_interval_seconds=30 enable_pathways_goodput=True ``` @@ -137,25 +141,25 @@ This feature is enabled by default, and no changes to the Monitoring API call ar ```python gcp_options = goodput_utils.GCPOptions( - project_id=None, # If None, the library will automatically identify from GCE internal metadata - location=None, # If None, the library will automatically identify from GCE internal metadata - replica_id='0', # Default is '0' - acc_type=None, # If None, the library will automatically identify from GCE internal metadata - enable_gcp_goodput_metrics=True, - enable_gcp_step_deviation_metrics=True, - ) + project_id=None, # If None, the library will automatically identify from GCE internal metadata + location=None, # If None, the library will automatically identify from GCE internal metadata + replica_id="0", # Default is '0' + acc_type=None, # If None, the library will automatically identify from GCE internal metadata + enable_gcp_goodput_metrics=True, + enable_gcp_step_deviation_metrics=True, +) goodput_monitor = monitoring.GoodputMonitor( - job_name=config.run_name, - logger_name=logger_name, - tensorboard_dir=config.tensorboard_dir, - upload_interval=config.goodput_upload_interval_seconds, - monitoring_enabled=True, - include_badput_breakdown=True, - include_step_deviation=True, - configured_ideal_step_time=None, # Optional, the library will compute ideal step time if it is not provided - gcp_options=gcp_options - ) + job_name=config.run_name, + logger_name=logger_name, + tensorboard_dir=config.tensorboard_dir, + upload_interval=config.goodput_upload_interval_seconds, + monitoring_enabled=True, + include_badput_breakdown=True, + include_step_deviation=True, + configured_ideal_step_time=None, # Optional, the library will compute ideal step time if it is not provided + gcp_options=gcp_options, +) ``` If you do not wish to send metrics to Google Cloud Monitoring then please set @@ -164,7 +168,7 @@ and `enable_gcp_step_deviation_metrics` to `False` for disabling step deviation metrics. ```bash -python3 -m MaxText.train src/MaxText/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH \ +python3 -m MaxText.train src/maxtext/configs/base.yml base_output_directory=$OUTPUT_PATH dataset_path=$DATA_PATH \ run_name=goodput-test-run steps=200 goodput_upload_interval_seconds=30 enable_gcp_goodput_metrics=False \ enable_gcp_step_deviation_metrics=False ``` @@ -176,17 +180,20 @@ monitoring. Goodput, Badput and Step Time Deviation metrics can be monitored using GCM Metrics Explorer: -1. Verify that the workload is executing with monitoring enabled. This ensures automatic data ingestion into Google Cloud Monitoring. -2. Navigate to [Metrics Explorer](https://console.cloud.google.com/monitoring/metrics-explorer). Initiate metric selection by clicking `Select a metric` then search for and select the `Workload` resource. Subsequently, choose the `Workload` metric category. - - a. [**Productive Time:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/goodput_time) - Represents the cumulative duration the workload spent on productive tasks, - measured by `compute.googleapis.com/workload/goodput_time`. - b. [**Non-Productive Time:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/badput_time) - Represents the cumulative duration the workload spent on non-productive tasks, - measured by `compute.googleapis.com/workload/badput_time`. - c. [**Performance:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/performance) - Represents the workload's performance metric, specifically step deviation - in this context, measured by `compute.googleapis.com/workload/performance`. -3. Navigate to [Dashboards](https://console.cloud.google.com/monitoring/dashboards). -4. Create a custom dashboard if there isn't one and add useful widgets with the above mentioned metrics. +1. Verify that the workload is executing with monitoring enabled. This ensures automatic data ingestion into Google Cloud Monitoring. + +2. Navigate to [Metrics Explorer](https://console.cloud.google.com/monitoring/metrics-explorer). Initiate metric selection by clicking `Select a metric` then search for and select the `Workload` resource. Subsequently, choose the `Workload` metric category. + + a. [**Productive Time:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/goodput_time) + Represents the cumulative duration the workload spent on productive tasks, + measured by `compute.googleapis.com/workload/goodput_time`.\ + b. [**Non-Productive Time:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/badput_time) + Represents the cumulative duration the workload spent on non-productive tasks, + measured by `compute.googleapis.com/workload/badput_time`.\ + c. [**Performance:**](https://cloud.google.com/monitoring/api/metrics_gcp#:~:text=workload/performance) + Represents the workload's performance metric, specifically step deviation + in this context, measured by `compute.googleapis.com/workload/performance`. + +3. Navigate to [Dashboards](https://console.cloud.google.com/monitoring/dashboards). + +4. Create a custom dashboard if there isn't one and add useful widgets with the above mentioned metrics. diff --git a/docs/guides/monitoring_and_debugging/understand_logs_and_metrics.md b/docs/guides/monitoring_and_debugging/understand_logs_and_metrics.md index 385138a695..90fb77d745 100644 --- a/docs/guides/monitoring_and_debugging/understand_logs_and_metrics.md +++ b/docs/guides/monitoring_and_debugging/understand_logs_and_metrics.md @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. --> + (understand-logs-and-metrics)= + # Understand logs and metrics When you run a training job, MaxText produces detailed output logs. This guide shows you how to interpret these logs to understand your configuration and monitor performance. @@ -21,7 +23,7 @@ When you run a training job, MaxText produces detailed output logs. This guide s To start, run a simple pretraining job on a single-host TPU. For instance, we can run the following command on TPU v5p-8. The resulting log is used as an example throughout this guide. ```bash -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ base_output_directory=gs://runner-maxtext-logs run_name=demo \ model_name=deepseek2-16b \ per_device_batch_size=24 max_target_length=2048 steps=10 dataset_type=synthetic enable_checkpointing=false @@ -32,13 +34,16 @@ per_device_batch_size=24 max_target_length=2048 steps=10 dataset_type=synthetic The first section of the log details the configuration of your run. This is crucial for debugging, as it shows you exactly which parameters were used. MaxText builds its configuration in layers. -- It starts with the **default configuration** from a YAML file. In our example, the file is [`src/MaxText/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/configs/base.yml). + +- It starts with the **default configuration** from a YAML file. In our example, the file is [`src/maxtext/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/maxtext/configs/base.yml). - Then, it overwrites any of these values with the arguments you provide in the **command line**. + ```none Updating keys from env and command line: ['run_name', 'model_name', 'enable_checkpointing', 'base_output_directory', 'per_device_batch_size', 'dataset_type', 'steps', 'max_target_length'] ``` -- It updates keys based on the **model-specific configuration** file. When you specify a model, like `deepseek2-16b`, MaxText reads the corresponding parameters from the [deepseek2-16b.yml](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/configs/models/deepseek2-16b.yml) file. + +- It updates keys based on the **model-specific configuration** file. When you specify a model, like `deepseek2-16b`, MaxText reads the corresponding parameters from the [deepseek2-16b.yml](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/maxtext/configs/models/deepseek2-16b.yml) file. ```none Running Model: deepseek2-16b @@ -49,9 +54,11 @@ MaxText builds its configuration in layers. ... Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_moe_mlp_dim', 'base_num_decoder_layers', 'first_num_dense_layers', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'num_experts', 'num_experts_per_tok', 'shared_experts', 'routed_scaling_factor', 'routed_score_func', 'routed_bias', 'decoder_block', 'attention_type', 'q_lora_rank', 'kv_lora_rank', 'qk_nope_head_dim', 'qk_rope_head_dim', 'v_head_dim', 'rope_type', 'rope_max_timescale', 'max_position_embeddings', 'original_max_position_embeddings', 'rope_factor', 'beta_fast', 'mscale'] ``` + Note that you cannot modify a key from both model config and command line. The final, consolidated configuration is printed last. + ```none # From base.yml default Config param opt_type: adamw @@ -69,7 +76,9 @@ Config param max_target_length: 2048 Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'),) ... ``` + This also includes the **output paths** for your run artifacts. + ``` Config param base_output_directory: gs://runner-maxtext-logs Config param run_name: demo @@ -83,31 +92,38 @@ Config param checkpoint_dir: gs://runner-maxtext-logs/demo/checkpoints/ MaxText organizes all of your run's artifacts into a main output directory. The primary location for your run is constructed by combining the `base_output_directory` and the `run_name` you specify in your command. Based on the logs above, the base path for this specific run is `gs://runner-maxtext-logs/demo`. Within this base path, MaxText creates several subdirectories for different types of artifacts. Many of these are optional and only created if you enable them with a specific flag. -* **TensorBoard logs (`tensorboard/`)** - * Flag: `enable_tensorboard=True` (default) - * Path: `gs://runner-maxtext-logs/demo/tensorboard/` -* **Profiler traces (`tensorboard/plugins/profile/`)** - * Flag: `profiler=xplane` - * Path: The profiler output is saved within the TensorBoard directory. +- **TensorBoard logs (`tensorboard/`)** + + - Flag: `enable_tensorboard=True` (default) + - Path: `gs://runner-maxtext-logs/demo/tensorboard/` + +- **Profiler traces (`tensorboard/plugins/profile/`)** + + - Flag: `profiler=xplane` + - Path: The profiler output is saved within the TensorBoard directory. + +- **Metrics in plain text (`metrics/`)** + + - Flag: `gcs_metrics=True` + - Path: `gs://runner-maxtext-logs/demo/metrics/` + +- **Configuration file (`config.yml`)** -* **Metrics in plain text (`metrics/`)** - * Flag: `gcs_metrics=True` - * Path: `gs://runner-maxtext-logs/demo/metrics/` + - Flag: `save_config_to_gcs=True` + - Path: `gs://runner-maxtext-logs/demo/config.yml` -* **Configuration file (`config.yml`)** - * Flag: `save_config_to_gcs=True` - * Path: `gs://runner-maxtext-logs/demo/config.yml` +- **Checkpoints (`checkpoints/`)** -* **Checkpoints (`checkpoints/`)** - * Flag: `enable_checkpointing=True` - * Path: `gs://runner-maxtext-logs/demo/checkpoints/` + - Flag: `enable_checkpointing=True` + - Path: `gs://runner-maxtext-logs/demo/checkpoints/` To generate all optional artifacts in one run, you can set the corresponding flags in the command line, like in the example below. This command enables tensorboard, profiler, text metrics, config saving, and checkpointing: + ```bash -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ base_output_directory=gs://runner-maxtext-logs run_name=demo2 \ model_name=deepseek2-16b \ per_device_batch_size=24 max_target_length=2048 steps=10 dataset_type=synthetic \ @@ -134,7 +150,6 @@ Num_devices: 4, shape (1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1) - **Hardware**: You are running on the `TPU v5` accelerator with `4` total devices. - **Parallelism strategy**: The `shape` tuple `(1, 1, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1)` shows how your devices are arranged for parallelism. Recall from Section 1, `Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'),)`. This confirms that all 4 devices are being used for Fully Sharded Data Parallelism (FSDP), which is the default behavior. - ## 3. Resource accounting Before executing training, the program analyzes the resource requirements for your training job, specifically memory and compute (FLOPs). @@ -142,10 +157,13 @@ Before executing training, the program analyzes the resource requirements for yo ### 3.1. Memory analysis We first perform a "dry run" compilation of a training step to [analyze its memory requirement](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/train.py#L380-L382). This static analysis is performed by the XLA compiler. The log outputs [memory sizes](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/max_utils.py#L672-L690): + ```none Total memory size: 100.4 GB, Output size: 44.5 GB, Temp size: 55.9 GB, Argument size: 44.5 GB, Host temp size: 0.0 GB. ``` + The most important number is `Total memory size: 100.4 GB`. This is the total High Bandwidth Memory (HBM) the TPU device needs to execute the program. Here is a breakdown: + - `Argument size: 44.5 GB`: This is the memory needed to hold the inputs for your function. This typically includes the batch of data, parameter (master copy), and optimizer state (e.g., moment). - `Output size: 44.5 GB`: This is the space required to store the results of the computation, such as the updated model weights and updated optimizer states. - `Temp size: 55.9 GB`: This is the "scratch space" memory. It's used for all the intermediate values created during the forward and backward passes that are discarded once the step is complete. This includes activation (forward pass), gradient (backward pass), and parameter (working copy, if mixed precision). @@ -153,7 +171,6 @@ The most important number is `Total memory size: 100.4 GB`. This is the total Hi In addition, it shows temporary memory used on the host CPU. In this case, `Host temp size: 0.0 GB`, indicating that all the significant memory allocation happens on the accelerator device. - ### 3.2. Memory snapshot The previous section is a forecast of memory usage for entire training step, based on static analysis of the compiled code from the XLA compiler. To see the actual memory usage, we now turn to a real-time snapshot from the JAX runtime, captured right after the training state is initialized. @@ -169,11 +186,13 @@ Memstats: After params initialized: Using (GB) 44.63 / 95.74 (46.615835%) on TPU_2(process=0,(0,1,0,0)) Using (GB) 44.63 / 95.74 (46.615835%) on TPU_3(process=0,(1,1,0,0)) ``` + This log shows that each of the four TPUs has `95.74 GB` of available High Bandwidth Memory (HBM). The initial training state is evenly distributed across devices, with each using the same amount of `44.63 GB`. ### 3.3. Model TFLOP per device The **model FLOPs** are the floating point operations to perform model computation. For training, the computation includes a single forward and backward pass. + - In MaxText, we estimate model FLOPs by summing operations in matrix multiplications (matmuls); see [calculate_tflops_training_per_device](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/maxtext_utils.py#L480). - The number of model FLOPs is dependent on model architecture, input size (batch size, sequence length), and gradient accumulation steps. It does not include optimization operations. - We break down the FLOPs into two parts: @@ -181,6 +200,7 @@ The **model FLOPs** are the floating point operations to perform model computati - "Attention FLOPs" are matmuls in attention score computation like $\mathrm{softmax}{\left(\frac{QK^\top}{\sqrt{d}}\right)} V$. One **TFLOP** (TeraFLOP) is equal to $10^{12}$ FLOPs. The log shows the theoretical estimate of **model TFLOP per device**: + ```none Per train step: Total TFLOPs: 764.67 @@ -188,6 +208,7 @@ Per train step: ``` In this example, given `model=deepseek2-16b`, `per_device_batch_size=24`, `max_target_length=2048`, and no gradient accumulation, we have $\text{model tflop per device} \approx 764.67$. + - 94.54% of the TFLOPs are attributed to learnable weight and 5.46% are attributed to attention. - As you will see next, this number is important for calculating performance metrics, such as TFLOP/s/device and Model FLOPs Utilization (MFU). @@ -196,6 +217,7 @@ You can find more information about model FLOPs and MFU in the [Performance Metr ## 4. Training metrics Finally, we are getting to the training steps! In this section, we introduce performance metrics including TFLOP/s/device, MFU, and Tokens/s/device (throughput). We briefly cover learning metrics including loss and total weights. + ```none completed step: 0, seconds: 44.923, TFLOP/s/device: 17.022, Tokens/s/device: 1094.129, total_weights: 196608, loss: 12.038 completed step: 1, seconds: 0.319, TFLOP/s/device: 2400.734, Tokens/s/device: 154316.608, total_weights: 196608, loss: 12.038 @@ -210,15 +232,18 @@ completed step: 9, seconds: 5.667, TFLOP/s/device: 134.924, Tokens/s/device: 867 ``` Before we dive deep here, recall a few numbers from previous sections: + - $\text{max target length} = 2048$, $\text{per device batch size} = 24$ - $\text{model tflop per device} \approx 764.67$ (rounded), $\text{number of devices} = 4$ ### 4.1. Performance metrics The performance metrics fluctuate at the beginning, and become stable towards the end. Therefore, we usually read them from the last step. Let's take a closer look at Step 9. + ```none completed step: 9, seconds: 5.667, TFLOP/s/device: 134.924, Tokens/s/device: 8672.758, total_weights: 196608, loss: 10.374 ``` + As shown in `seconds: 5.667`, $\text{measured step time in seconds} \approx 5.667$ (rounded). **TFLOP per second per device** @@ -232,7 +257,7 @@ $$\text{tflop/s/device} = \frac{\text{model tflop per device}}{\text{measured st $$\text{MFU} = \frac{\text{tflop/s/device}}{\text{peak hardware tflop/s}}$$ - For TPU v5p, $\text{peak hardware tflop/s}=459$. Thus, $134.924 / 459 = 29.40$%. Note this is an example for explanation with small batch size and sequence length, so the MFU is not optimal. +For TPU v5p, $\text{peak hardware tflop/s}=459$. Thus, $134.924 / 459 = 29.40$%. Note this is an example for explanation with small batch size and sequence length, so the MFU is not optimal. **Tokens per second per device (throughput)** @@ -251,10 +276,11 @@ $$\text{number of tokens per device} = \text{per device batch size} \times \text **Loss**. The loss is the key indicator of learning progress, which should decrease over training steps. In this example, the loss is `12.038` at Step 0 and decreases to `10.374` at Step 9. Ideally, we want the loss to converge to a small value with sufficiently large training steps. **Total weights**. When discussing the throughput, we have $\text{number of tokens} = \text{per device batch size} \times \text{max target length} \times \text{number of device}$. In this example, $\text{number of tokens} = 24 \times 2048 \times 4 = 196608$. There are two types of tokens: real tokens and pad tokens. The pad tokens are placeholders introduced by data preprocessing: We truncate or pad each sentence to max target length. Only real tokens contribute to the learning signal (i.e., loss). Therefore, we monitor $\text{number of real tokens}$, which is shown as [total weights](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/train.py#L151). + - Here we see `total_weights: 196608` for all steps. This is because we are using `dataset_type=synthetic`, where all sentences are generated with a length of `max_target_length=2048`. As a result, there are no pad tokens and total weights = number of tokens. - However, in real datasets, sentences can have variable lengths and total weights < number of tokens. For example, we can set `dataset_type=tfds dataset_path=gs://maxtext-dataset dataset_name='c4/en:3.0.1'`, and will see total weights smaller than `196608`: ```none completed step: 8, seconds: 5.670, TFLOP/s/device: 134.856, Tokens/s/device: 8668.393, total_weights: 163259, loss: 9.596 completed step: 9, seconds: 5.669, TFLOP/s/device: 134.884, Tokens/s/device: 8670.184, total_weights: 155934, loss: 9.580 ``` -- For better convergence, we want to have large total weights. Towards this end, MaxText supports [packing](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/sequence_packing.py#L37) multiple short sequences into one. This is enabled by default with `packing=True` in [base.yml](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/configs/base.yml#L465). +- For better convergence, we want to have large total weights. Towards this end, MaxText supports [packing](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/MaxText/sequence_packing.py#L37) multiple short sequences into one. This is enabled by default with `packing=True` in [base.yml](https://github.com/AI-Hypercomputer/maxtext/blob/28e5097ac467ed8b1d17676d68aa5acc50f9d60d/src/maxtext/configs/base.yml#L465). diff --git a/docs/guides/optimization/sharding.md b/docs/guides/optimization/sharding.md index 70b7185131..6e8a69cd0a 100644 --- a/docs/guides/optimization/sharding.md +++ b/docs/guides/optimization/sharding.md @@ -120,15 +120,15 @@ arithmetic intensity analysis since they shard the batch, as we will illustrate Sharding in maxtext is split into 3 layers -- **Physical** mesh axes (e.g. `data`, `fsdp`, `tensor`) defined [here](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/MaxText/configs/base.yml#L269) +- **Physical** mesh axes (e.g. `data`, `fsdp`, `tensor`) defined [here](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/maxtext/configs/base.yml#L269) - - Mesh is created via [create_device_mesh](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/MaxText/max_utils.py#L576-L580) +- Mesh is created via [create_device_mesh](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/MaxText/max_utils.py#L576-L580) - - Mesh given names in train.py via [Mesh](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/MaxText/train.py#L594) +- Mesh given names in train.py via [Mesh](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/MaxText/train.py#L594) -- **Logical** axes which map a meaningful axes name to physical axes defined [here](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/MaxText/configs/base.yml#L270) +- **Logical** axes which map a meaningful axes name to physical axes defined [here](https://github.com/AI-Hypercomputer/maxtext/blob/f269268bd622f6d2f40d38632ede7a7834a6024e/maxtext/configs/base.yml#L270) - - E.g. logical axes `activation_batch` is sharded by the physical axes of `data` and `fsdp` (among others) since those sharding strategies shard the batch. `Activation_batch` is a common axis among most activation tensors. Note that if we use `data_parallelism=4` and `fsdp_parallelism=2`, then the `activation_batch` dimension will get sharded over both, e.g. $4*2=8$ ways. +- E.g. logical axes `activation_batch` is sharded by the physical axes of `data` and `fsdp` (among others) since those sharding strategies shard the batch. `Activation_batch` is a common axis among most activation tensors. Note that if we use `data_parallelism=4` and `fsdp_parallelism=2`, then the `activation_batch` dimension will get sharded over both, e.g. $4*2=8$ ways. - **Individual tensors** have sharding constraints - generally specified by logical rules @@ -424,8 +424,8 @@ Note that for MoE models, this arithmetic intensity grows by a factor of `expert ## Context Autoregressive -Context Autoregressive shards the KV cache on the sequence dimension. It shards feed forward layer by experts for both activations and weights. This is used for inference only, see [inference.yml](https://github.com/AI-Hypercomputer/maxtext/blob/353a45d57eb1f1cc02e5c8d9e7b18eaf634d7edc/MaxText/configs/inference.yml#L4) for the modified logical axis rules for inference. +Context Autoregressive shards the KV cache on the sequence dimension. It shards feed forward layer by experts for both activations and weights. This is used for inference only, see [inference.yml](https://github.com/AI-Hypercomputer/maxtext/blob/353a45d57eb1f1cc02e5c8d9e7b18eaf634d7edc/maxtext/configs/inference/inference.yml#L4) for the modified logical axis rules for inference. ## Autoregressive -Autoregressive shards weights, but not activations. This is used for inference only. See [inference.yml](https://github.com/AI-Hypercomputer/maxtext/blob/353a45d57eb1f1cc02e5c8d9e7b18eaf634d7edc/MaxText/configs/inference.yml#L4) for the modified logical axis rules for inference. +Autoregressive shards weights, but not activations. This is used for inference only. See [inference.yml](https://github.com/AI-Hypercomputer/maxtext/blob/353a45d57eb1f1cc02e5c8d9e7b18eaf634d7edc/maxtext/configs/inference/inference.yml#L4) for the modified logical axis rules for inference. diff --git a/docs/reference/architecture/architecture_overview.md b/docs/reference/architecture/architecture_overview.md index 94224985ca..92af7e90f2 100644 --- a/docs/reference/architecture/architecture_overview.md +++ b/docs/reference/architecture/architecture_overview.md @@ -2,7 +2,7 @@ ## The MaxText philosophy -The architecture of MaxText is guided by a distinct and deliberate philosophy that prioritizes accessibility and scalability by deeply leveraging the power of the XLA compiler. This approach marks a strategic departure from frameworks that rely on extensive manual optimization. Instead, MaxText achieves its goals through a pure Python/JAX implementation that trusts the underlying compiler to handle the complexities of hardware optimization. Only for the most performance-critical operations, such as custom attention mechanisms or Mixture-of-Experts (MoE) routing, does MaxText use custom kernels written in Pallas. +The architecture of MaxText is guided by a distinct and deliberate philosophy that prioritizes accessibility and scalability by deeply leveraging the power of the XLA compiler. This approach marks a strategic departure from frameworks that rely on extensive manual optimization. Instead, MaxText achieves its goals through a pure Python/JAX implementation that trusts the underlying compiler to handle the complexities of hardware optimization. Only for the most performance-critical operations, such as custom attention mechanisms or Mixture-of-Experts (MoE) routing, does MaxText use custom kernels written in Pallas. ## Trusting the compiler @@ -33,14 +33,13 @@ The control plane of MaxText provides a structured yet flexible interface for us ### `base.yml`: the central configuration hub -Every MaxText job is governed by the same base YAML configuration file ([`src/MaxText/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/01c7137d4e13878e38baae44dc99e588eaa50a70/src/MaxText/configs/base.yml)) with model-specific details and overrides passed through a second config (e.g. [`src/MaxText/configs/models/deepseek3-671b.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/01c7137d4e13878e38baae44dc99e588eaa50a70/src/MaxText/configs/models/deepseek3-671b.yml)). Finally, experiment-specific settings are passed on the command line. The contents of these together comprise all the hyperparameters and settings that define a run: - -* Model architecture: Defines the core transformer structure, with parameters like `model_name` (e.g., 'llama2-7b'), `global_parameter_scale` for size, `base_emb_dim`, `base_num_heads`, the type of attention mechanism, and `quantization` settings (e.g., 'int8'). -* Training and optimization: Controls the training process with settings like `steps`, `learning_rate`, optimizer parameters such as `adam_b1`, and the `per_device_batch_size`. -* Data pipeline: Specifies the data source via `dataset_type` ('tfds', 'grain', 'hf'), the `dataset_path` on Cloud Storage, and Hugging Face-specific parameters like `hf_path` and `hf_train_files`. -* Hardware and parallelism: Defines the physical and logical device layout with `ici_parallelism` (intra-chip interconnect), `dcn_parallelism` (data center network), and `compile_topology` for ahead-of-time compilation. -* Checkpointing and logging: Manages run artifacts with `enable_checkpointing`, `async_checkpointing`, the `base_output_directory` in a Cloud Storage bucket, and a unique `run_name`. +Every MaxText job is governed by the same base YAML configuration file ([`src/maxtext/configs/base.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/01c7137d4e13878e38baae44dc99e588eaa50a70/src/maxtext/configs/base.yml)) with model-specific details and overrides passed through a second config (e.g. [`src/maxtext/configs/models/deepseek3-671b.yml`](https://github.com/AI-Hypercomputer/maxtext/blob/01c7137d4e13878e38baae44dc99e588eaa50a70/src/maxtext/configs/models/deepseek3-671b.yml)). Finally, experiment-specific settings are passed on the command line. The contents of these together comprise all the hyperparameters and settings that define a run: +- Model architecture: Defines the core transformer structure, with parameters like `model_name` (e.g., 'llama2-7b'), `global_parameter_scale` for size, `base_emb_dim`, `base_num_heads`, the type of attention mechanism, and `quantization` settings (e.g., 'int8'). +- Training and optimization: Controls the training process with settings like `steps`, `learning_rate`, optimizer parameters such as `adam_b1`, and the `per_device_batch_size`. +- Data pipeline: Specifies the data source via `dataset_type` ('tfds', 'grain', 'hf'), the `dataset_path` on Cloud Storage, and Hugging Face-specific parameters like `hf_path` and `hf_train_files`. +- Hardware and parallelism: Defines the physical and logical device layout with `ici_parallelism` (intra-chip interconnect), `dcn_parallelism` (data center network), and `compile_topology` for ahead-of-time compilation. +- Checkpointing and logging: Manages run artifacts with `enable_checkpointing`, `async_checkpointing`, the `base_output_directory` in a Cloud Storage bucket, and a unique `run_name`. A critical feature of this system is its flexibility. While `base.yml` provides the default values, any parameter can be overridden at runtime via command-line arguments. This allows for easy scripting of experiments and hyperparameter sweeps without needing to modify the configuration file for every run. At the same time, reproducibility can of course be maintained, by storing command line overrides in .sh files. @@ -48,25 +47,24 @@ A critical feature of this system is its flexibility. While `base.yml` provides MaxText can be executed trivially on a single TPU VM host and surprisingly easily on multi-host setups. -* Single-host development: This is the simplest entry point, designed for running MaxText on a single TPU VM (e.g., v5p-8) or a single GPU machine. It is ideal for initial setup, dependency installation, and small-scale debugging or experimentation. - -* GKE with XPK (recommended for production): This is the most scalable and robust method for running MaxText. It leverages the Accelerated Processing Kit (XPK) on Google Kubernetes Engine (GKE). XPK is an orchestration tool that standardizes best practices for large-scale ML jobs. It decouples the provisioning of compute capacity from the execution of the training job, allowing for more efficient resource management. This approach is recommended for production-grade training and serving due to its scalability, fault tolerance, and integration with the broader Google Cloud ecosystem. +- Single-host development: This is the simplest entry point, designed for running MaxText on a single TPU VM (e.g., v5p-8) or a single GPU machine. It is ideal for initial setup, dependency installation, and small-scale debugging or experimentation. +- GKE with XPK (recommended for production): This is the most scalable and robust method for running MaxText. It leverages the Accelerated Processing Kit (XPK) on Google Kubernetes Engine (GKE). XPK is an orchestration tool that standardizes best practices for large-scale ML jobs. It decouples the provisioning of compute capacity from the execution of the training job, allowing for more efficient resource management. This approach is recommended for production-grade training and serving due to its scalability, fault tolerance, and integration with the broader Google Cloud ecosystem. ### Summary The table below summarizes some of the most critical parameters in base.yml and the components of the architecture they control, serving as a quick reference for configuring a MaxText run. -| Parameter | Module(s) Affected | Description | -| :---- | :---- | :---- | -| model\_name | models.py, train.py | Selects the transformer architecture as specified in the corresponding model config file (e.g., 'llama2-7b'). | -| per\_device\_batch\_size | train.py, input\_pipeline.py | Sets the local batch size per accelerator chip. | -| ici\_parallelism, dcn\_parallelism | max\_utils.py, train.py | Defines the device mesh shape for intra-chip and data center network parallelism. | -| dataset\_type | input\_pipeline.py | Specifies the data loader backend ('tfds', 'grain', 'hf'). | -| enable\_checkpointing | checkpointing.py, train.py | Enables or disables saving model state. | -| async\_checkpointing | checkpointing.py, train.py | If True, saves checkpoints without blocking the training loop. | -| quantization | layers.py, optimizers.py | Enables quantization, e.g., 'int8' for AQT or Qwix. | -| compile\_topology | train\_compile.py | Specifies the target hardware topology for AOT compilation. | +| Parameter | Module(s) Affected | Description | +| :------------------------------- | :-------------------------- | :------------------------------------------------------------------------------------------------------------ | +| model_name | models.py, train.py | Selects the transformer architecture as specified in the corresponding model config file (e.g., 'llama2-7b'). | +| per_device_batch_size | train.py, input_pipeline.py | Sets the local batch size per accelerator chip. | +| ici_parallelism, dcn_parallelism | max_utils.py, train.py | Defines the device mesh shape for intra-chip and data center network parallelism. | +| dataset_type | input_pipeline.py | Specifies the data loader backend ('tfds', 'grain', 'hf'). | +| enable_checkpointing | checkpointing.py, train.py | Enables or disables saving model state. | +| async_checkpointing | checkpointing.py, train.py | If True, saves checkpoints without blocking the training loop. | +| quantization | layers.py, optimizers.py | Enables quantization, e.g., 'int8' for AQT or Qwix. | +| compile_topology | train_compile.py | Specifies the target hardware topology for AOT compilation. | ## Core architectural components @@ -80,18 +78,17 @@ The typical model comprises a decoder-only autoregressive transformer, but MaxTe While the base model implementations are typically simple, MaxText is equipped to handle a wide range of advanced, industry-standard features necessary for state-of-the-art performance and efficiency: -* Mixture-of-Experts (MoE): MaxText provides native support for sparse MoE models, such as DeepSeek. This includes efficient "dropping" and "dropless" MoE implementations leveraging the MegaBlox [Pallas](https://docs.jax.dev/en/latest/pallas/index.html) kernel, which can be enabled via configuration flags. +- Mixture-of-Experts (MoE): MaxText provides native support for sparse MoE models, such as DeepSeek. This includes efficient "dropping" and "dropless" MoE implementations leveraging the MegaBlox [Pallas](https://docs.jax.dev/en/latest/pallas/index.html) kernel, which can be enabled via configuration flags. -* Advanced attention mechanisms: The architecture is not limited to standard self-attention. It supports variants like Grouped-Query Attention (GQA), Multi-Query Attention (MQA) and Multi-headed Latent Attention (MLA). Since, like MoE, attention can be a performance hot-spot in transformers, attention is typically implemented in [Pallas](https://docs.jax.dev/en/latest/pallas/index.html) kernels, with Splash (Sparse, Flash) Attention being the default for training. +- Advanced attention mechanisms: The architecture is not limited to standard self-attention. It supports variants like Grouped-Query Attention (GQA), Multi-Query Attention (MQA) and Multi-headed Latent Attention (MLA). Since, like MoE, attention can be a performance hot-spot in transformers, attention is typically implemented in [Pallas](https://docs.jax.dev/en/latest/pallas/index.html) kernels, with Splash (Sparse, Flash) Attention being the default for training. -* Quantization: The framework seamlessly integrates with Google's Accurate Quantized Training (AQT) and Qwix libraries. Quantization logic is applied at the layer level. +- Quantization: The framework seamlessly integrates with Google's Accurate Quantized Training (AQT) and Qwix libraries. Quantization logic is applied at the layer level. - -The modularity of this design is clearly demonstrated by third-party extensions. For instance, the NVIDIA maxtext-jaxpp fork was able to add support for pipeline parallelism by inserting jaxpp.pipeline\_enter\_stage hooks directly into the \_\_call\_\_ method of the Decoder class, a testament to the codebase's modularity and extensibility. +The modularity of this design is clearly demonstrated by third-party extensions. For instance, the NVIDIA maxtext-jaxpp fork was able to add support for pipeline parallelism by inserting jaxpp.pipeline_enter_stage hooks directly into the \_\_call\_\_ method of the Decoder class, a testament to the codebase's modularity and extensibility. ### Data ingestion (`input_pipeline.py`) -[The data ingestion pipeline](../../guides/data_input_pipeline.md) is a critical component for performance at scale. In MaxText, the main training loop interfaces with the data pipeline through the create\_data\_iterator function, which is called from train.py. This function acts as a facade, abstracting the specific data loading implementation from the rest of the training logic. +[The data ingestion pipeline](../../guides/data_input_pipeline.md) is a critical component for performance at scale. In MaxText, the main training loop interfaces with the data pipeline through the create_data_iterator function, which is called from train.py. This function acts as a facade, abstracting the specific data loading implementation from the rest of the training logic. MaxText supports three primary data loading backends: @@ -99,7 +96,6 @@ MaxText supports three primary data loading backends: 2. TFDS (TensorFlow Datasets): For using datasets in the TFRecord format. 3. Grain: A data loading library optimized for large-scale, distributed environments. - While all three are supported, MaxText recommends the use of Grain, particularly for multi-host training scenarios. The rationale stems from performance and determinism considerations, at which Grain excels. Grain uses a data format called ArrayRecord, which supports efficient random access by index. This allows for true global shuffling of data across all hosts and eliminates the performance bottleneck associated with sequential reading. ### State management and persistence (`checkpointing.py`) @@ -107,19 +103,17 @@ While all three are supported, MaxText recommends the use of Grain, particularly MaxText's state management and persistence layer is built on [Orbax](https://orbax.readthedocs.io/en/latest/), a flexible and powerful open-source checkpointing library for JAX applications. The core logic is encapsulated within the checkpointing.py module, which provides a comprehensive suite of tools for saving and loading training state with high performance and resilience. -The central function is create\_orbax\_checkpoint\_manager, which configures and returns an Orbax CheckpointManager instance. This manager handles the core checkpointing operations and is configured with several key features: - -* Asynchronous checkpointing: By setting the `async_checkpointing` flag to true, users can enable non-blocking checkpoint saves. This is a critical performance optimization. The training loop can proceed with the next step on the accelerators while the CPU on each host handles the process of serializing the previous step's state and writing it to Google Cloud Storage. This effectively hides the I/O latency of checkpointing and maximizes accelerator utilization. -* Flexible state restoration: The `load_state_if_possible` function implements a sophisticated, prioritized logic for resuming a run. When a job starts, it first attempts to find and load a full checkpoint from the current run's output directory. If that fails, it checks if a path to a full state checkpoint from a different run has been provided via the `load_full_state_from_path` argument. If that also fails, it looks for a parameter-only checkpoint (without training/optimizer state) specified by `load_parameters_from_path`. -* Emergency and replicated checkpointing: For maximum resilience and rapid job resumption in large-scale, production environments like GKE, the module includes support for advanced Orbax features. +The central function is create_orbax_checkpoint_manager, which configures and returns an Orbax CheckpointManager instance. This manager handles the core checkpointing operations and is configured with several key features: +- Asynchronous checkpointing: By setting the `async_checkpointing` flag to true, users can enable non-blocking checkpoint saves. This is a critical performance optimization. The training loop can proceed with the next step on the accelerators while the CPU on each host handles the process of serializing the previous step's state and writing it to Google Cloud Storage. This effectively hides the I/O latency of checkpointing and maximizes accelerator utilization. +- Flexible state restoration: The `load_state_if_possible` function implements a sophisticated, prioritized logic for resuming a run. When a job starts, it first attempts to find and load a full checkpoint from the current run's output directory. If that fails, it checks if a path to a full state checkpoint from a different run has been provided via the `load_full_state_from_path` argument. If that also fails, it looks for a parameter-only checkpoint (without training/optimizer state) specified by `load_parameters_from_path`. +- Emergency and replicated checkpointing: For maximum resilience and rapid job resumption in large-scale, production environments like GKE, the module includes support for advanced Orbax features. A fundamental aspect of the MaxText workflow is the conversion of checkpoints between different formats. Scripts are provided to handle both ingestion and egress of model weights: -* Ingestion: Utilities like convert\_gemma\_chkpt.py and llama\_or\_mistral\_ckpt.py are used to transform checkpoints from standard frameworks (e.g., Hugging Face PyTorch) into the native MaxText Orbax format, which includes the full PyTree structure required for training. -* Preparation for inference: Conversely, the generate\_param\_only\_checkpoint.py script serves a crucial role in the path to deployment. It takes a full training checkpoint (which contains model parameters, optimizer state, and other metadata) and strips it down to only the essential model parameters. This script also performs a critical transformation from the "scanned" format used during training (an optimization where layers are stacked into a single tensor for efficient compilation) to the "unscanned" format required for autoregressive decoding. The resulting lightweight, parameter-only checkpoint is optimized for use with the decode.py script or for deployment with the JetStream inference engine. -* There also exist conversion scripts to convert weights to Hugging Face, e.g. `llama_mistral_mixtral_orbax_to_hf.py` - +- Ingestion: Utilities like convert_gemma_chkpt.py and llama_or_mistral_ckpt.py are used to transform checkpoints from standard frameworks (e.g., Hugging Face PyTorch) into the native MaxText Orbax format, which includes the full PyTree structure required for training. +- Preparation for inference: Conversely, the generate_param_only_checkpoint.py script serves a crucial role in the path to deployment. It takes a full training checkpoint (which contains model parameters, optimizer state, and other metadata) and strips it down to only the essential model parameters. This script also performs a critical transformation from the "scanned" format used during training (an optimization where layers are stacked into a single tensor for efficient compilation) to the "unscanned" format required for autoregressive decoding. The resulting lightweight, parameter-only checkpoint is optimized for use with the decode.py script or for deployment with the JetStream inference engine. +- There also exist conversion scripts to convert weights to Hugging Face, e.g. `llama_mistral_mixtral_orbax_to_hf.py` ### Utilities and distributed setup (`max_utils.py`) @@ -127,10 +121,9 @@ The `max_utils.py` module serves as a collection of common helper functions used The `maybe_initialize_jax_distributed_system` function is one example of this abstraction. This single function encapsulates the logic required to correctly call `jax.distributed.initialize()` in various deployment scenarios. It inspects the configuration and environment to determine the correct initialization parameters, handling cases for: -* Different hardware types, such as `gpu_multiprocess`. -* Configurations involving asynchronous checkpointing and multi-controller setups, which have specific distributed system requirements. -* Specialized environments like GKE with emergency checkpointing enabled. In this scenario, the JAX process ID and the coordinator's network address are not known beforehand but are written to a file by the GKE orchestrator. The function contains logic to poll for this file and parse the necessary information to initialize the distributed system correctly. - +- Different hardware types, such as `gpu_multiprocess`. +- Configurations involving asynchronous checkpointing and multi-controller setups, which have specific distributed system requirements. +- Specialized environments like GKE with emergency checkpointing enabled. In this scenario, the JAX process ID and the coordinator's network address are not known beforehand but are written to a file by the GKE orchestrator. The function contains logic to poll for this file and parse the necessary information to initialize the distributed system correctly. By centralizing this complex, environment-dependent logic into a single utility function, MaxText keeps the main training script cleaner and shields the end-user from the low-level details of distributed system bootstrapping. @@ -146,12 +139,12 @@ The foundation of MaxText's scaling strategy is JAX's `jit` transformation, whic This logical mesh abstraction enables the implementation of the standard parallelism strategies required for training large language models: -* Data parallelism (DP): The simplest form, where the entire model is replicated on each device (or group of devices), and the global data batch is split among the replicas. -* Fully sharded data parallelism (FSDP): An optimization over DP where the model's parameters, gradients, and optimizer states are sharded (split) across the data-parallel replicas, significantly reducing the memory footprint on each device. -* Tensor parallelism (TP): A model parallelism technique where individual operations within a transformer layer (such as large matrix multiplications) are split across multiple devices within a replica. -* Pipeline parallelism (PP): splitting multiple stages of the network (groups of layers) across devices +- Data parallelism (DP): The simplest form, where the entire model is replicated on each device (or group of devices), and the global data batch is split among the replicas. +- Fully sharded data parallelism (FSDP): An optimization over DP where the model's parameters, gradients, and optimizer states are sharded (split) across the data-parallel replicas, significantly reducing the memory footprint on each device. +- Tensor parallelism (TP): A model parallelism technique where individual operations within a transformer layer (such as large matrix multiplications) are split across multiple devices within a replica. +- Pipeline parallelism (PP): splitting multiple stages of the network (groups of layers) across devices -In MaxText, these strategies are implemented by annotating the model's PyTrees (the nested Python structures of arrays that hold the parameters and state) with sharding specifications. This is done using Flax's partitioning utilities, such as nn\_partitioning. These annotations provide requirements and hints to the compiler, telling it how each tensor should be distributed across the axes of the device mesh. The compiler then generates the appropriate collective communication operations (e.g., all-reduce, all-gather) needed to execute the parallel computation correctly and efficiently. +In MaxText, these strategies are implemented by annotating the model's PyTrees (the nested Python structures of arrays that hold the parameters and state) with sharding specifications. This is done using Flax's partitioning utilities, such as nn_partitioning. These annotations provide requirements and hints to the compiler, telling it how each tensor should be distributed across the axes of the device mesh. The compiler then generates the appropriate collective communication operations (e.g., all-reduce, all-gather) needed to execute the parallel computation correctly and efficiently. For more information on sharding see [our sharding documentation](../../guides/optimization/sharding.md). @@ -184,6 +177,6 @@ The critical technology enabling this strategy is the suite of checkpoint conver Debugging performance issues in a distributed system with thousands of accelerators is a notoriously difficult challenge. MaxText incorporates several built-in diagnostic features designed to provide visibility into the system's behavior at scale. -* Stack trace collection: To diagnose program hangs or faults, users can set `collect_stack_trace: True` in the configuration. This feature will periodically dump the Python stack traces from all worker processes. The traces can be directed to the console for immediate inspection or, more scalably, uploaded to Cloud Logging, where they can be aggregated and queried to identify misbehaving nodes. -* HLO dumping: For deep, low-level performance analysis, MaxText allows users to dump the XLA High-Level Optimizer (HLO) graph. By setting the `dump_hlo` flag, the compiled graph for a specific training step can be saved to a local directory or uploaded to Cloud Storage. This HLO representation is invaluable for compiler engineers and advanced users who need to understand exactly how XLA is interpreting and optimizing the model, making it possible to debug subtle performance regressions or compiler-related issues. -* Goodput monitoring: The framework integrates with the ml-goodput-measurement library, which provides a more holistic view of job efficiency than simple TFLOPs calculations. This allows for the tracking of metrics that capture overall "goodput," accounting for factors like data loading time, compilation overhead, and idle time, giving a truer picture of end-to-end performance. +- Stack trace collection: To diagnose program hangs or faults, users can set `collect_stack_trace: True` in the configuration. This feature will periodically dump the Python stack traces from all worker processes. The traces can be directed to the console for immediate inspection or, more scalably, uploaded to Cloud Logging, where they can be aggregated and queried to identify misbehaving nodes. +- HLO dumping: For deep, low-level performance analysis, MaxText allows users to dump the XLA High-Level Optimizer (HLO) graph. By setting the `dump_hlo` flag, the compiled graph for a specific training step can be saved to a local directory or uploaded to Cloud Storage. This HLO representation is invaluable for compiler engineers and advanced users who need to understand exactly how XLA is interpreting and optimizing the model, making it possible to debug subtle performance regressions or compiler-related issues. +- Goodput monitoring: The framework integrates with the ml-goodput-measurement library, which provides a more holistic view of job efficiency than simple TFLOPs calculations. This allows for the tracking of metrics that capture overall "goodput," accounting for factors like data loading time, compilation overhead, and idle time, giving a truer picture of end-to-end performance. diff --git a/docs/reference/core_concepts/checkpoints.md b/docs/reference/core_concepts/checkpoints.md index af344eb941..08bffa9861 100644 --- a/docs/reference/core_concepts/checkpoints.md +++ b/docs/reference/core_concepts/checkpoints.md @@ -20,10 +20,10 @@ Checkpoint formats in MaxText can be categorized along two axes: whether they include **training states** (e.g., optimizer properties) and whether the model's parameter weights are **stacked** or **unstacked** (aka scanned/unscanned). This results in the four types summarized below: -| | **Unstacked Weights** | **Stacked Weights** | -| :------------------------ | :------------------------------------- | :-------------------------------------- | -| **Without Train State** | Unstacked Inference Checkpoint | Stacked Inference Checkpoint | -| **With Train State** | Unstacked Training Checkpoint | Stacked Training Checkpoint | +| | **Unstacked Weights** | **Stacked Weights** | +| :---------------------- | :----------------------------- | :--------------------------- | +| **Without Train State** | Unstacked Inference Checkpoint | Stacked Inference Checkpoint | +| **With Train State** | Unstacked Training Checkpoint | Stacked Training Checkpoint | We discuss these two axes respectively: @@ -33,7 +33,7 @@ Checkpoints with a **training state** contain more than just the model's paramet In contrast, **inference checkpoints** contain only the parameter weights. We also call them parameter only/param-only checkpoints. This is the format most commonly used for sharing models on public platforms like HuggingFace, as they are smaller and ready for immediate use in inference or for fine-tuning. -### Stacked checkpoints and JAX scan function +### Stacked checkpoints and JAX scan function The concept of stacked vs. unstacked checkpoints is specific to JAX-based models that use the `jax.lax.scan` function ([doc](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)). `scan` is a powerful JAX feature that compiles sequential operations (like the layers of a Transformer) into a single, highly optimized kernel, avoiding the overhead of a Python for-loop. @@ -46,6 +46,7 @@ To work with `jax.lax.scan`, the model's parameters must be "stacked". For a Tra *Figure 1: A comparison of an unstacked checkpoint and a stacked checkpoint for a simple language model.* Their difference can also be represented in the following pytree structure: + ``` # Stacked (aka scanned) "params" : { @@ -74,9 +75,9 @@ To summarize the four checkpoint types: - **Stacked Inference Checkpoint:** Contains only model weights, but they are stacked for `scan`-compatible inference. Created by stripping the optimizer state from a stacked training checkpoint. - **Stacked Training Checkpoint:** The default format for saving and resuming training runs within MaxText. Contains both weights and optimizer state in a stacked format, optimized for `jax.lax.scan`. -In MaxText, we treat **Stacked Inference Checkpoints** as the default format for checkpoint conversion. For *saving and resuming* training, MaxText uses **Stacked Training Checkpoints** by default. +In MaxText, we treat **Stacked Inference Checkpoints** as the default format for checkpoint conversion. For *saving and resuming* training, MaxText uses **Stacked Training Checkpoints** by default. ---- +______________________________________________________________________ ## Using checkpoints in practice @@ -86,17 +87,17 @@ Beyond understanding the formats, it's crucial to know how to use checkpoints in MaxText automatically saves checkpoints periodically during a training run. These are **Stacked Training Checkpoints** that contain the full state needed to resume. -- `base_output_directory`: Specifies the GCS bucket directory where checkpoints will be saved. -- `enable_checkpointing`: A boolean to enable or disable checkpointing. -- `async_checkpoint`: Support training and checkpoint saving at the same time. -- `checkpoint_period`: The interval, in training steps, at which to save a new checkpoint. +- `base_output_directory`: Specifies the GCS bucket directory where checkpoints will be saved. +- `enable_checkpointing`: A boolean to enable or disable checkpointing. +- `async_checkpoint`: Support training and checkpoint saving at the same time. +- `checkpoint_period`: The interval, in training steps, at which to save a new checkpoint. Furthermore, MaxText supports emergency checkpointing, which saves a local copy of the checkpoint that can be restored quickly after an interruption. -- `enable_emergency_checkpoint`: A boolean to enable or disable this feature. -- `local_checkpoint_directory`: The local path for storing emergency checkpoints. -- `local_checkpoint_period`: The interval, in training steps, for saving local checkpoints. +- `enable_emergency_checkpoint`: A boolean to enable or disable this feature. +- `local_checkpoint_directory`: The local path for storing emergency checkpoints. +- `local_checkpoint_period`: The interval, in training steps, for saving local checkpoints. -More configs about checkpoints can be found in [here](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/configs/base.yml#L23-L65). +More configs about checkpoints can be found in [here](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/maxtext/configs/base.yml#L23-L65). For practical guides on checkpointing, please refer to [](checkpointing_solutions). diff --git a/docs/reference/core_concepts/moe_configuration.md b/docs/reference/core_concepts/moe_configuration.md index dfacd3fb60..c5e42b7153 100644 --- a/docs/reference/core_concepts/moe_configuration.md +++ b/docs/reference/core_concepts/moe_configuration.md @@ -16,28 +16,30 @@ # Mixture of Experts (MoE) Configuration -This document provides a detailed explanation of the configuration parameters related to Mixture of Experts (MoE) models in MaxText. These settings control the model architecture, routing mechanisms, and performance optimizations. Default values and parameter definitions are located in `src/MaxText/configs/base.yml` and are primarily used in `src/MaxText/layers/moe.py`. - +This document provides a detailed explanation of the configuration parameters related to Mixture of Experts (MoE) models in MaxText. These settings control the model architecture, routing mechanisms, and performance optimizations. Default values and parameter definitions are located in `src/maxtext/configs/base.yml` and are primarily used in `src/MaxText/layers/moe.py`. ## 1. Architecture ### MoE Strategy + MaxText supports both Dropless and Dropping strategies. Please refer to the decision tree below to determine the active strategy. ![Illustration of MoE strategy](../../_static/moe_strategy.png) *Figure 1: Decision Logic for MaxText MoE Strategies.* Dropless: -* [Tokamax Ragged Dot](https://github.com/openxla/tokamax/tree/main/tokamax/_src/ops/ragged_dot): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=True`. -* [Megablox](https://github.com/google/maxtext/tree/main/src/MaxText/kernels/megablox): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=True`. -* [JAX Ragged Dot](https://docs.jax.dev/en/latest/_autosummary/jax.lax.ragged_dot.html): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=False`. -* Dense Matmul: Enabled by setting `sparse_matmul=False, capacity_factor=-1`. + +- [Tokamax Ragged Dot](https://github.com/openxla/tokamax/tree/main/tokamax/_src/ops/ragged_dot): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=True`. +- [Megablox](https://github.com/google/maxtext/tree/main/src/MaxText/kernels/megablox): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=True`. +- [JAX Ragged Dot](https://docs.jax.dev/en/latest/_autosummary/jax.lax.ragged_dot.html): Enabled by setting `sparse_matmul=True, use_tokamax_gmm=False, megablox=False`. +- Dense Matmul: Enabled by setting `sparse_matmul=False, capacity_factor=-1`. Dropping: -* Dense Matmul: Enabled by setting `sparse_matmul=False, capacity_factor > 0` (commonly 1.0 to 1.25). +- Dense Matmul: Enabled by setting `sparse_matmul=False, capacity_factor > 0` (commonly 1.0 to 1.25). ### General Configuration + `num_experts`: The total number of routed experts available in the MoE layer. `num_experts_per_tok`: The number of experts selected for each token, often referred to as top-k strategy. @@ -53,6 +55,7 @@ Dropping: `float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability. Recommended specifically when lower precision types cause convergence or quality issues. ### Routing Mechanism + `use_random_routing`: If enabled, ignores the gate logits and routes tokens to random experts. This is designed to simulate load balancing for debugging and performance testing purposes. `n_routing_groups` and `topk_routing_group`: Experts are divided into n_routing_groups. The router first selects the top k highest-scoring groups (as `topk_routing_group`), and then selects experts only from those groups. @@ -70,17 +73,20 @@ Dropping: `norm_topk_prob`: If enabled, normalizes the router weights for the selected top-k experts. ### MLP Block & Computation + `sparse_matmul`: Determines whether to use efficient sparse matrix multiplication or dense matrix multiplication. - * `True`: Uses specialized kernels (like Tokamax Ragged Dot or Megablox) or JAX Ragged Dot to perform computation only on active tokens. This is generally faster for MoE. - * `False`: Performs dense computation with masking. This is typically used when checking numerical correctness or implementing dropping strategies. + +- `True`: Uses specialized kernels (like Tokamax Ragged Dot or Megablox) or JAX Ragged Dot to perform computation only on active tokens. This is generally faster for MoE. +- `False`: Performs dense computation with masking. This is typically used when checking numerical correctness or implementing dropping strategies. `use_tokamax_gmm`: If enabled, use Tokamax library's Ragged Dot for matmul. Recommended for dropless configurations. `megablox`: If enabled, use Megablox for sparse matrix operations. Effective only when `use_tokamax_gmm` is False. `capacity_factor`: A scalar multiplier for expert capacity. Effective only when `sparse_matmul` is False. - * Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped. - * Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline. + +- Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped. +- Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline. `use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. Recommended to replace the inefficient scatter-add generated by the `jax.numpy.take` in the backward pass. @@ -89,9 +95,11 @@ Dropping: `use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications that yields performance benefits. ## 2. Sharding + `expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include: - * `fsdp`: Treats the expert axis as a FSDP axis. - * `context`: Treats the expert axis as a context parallelism axis, useful for long context. + +- `fsdp`: Treats the expert axis as a FSDP axis. +- `context`: Treats the expert axis as a context parallelism axis, useful for long context. `use_ring_of_experts` (experimental): This feature requires expert parallelism. If enabled, it replaces the standard two All-to-All communications with All-Gather in dispatch and Reduce-Scatter in collect. By gathering inputs across all shards, it allows for local routing and Top-K calculations, followed by result aggregation via Reduce-Scatter. This approach is particularly effective for models with a large Top-K, as it gathers activations before they are replicated k times to reduce communication. @@ -100,25 +108,31 @@ Dropping: `shard_exp_on_fsdp`: If enabled, shard the expert dimension of the MLP weights on the FSDP axis, and recommended only when num_experts is a multiple of fsdp_parallelism. ## 3. Performance Tuning + These parameters provide granular control over the tiling dimensions for sparse matmul Pallas kernel. -* `wi_tile_...`: Tile size for the first layer of the MLP (Input -> Hidden). -* `wo_tile_...`: Tile size for the second layer of the MLP (Hidden -> Output). +- `wi_tile_...`: Tile size for the first layer of the MLP (Input -> Hidden). +- `wo_tile_...`: Tile size for the second layer of the MLP (Hidden -> Output). For each, you can control: -* `..._fwd_...`: Tile size for the forward pass. -* `..._dlhs_...`: Tile size for the backward pass gradient calculation w.r.t. activations. -* `..._drhs_...`: Tile size for the backward pass gradient calculation w.r.t. weights. + +- `..._fwd_...`: Tile size for the forward pass. +- `..._dlhs_...`: Tile size for the backward pass gradient calculation w.r.t. activations. +- `..._drhs_...`: Tile size for the backward pass gradient calculation w.r.t. weights. For each dimension, you can control: -* `..._batch_seq`: Tile size for batch x sequence dimension. -* `..._embed_dim`: Tile size for embedding dimension. -* `..._mlp_dim`: Tile size for MLP dimension. + +- `..._batch_seq`: Tile size for batch x sequence dimension. +- `..._embed_dim`: Tile size for embedding dimension. +- `..._mlp_dim`: Tile size for MLP dimension. Implementation Support: -* Megablox/JAX Ragged Dot: - * Supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`). - * Configs are enabled for INT8, FP8, and BF16. -* Tokamax Ragged Dot: - * Supports all 18 configurations. **Note**: Currently enabled for FP8 quantization; BF16 integration is in progress. +- Megablox/JAX Ragged Dot: + + - Supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`). + - Configs are enabled for INT8, FP8, and BF16. + +- Tokamax Ragged Dot: + + - Supports all 18 configurations. **Note**: Currently enabled for FP8 quantization; BF16 integration is in progress. diff --git a/docs/reference/core_concepts/quantization.md b/docs/reference/core_concepts/quantization.md index 701564b68d..1f9fb7ba87 100644 --- a/docs/reference/core_concepts/quantization.md +++ b/docs/reference/core_concepts/quantization.md @@ -13,29 +13,32 @@ See the License for the specific language governing permissions and limitations under the License. --> + (quantization)= + # Quantization Quantization in deep learning is the process of reducing the precision of numbers used to represent a model's weights and/or activations. Instead of using higher-precision floating-point formats like 32-bit floats (`float32`) or 16-bit brain floats (`bfloat16`), quantization maps these values to lower-precision numerical formats, most commonly 8-bit integers (`int8`) or floats (`fp8`). MaxText supports quantization via both the [AQT](https://github.com/google/aqt) and [Qwix](https://github.com/google/qwix) libraries. Qwix is the recommended approach, providing a non-intrusive way to apply Quantized Training (QT). -## Why use quantization? +## Why use quantization? The drive to use lower-precision formats like `int8` or `fp8` stems from significant performance advantages: **Faster computation**: Hardware accelerators like TPUs and GPUs often have specialized instructions for integer arithmetic. Operations on lower-precision data like `int8` or `fp8` can be significantly faster than on BF16 or FP32. For example, matrix multiplications with these formats can often be 2x or more faster on hardware supporting native lower-precision tensor cores. **Reduced memory footprint**: Storing weights and activations in `int8` or `fp8` requires 2x less memory compared to `bfloat16`. This reduces: + - **HBM usage**: Less memory is needed on the accelerator itself. - **Communication costs**: Less data needs to be transferred between memory and compute units, or across devices in distributed training, which makes these transfers faster and consumes less bandwidth. - **Reduced power consumption**: Lower precision operations and reduced memory access lead to less energy usage, which is crucial for deploying models on edge devices and for sustainable AI. The primary trade-off with quantization is between the model accuracy and computational performance: -* Reduced Dynamic Range & Precision: Lower-precision formats like `int8` or `fp8` can represent a much smaller range of values and with less precision than BF16. This can be problematic for models with wide distributions of weights or activations, potentially clipping large values or losing fine-grained details. -* Impact on Gradients: Gradients during backpropagation can have very different, often wider, distributions than weights or activations, making them more sensitive to quantization errors. -* Convergence Issues: The approximations introduced by quantization can sometimes hinder the model's ability to converge during training. +- Reduced Dynamic Range & Precision: Lower-precision formats like `int8` or `fp8` can represent a much smaller range of values and with less precision than BF16. This can be problematic for models with wide distributions of weights or activations, potentially clipping large values or losing fine-grained details. +- Impact on Gradients: Gradients during backpropagation can have very different, often wider, distributions than weights or activations, making them more sensitive to quantization errors. +- Convergence Issues: The approximations introduced by quantization can sometimes hinder the model's ability to converge during training. To overcome the challenges of quantization, libraries like Google's Accurate Quantized Training (AQT) and its successor Qwix (used in MaxText) employ a suite of advanced techniques. These methods ensure that models can be trained with low-precision arithmetic without significant loss in accuracy and with stable convergence. @@ -45,9 +48,9 @@ Quantized Training (QT) incorporates the effects of quantization into the traini Here’s how it works: -1. **Forward Pass**: During the forward pass, high-precision weights and activations are converted to a lower-precision format. This step simulates the information loss that occurs during quantization. The model then performs its computations using these lower-precision representations before they are converted back to a higher precision for the rest of the network. This process forces the model to become robust to the noise and reduced range of quantized values. +1. **Forward Pass**: During the forward pass, high-precision weights and activations are converted to a lower-precision format. This step simulates the information loss that occurs during quantization. The model then performs its computations using these lower-precision representations before they are converted back to a higher precision for the rest of the network. This process forces the model to become robust to the noise and reduced range of quantized values. -2. **Backward Pass**: Standard backpropagation cannot flow through the non-differentiable quantization operations (like rounding). To solve this, QT uses the **Straight-Through Estimator (STE)**. The STE essentially "ignores" the non-differentiable quantization step during the backward pass, passing the gradients through as if the operation was an identity function. This allows the high-precision weights to be updated based on the loss, enabling the model to learn effectively. +2. **Backward Pass**: Standard backpropagation cannot flow through the non-differentiable quantization operations (like rounding). To solve this, QT uses the **Straight-Through Estimator (STE)**. The STE essentially "ignores" the non-differentiable quantization step during the backward pass, passing the gradients through as if the operation was an identity function. This allows the high-precision weights to be updated based on the loss, enabling the model to learn effectively. By integrating the quantization simulation directly into the training, the model learns to minimize the impact of precision loss, resulting in a more accurate quantized model. @@ -59,11 +62,11 @@ You can enable quantization in MaxText by setting flags in your configuration fi The primary flags to control quantization are: -* `use_qwix_quantization`: A boolean flag. - * Set to `True` to enable quantization using the Qwix library. - * Set to `False` (or omit) to use the AQT library if `quantization` is set. -* `quantization`: A string that specifies the type of quantization to apply. The accepted values depend on whether you are using Qwix or AQT. -* `quantization_calibration_method`: The calibration method for weights and activations (e.g., `"absmax"`). This is mainly for Qwix. +- `use_qwix_quantization`: A boolean flag. + - Set to `True` to enable quantization using the Qwix library. + - Set to `False` (or omit) to use the AQT library if `quantization` is set. +- `quantization`: A string that specifies the type of quantization to apply. The accepted values depend on whether you are using Qwix or AQT. +- `quantization_calibration_method`: The calibration method for weights and activations (e.g., `"absmax"`). This is mainly for Qwix. ### Qwix Quantization (Recommended) @@ -73,18 +76,18 @@ To use Qwix, you must set `use_qwix_quantization=True`. Qwix is a powerful and n Common options for the `quantization` flag when using Qwix include: -* `"int8"`: 8-bit integer quantization. -* `"fp8"`: 8-bit floating-point quantization. -* `"fp8_full"`: FP8 quantization with static scaling. -* `"fp8_gpu"`: FP8 for NVIDIA GPUs. -* `"fp8_nanoo"`: FP8 for AMD MI300/MI325 GPUs. +- `"int8"`: 8-bit integer quantization. +- `"fp8"`: 8-bit floating-point quantization. +- `"fp8_full"`: FP8 quantization with static scaling. +- `"fp8_gpu"`: FP8 for NVIDIA GPUs. +- `"fp8_nanoo"`: FP8 for AMD MI300/MI325 GPUs. #### Example command for Qwix Here is an example of how to run a training job with int8 quantization enabled via Qwix: ```bash -python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs:// dataset_type=synthetic use_qwix_quantization=true quantization='int8' +python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs:// dataset_type=synthetic use_qwix_quantization=true quantization='int8' ``` #### The Qwix Interception API @@ -96,23 +99,25 @@ Instead, you define a set of quantization rules externally. Qwix then uses a con A quantization rule can be defined as follows: ```python -rule = [qwix.QtRule( - module_path="decoder/.*layers.*", - weight_qtype=jnp.int8, - act_qtype=jnp.int8, - bwd_qtype=jnp.int8, - bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, - op_names=("dot_general",), - )] +rule = [ + qwix.QtRule( + module_path="decoder/.*layers.*", + weight_qtype=jnp.int8, + act_qtype=jnp.int8, + bwd_qtype=jnp.int8, + bwd_weight_grad_tile_size=1 / config.quantization_local_shard_count, + op_names=("dot_general",), + ) +] ``` **`QtRule` parameters**: -* `module_path`: A regex to match the layers to which this rule should be applied. -* `weight_qtype`: The target quantization type for weights (e.g., `jnp.int8`). -* `act_qtype`: The target quantization type for activations. -* `bwd_qtype`: The quantization type for the backward pass. -* `op_names`: The operations to be quantized (e.g., `"dot_general"`). +- `module_path`: A regex to match the layers to which this rule should be applied. +- `weight_qtype`: The target quantization type for weights (e.g., `jnp.int8`). +- `act_qtype`: The target quantization type for activations. +- `bwd_qtype`: The quantization type for the backward pass. +- `op_names`: The operations to be quantized (e.g., `"dot_general"`). This rule is then used within a `QtProvider` to quantize the model automatically: @@ -137,38 +142,44 @@ When using AQT, you can pass one of the following values to the `quantization` f #### Example command for AQT ```bash -python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs:// dataset_type=synthetic use_qwix_quantization=false quantization='int8' +python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$YOUR_JOB_NAME base_output_directory=gs:// dataset_type=synthetic use_qwix_quantization=false quantization='int8' ``` + Note that `use_qwix_quantization` is not set to `True`. For further reading, please refer to the [Qwix Read the Docs website](https://qwix.readthedocs.io/en/latest/get_started.html#). ## DeepSeek V3 Fine-tuning FP8 Recipe + To improve the performance of DeepSeek V3 fine-tuning, we developed a custom recipe optimized for FP8 throughput. The method prioritizes specific compute-intensive and bandwidth-heavy components while preserving training stability through a fine-grained scaling strategy. ### Quantization Scope + To realize these gains, the recipe employs a w8a8g8 (8-bit weights, activations and gradients) strategy targeting three primary areas: -* Megablox Kernels: Specifically the `gmm` and `tgmm` operations. +- Megablox Kernels: Specifically the `gmm` and `tgmm` operations. -* Attention Projections: Utilizing convolution fusion. +- Attention Projections: Utilizing convolution fusion. -* Communication: Specifically the weight All-Gathers. +- Communication: Specifically the weight All-Gathers. ### FP8 Recipe -* Rounding: rounding to nearest even -* Precision - * Activations and weights: e4m3fn - * Gradients: e5m2 -* Scaling granularity: per-axis -* Scaling mode: - * static for weights and activations - * dynamic for gradients + +- Rounding: rounding to nearest even +- Precision + - Activations and weights: e4m3fn + - Gradients: e5m2 +- Scaling granularity: per-axis +- Scaling mode: + - static for weights and activations + - dynamic for gradients ### Convergence + To validate this recipe, we utilized MaxText following the MLPerf Training framework by MLCommons to ensure a reproducible and standardized evaluation. Using the C4 dataset (loaded via TFDS) as the reference corpus, we tracked convergence by monitoring validation loss on a held-out split. This aligns with MLPerf’s time-to-quality principle, where the primary metric is the speed at which the model achieves target quality. For this specific case, we derived our training duration from the MLPerf 405B benchmark, targeting roughly 2–3 billion tokens after resuming from a checkpoint. In our configuration, we executed 300 steps with a sequence length of 4096 and a global batch size of 2048, resulting in a total of approximately 2.5 billion tokens. ### Performance Sensitivity + Please note that the FP8 benefits are highly sensitive to model parameters, the efficiency of the BF16 baseline, and hardware utilization; consequently, results will vary when this recipe is applied to other models. Any variance in these factors shifts the ratio of compute-bound to memory-bound operations, directly altering the potential gains. diff --git a/docs/run_maxtext/run_maxtext_localhost.md b/docs/run_maxtext/run_maxtext_localhost.md index 9e1d847d43..e56957788d 100644 --- a/docs/run_maxtext/run_maxtext_localhost.md +++ b/docs/run_maxtext/run_maxtext_localhost.md @@ -58,7 +58,7 @@ bash tools/setup/setup.sh DEVICE={tpu|gpu} After the installation is complete, run a short training job using synthetic data to confirm everything is working correctly. This command trains a model for just 10 steps. Remember to replace `$YOUR_JOB_NAME` with a unique name for your run and `gs://` with the path to the GCS bucket you configured in the prerequisites. ```bash -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ dataset_type=synthetic \ @@ -72,7 +72,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \ To demonstrate model output, run the following command: ```bash -python3 -m maxtext.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/maxtext/configs/base.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ per_device_batch_size=1 @@ -82,17 +82,17 @@ python3 -m maxtext.decode src/MaxText/configs/base.yml \ ### Running models using provided configs -MaxText provides many OSS model configs that you can use directly to run training jobs on those model-specific architectures. These model-specific YAML files are located in `src/MaxText/configs/models` for TPU-oriented defaults, and `src/MaxText/configs/models/gpu` for GPU-oriented defaults. +MaxText provides many OSS model configs that you can use directly to run training jobs on those model-specific architectures. These model-specific YAML files are located in `src/maxtext/configs/models` for TPU-oriented defaults, and `src/maxtext/configs/models/gpu` for GPU-oriented defaults. #### Training on TPUs -To use a pre-configured model for TPUs, you override the `model_name` parameter, and MaxText will automatically load the corresponding configuration from the `src/MaxText/configs/models` directory and merge it with the settings from `src/MaxText/configs/base.yml`. +To use a pre-configured model for TPUs, you override the `model_name` parameter, and MaxText will automatically load the corresponding configuration from the `src/maxtext/configs/models` directory and merge it with the settings from `src/maxtext/configs/base.yml`.
llama3-8b (TPU) ```bash -python3 -m MaxText.train MaxText/configs/base.yml \ +python3 -m MaxText.train maxtext/configs/base.yml \ model_name=llama3-8b \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ @@ -106,7 +106,7 @@ python3 -m MaxText.train MaxText/configs/base.yml \ qwen3-4b (TPU) ```bash -python3 -m MaxText.train MaxText/configs/base.yml \ +python3 -m MaxText.train maxtext/configs/base.yml \ model_name=qwen3-4b \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ @@ -118,13 +118,13 @@ python3 -m MaxText.train MaxText/configs/base.yml \ #### Training on GPUs -To use a GPU-optimized configuration, you should specify the path to the model's YAML file within the `src/MaxText/configs/models/gpu` directory as the main config file in the command. These files typically inherit from `base.yml` and set the appropriate `model_name` internally, as well as GPU-specific settings. +To use a GPU-optimized configuration, you should specify the path to the model's YAML file within the `src/maxtext/configs/models/gpu` directory as the main config file in the command. These files typically inherit from `base.yml` and set the appropriate `model_name` internally, as well as GPU-specific settings.
mixtral-8x7b (GPU) ```bash -python3 -m MaxText.train src/MaxText/configs/models/gpu/mixtral_8x7b.yml \ +python3 -m MaxText.train src/maxtext/configs/gpu/models/mixtral_8x7b.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ dataset_type=synthetic \ @@ -139,7 +139,7 @@ This will load `gpu/mixtral_8x7b.yml`, which inherits from `base.yml`. llama3-8b (GPU) ```bash -python3 -m MaxText.train src/MaxText/configs/models/gpu/llama3-8b.yml \ +python3 -m MaxText.train src/maxtext/configs/gpu/models/llama3-8b.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ dataset_type=synthetic \ diff --git a/docs/run_maxtext/run_maxtext_single_host_gpu.md b/docs/run_maxtext/run_maxtext_single_host_gpu.md index 61cff49353..2335ba30dc 100644 --- a/docs/run_maxtext/run_maxtext_single_host_gpu.md +++ b/docs/run_maxtext/run_maxtext_single_host_gpu.md @@ -21,6 +21,7 @@ This is a short guide to run Maxtext on GPU. For this current set of instruction ## Create a GPU VM Follow the instructions to create a3 high or an a3 Mega VM + - https://cloud.google.com/compute/docs/gpus/create-gpu-vm-accelerator-optimized#console - Add enough disk space to work through the examples (at least 500GB) @@ -42,9 +43,9 @@ Related NVIDIA Content: - NVIDIA JAX Session: - Learn more about Jax on GPUs: - - https://www.nvidia.com/en-us/on-demand/session/gtc24-s62246/ + - https://www.nvidia.com/en-us/on-demand/session/gtc24-s62246/ - NVIDIA JAX Toolbox: - - https://github.com/NVIDIA/JAX-Toolbox + - https://github.com/NVIDIA/JAX-Toolbox ## Install Docker @@ -109,7 +110,7 @@ You should see the following: Note: If you only see CPUDevice, that means there is a issue with NVIDIA Container and you need to stop and fix the issue. -We will Run the next commands from inside the docker for convenience. +We will Run the next commands from inside the docker for convenience. ## SSH into the docker @@ -147,7 +148,7 @@ Hardware: GPU ``` ```bash -python3 -m MaxText.train src/MaxText/configs/base.yml run_name=gpu01 base_output_directory=/deps/output \ +python3 -m MaxText.train src/maxtext/configs/base.yml run_name=gpu01 base_output_directory=/deps/output \ dataset_type=synthetic enable_checkpointing=True steps=10 attention=cudnn_flash_te scan_layers=False \ use_iota_embed=True hardware=gpu per_device_batch_size=12 ``` @@ -156,7 +157,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml run_name=gpu01 base_output You can find the optimized running of LLama Models for various host configurations here: -https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/configs/a3/llama_2_7b +https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/gpu/a3/llama_2_7b `1vm.sh` modified script below: @@ -167,7 +168,7 @@ echo "Running 1vm.sh" # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \ # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \ # --device-type ${DEVICE_TYPE} --num-slices 1 \ -# --command "bash src/MaxText/configs/a3/llama_2_7b/1vm.sh" +# --command "bash src/maxtext/configs/gpu/a3/llama_2_7b/1vm.sh" # Stop execution if any command exits with error set -e @@ -193,7 +194,7 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ # 1 node, DATA_DP=1, ICI_FSDP=8 -python3 -m MaxText.train src/MaxText/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME dcn_data_parallelism=1 \ +python3 -m MaxText.train src/maxtext/configs/gpu/models/llama2_7b.yml run_name=$RUN_NAME dcn_data_parallelism=1 \ ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH attention=cudnn_flash_te scan_layers=False \ use_iota_embed=True hardware=gpu ``` diff --git a/docs/run_maxtext/run_maxtext_via_multihost_job.md b/docs/run_maxtext/run_maxtext_via_multihost_job.md index bdf71ae97d..15b9d8792c 100644 --- a/docs/run_maxtext/run_maxtext_via_multihost_job.md +++ b/docs/run_maxtext/run_maxtext_via_multihost_job.md @@ -15,61 +15,64 @@ --> (run-multihost-job)= + # Production Jobs on Multiple Slices (`multihost_job.py`) The workflow using `multihost_job.py` is optimized for long running experiments, providing resiliency against hardware failure and avoiding long running ssh connections. Its latency is much higher than `multihost_runner.py` because it needs to provision new capacity each time. The `multihost_job.py` script ends once the request to create the TPUs is issued. Logs are written both to gcloud in real time and also sent to GCS at the end of the job. The `multihost_job.py` script: -* Copies your code to your GCS bucket -* Spins up specified TPU VM(s) via CQR -* Directs the TPU's to download then run that code. Because this logic is within the CQR's startup script, if there hardware is interrupted, the job will be rescheduled and resumed. -* Logs to gcloud, and additionally sends the logs to GCS at the job end -* Delete the TPUs and QR at the end of the job. +- Copies your code to your GCS bucket +- Spins up specified TPU VM(s) via CQR +- Directs the TPU's to download then run that code. Because this logic is within the CQR's startup script, if there hardware is interrupted, the job will be rescheduled and resumed. +- Logs to gcloud, and additionally sends the logs to GCS at the job end +- Delete the TPUs and QR at the end of the job. 1. **Choose a directory on your runner machine to develop and clone MaxText into.** The runner machine can -either be a TPUVM or not. If your runner machine is a TPUVM, it needs service account roles that grant it permission to create queued resources and has write access to GCS, such as the `TPU ADMIN` and `STORAGE ADMIN` roles. Clone MaxText, and cd into the root of the repo. + either be a TPUVM or not. If your runner machine is a TPUVM, it needs service account roles that grant it permission to create queued resources and has write access to GCS, such as the `TPU ADMIN` and `STORAGE ADMIN` roles. Clone MaxText, and cd into the root of the repo. 2. **Set your project, zone.** - Set your gcloud config, see https://cloud.google.com/sdk/gcloud/reference/config for more. - ``` - PROJECT= - ``` + Set your gcloud config, see https://cloud.google.com/sdk/gcloud/reference/config for more. + + ``` + PROJECT= + ``` - ``` - ZONE= - ``` + ``` + ZONE= + ``` - ``` - gcloud config set project $PROJECT - gcloud config set compute/zone $ZONE - ``` + ``` + gcloud config set project $PROJECT + gcloud config set compute/zone $ZONE + ``` 3. **Link to a GCS bucket.** - Create a bucket if you don't already have one, see: https://cloud.google.com/storage/docs/creating-buckets for instructions to create one. Once you've identified your bucket: + Create a bucket if you don't already have one, see: https://cloud.google.com/storage/docs/creating-buckets for instructions to create one. Once you've identified your bucket: - ``` - BUCKET_NAME= - ``` + ``` + BUCKET_NAME= + ``` 4. **Run your training job.** - ```{important} - `multihost_job` creates a request for new capacity for each run! You cannot use this tool on existing capacity, instead we recommend `multihost_runner` for this purpose. - ``` + ```{important} + `multihost_job` creates a request for new capacity for each run! You cannot use this tool on existing capacity, instead we recommend `multihost_runner` for this purpose. + ``` - Choose the number of nodes (we use 2 below, but you may customize this and other feature of your TPU(s)) - ```sh - NODE_COUNT=2 - ``` - ```sh - RUN_NAME=$YOUR_JOB_NAME # You may set this to any unique name for a fresh run. - python3 multihost_job.py --NUM_SLICES=$NODE_COUNT --RUN_NAME=$RUN_NAME --BUCKET_NAME=$BUCKET_NAME --CQR_EXTRA_ARGS="--reserved" --COMMAND="bash tools/setup/setup.sh && python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$RUN_NAME" - ``` + Choose the number of nodes (we use 2 below, but you may customize this and other feature of your TPU(s)) - We tell `multihost_job` to target the `reserved` pool by by including `--reserved` as extra arguments to the CQR request, but you may instead target the `on-demand` pool by removing the `--CQR_EXTRA_ARGS` flag (on-demand is default), or the pre-emptible pool with `--CQR_EXTRA_ARGS="--best-effort"`, which may be necessary if your reservation is full. + ```sh + NODE_COUNT=2 + ``` -5. **View the job's logs in cloud logging.** + ```sh + RUN_NAME=$YOUR_JOB_NAME # You may set this to any unique name for a fresh run. + python3 multihost_job.py --NUM_SLICES=$NODE_COUNT --RUN_NAME=$RUN_NAME --BUCKET_NAME=$BUCKET_NAME --CQR_EXTRA_ARGS="--reserved" --COMMAND="bash tools/setup/setup.sh && python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$RUN_NAME" + ``` - The link to your job's cloud logging is printed at the end of `multihost_job` output. Additionally logs are saved to GCS when your job finishes, and this bucket's URL is also printed by `multihost_job`. + We tell `multihost_job` to target the `reserved` pool by by including `--reserved` as extra arguments to the CQR request, but you may instead target the `on-demand` pool by removing the `--CQR_EXTRA_ARGS` flag (on-demand is default), or the pre-emptible pool with `--CQR_EXTRA_ARGS="--best-effort"`, which may be necessary if your reservation is full. + +5. **View the job's logs in cloud logging.** + The link to your job's cloud logging is printed at the end of `multihost_job` output. Additionally logs are saved to GCS when your job finishes, and this bucket's URL is also printed by `multihost_job`. diff --git a/docs/run_maxtext/run_maxtext_via_multihost_runner.md b/docs/run_maxtext/run_maxtext_via_multihost_runner.md index ac5adf9d72..143f5e9f84 100644 --- a/docs/run_maxtext/run_maxtext_via_multihost_runner.md +++ b/docs/run_maxtext/run_maxtext_via_multihost_runner.md @@ -15,6 +15,7 @@ --> (run-multihost-runner)= + # Quicks Experiments on Multiple Hosts or Multiple Slices (`multihost_runner.py`) This workflow using `multihost_runner.py` is optimized for quick experiments, repeatedly reusing the same TPUs. Because the `multihost_runner.py` script depends on long-lived `ssh` connections, we do not recommend it for any long-running jobs. @@ -23,79 +24,97 @@ We call the `runner` machine the one that `multihost_runner.py` is called from. If the runner machine is a cloud VM, it must be in the same project as the workers. The `multihost_runner.py` script: -* Distributes your code by recursively copying the current state of the chosen directory to multiple worker TPUVM. -* Runs the code on the workers -* Logs and monitors the processes' error statuses and brings the logs back to the runner machine. + +- Distributes your code by recursively copying the current state of the chosen directory to multiple worker TPUVM. +- Runs the code on the workers +- Logs and monitors the processes' error statuses and brings the logs back to the runner machine. Although there are several steps below, most are for the initial setup. Once setup you can continually make changes to your code and re-run your code with only step 5. 1. **Choose a directory on your runner machine to develop and clone MaxText into.** The runner machine can -either be a TPUVM or not, but it cannot be one of the workers. If your runner machine is a TPUVM, it needs service account roles that grant it permission to create queued resources and ssh into them, such as the `TPU ADMIN` role. Clone MaxText, and cd into the root of the repo. + either be a TPUVM or not, but it cannot be one of the workers. If your runner machine is a TPUVM, it needs service account roles that grant it permission to create queued resources and ssh into them, such as the `TPU ADMIN` role. Clone MaxText, and cd into the root of the repo. 2. **Set your project, zone, and ssh keys.** - Set your gcloud config, see https://cloud.google.com/sdk/gcloud/reference/config for more. - ``` - PROJECT= - ``` - ``` - ZONE= - ``` - ``` - gcloud config set project $PROJECT - gcloud config set compute/zone $ZONE - ``` - - Create ssh keys for gcloud, we recommend leaving a blank password (hit enter twice after running the below command). If you are prompted that the the file already exists you can choose not to overwrite by selecting "n". - ``` - ssh-keygen -f ~/.ssh/google_compute_engine - ``` + Set your gcloud config, see https://cloud.google.com/sdk/gcloud/reference/config for more. + + ``` + PROJECT= + ``` + + ``` + ZONE= + ``` + + ``` + gcloud config set project $PROJECT + gcloud config set compute/zone $ZONE + ``` + + Create ssh keys for gcloud, we recommend leaving a blank password (hit enter twice after running the below command). If you are prompted that the the file already exists you can choose not to overwrite by selecting "n". + + ``` + ssh-keygen -f ~/.ssh/google_compute_engine + ``` 3. **Create your instances via Queued Resource (QR).** - Choose names for your TPUs and QR: - ``` - TPU_PREFIX=$YOUR_TPU_NAME # Use new names when you create new TPUs - QR_ID=$TPU_PREFIX # Convenient to reuse the node names, but can be different - ``` - Choose the number of nodes (we use 2 below, but you may customize this and other feature of your TPU(s)) - ``` - NODE_COUNT=2 - ``` - Create a multislice environment of nodes using create queued resources - ``` - gcloud alpha compute tpus queued-resources create $QR_ID --accelerator-type=v4-8 --runtime-version=tpu-ubuntu2204-base --node-count=$NODE_COUNT --node-prefix=$TPU_PREFIX --reserved - ``` - We target the `reserved` pool above, but you may instead target the `on-demand` pool by omitting this flag, - or target pre-emptible capacity with the `--best-effort` flag, which may be necessary if your reservation is full. - - You have to wait for the QR to become `ACTIVE` (as opposed to `ACCEPTED` or `PROVISIONING`) which corresponds to the worker nodes becoming `READY` (as opposed to `CREATING`). This may take a minute or two and can be checked via - ``` - gcloud alpha compute tpus queued-resources list --filter=$QR_ID - ``` + Choose names for your TPUs and QR: + + ``` + TPU_PREFIX=$YOUR_TPU_NAME # Use new names when you create new TPUs + QR_ID=$TPU_PREFIX # Convenient to reuse the node names, but can be different + ``` + + Choose the number of nodes (we use 2 below, but you may customize this and other feature of your TPU(s)) + + ``` + NODE_COUNT=2 + ``` + + Create a multislice environment of nodes using create queued resources + + ``` + gcloud alpha compute tpus queued-resources create $QR_ID --accelerator-type=v4-8 --runtime-version=tpu-ubuntu2204-base --node-count=$NODE_COUNT --node-prefix=$TPU_PREFIX --reserved + ``` + + We target the `reserved` pool above, but you may instead target the `on-demand` pool by omitting this flag, + or target pre-emptible capacity with the `--best-effort` flag, which may be necessary if your reservation is full. + + You have to wait for the QR to become `ACTIVE` (as opposed to `ACCEPTED` or `PROVISIONING`) which corresponds to the worker nodes becoming `READY` (as opposed to `CREATING`). This may take a minute or two and can be checked via + + ``` + gcloud alpha compute tpus queued-resources list --filter=$QR_ID + ``` + 4. **Install dependencies.** - Install the dependencies of `train.py` on each worker using `multihost_runner.py`: - ``` - python3 multihost_runner.py --TPU_PREFIX=$TPU_PREFIX --COMMAND="bash tools/setup/setup.sh" - ``` - If you are running the `multihost_runner.py` script from a TPUVM, you will need to set `--INTERNAL_IP=true`. + Install the dependencies of `train.py` on each worker using `multihost_runner.py`: + + ``` + python3 multihost_runner.py --TPU_PREFIX=$TPU_PREFIX --COMMAND="bash tools/setup/setup.sh" + ``` + + If you are running the `multihost_runner.py` script from a TPUVM, you will need to set `--INTERNAL_IP=true`. 5. **Run your training job.** - Set a RUN_NAME for your job: - ``` - RUN_NAME=$YOUR_JOB_NAME # You may set this to any unique name for a fresh run. - ``` - Set config values for `base_output_directory` and `dataset_path` in `configs/base.yml` if not set already. - ``` - python3 multihost_runner.py --TPU_PREFIX=$TPU_PREFIX --COMMAND="python3 -m MaxText.train src/MaxText/configs/base.yml run_name=$RUN_NAME" - ``` - If you are running the `multihost_runner.py` script from a TPUVM, you will need to set `--INTERNAL_IP=true`. + Set a RUN_NAME for your job: -6. **Clean up TPUs and QR when finished.** + ``` + RUN_NAME=$YOUR_JOB_NAME # You may set this to any unique name for a fresh run. + ``` + + Set config values for `base_output_directory` and `dataset_path` in `configs/base.yml` if not set already. + + ``` + python3 multihost_runner.py --TPU_PREFIX=$TPU_PREFIX --COMMAND="python3 -m MaxText.train src/maxtext/configs/base.yml run_name=$RUN_NAME" + ``` - ``` - gcloud alpha compute tpus queued-resources delete $QR_ID --force --async - ``` + If you are running the `multihost_runner.py` script from a TPUVM, you will need to set `--INTERNAL_IP=true`. + +6. **Clean up TPUs and QR when finished.** - The `--force` flag deletes both the queued resources and the TPU VMs, without it only a `SUSPENDED` queued resource whose TPUs have already been deleted can itself be deleted. We highly recommend the `--async` flag since deleting the TPUs and QR will take a minute or two. + ``` + gcloud alpha compute tpus queued-resources delete $QR_ID --force --async + ``` + The `--force` flag deletes both the queued resources and the TPU VMs, without it only a `SUSPENDED` queued resource whose TPUs have already been deleted can itself be deleted. We highly recommend the `--async` flag since deleting the TPUs and QR will take a minute or two. diff --git a/docs/run_maxtext/run_maxtext_via_pathways.md b/docs/run_maxtext/run_maxtext_via_pathways.md index 5ef3e29b1b..46b70804c8 100644 --- a/docs/run_maxtext/run_maxtext_via_pathways.md +++ b/docs/run_maxtext/run_maxtext_via_pathways.md @@ -15,6 +15,7 @@ --> (run-pathways)= + # Via Pathways This guide provides a comprehensive walkthrough for running MaxText workloads on a Google Kubernetes Engine (GKE) cluster using Pathways. Pathways acts as a powerful orchestrator for large-scale JAX jobs on AI Hypercomputer infrastructure. @@ -22,31 +23,37 @@ This guide provides a comprehensive walkthrough for running MaxText workloads on This document assumes you have already created a Pathways GKE cluster using `xpk`. If you haven't, follow the instructions at the [Google Cloud Pathways & XPK documentation](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster#xpk). We will cover two primary modes of operation: -* **Batch workload**: Ideal for long-running, non-interactive training jobs. -* **Headless workload**: Ideal for interactive development, debugging, and running code from a local machine or CPU VM. + +- **Batch workload**: Ideal for long-running, non-interactive training jobs. +- **Headless workload**: Ideal for interactive development, debugging, and running code from a local machine or CPU VM. ## 1. Prerequisites Before you can run a MaxText workload, you must complete the following setup steps. -1. **Install XPK and its dependencies**. Ensure that the `xpk` command-line tool is installed. -2. **Create a GKE cluster** configured for Pathways. -3. **Build and upload a MaxText Docker image** to your project's Artifact Registry. +1. **Install XPK and its dependencies**. Ensure that the `xpk` command-line tool is installed. + +2. **Create a GKE cluster** configured for Pathways. + +3. **Build and upload a MaxText Docker image** to your project's Artifact Registry. - Step 1: Build the Docker image for a TPU device. This image contains MaxText and its dependencies. - ```shell - bash dependencies/scripts/docker_build_dependency_image.sh DEVICE=tpu MODE=stable - ``` + Step 1: Build the Docker image for a TPU device. This image contains MaxText and its dependencies. - Step 2: Configure Docker to authenticate with Google Cloud - ```shell - gcloud auth configure-docker - ``` + ```shell + bash dependencies/scripts/docker_build_dependency_image.sh DEVICE=tpu MODE=stable + ``` - Step 3: Upload the image to your project's registry. Replace `$USER_runner` with your desired image name. - ```shell - bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=$USER_runner - ``` + Step 2: Configure Docker to authenticate with Google Cloud + + ```shell + gcloud auth configure-docker + ``` + + Step 3: Upload the image to your project's registry. Replace `$USER_runner` with your desired image name. + + ```shell + bash dependencies/scripts/docker_upload_runner.sh CLOUD_IMAGE_NAME=$USER_runner + ``` ## 2. Environment configuration @@ -87,7 +94,7 @@ xpk workload create-pathways \ --project=$PROJECT \ --zone=$ZONE \ --docker-image=${DOCKER_IMAGE} \ - --command="python3 -m MaxText.train src/MaxText/configs/base.yml \ + --command="python3 -m MaxText.train src/maxtext/configs/base.yml \ base_output_directory=gs://${BUCKET_NAME} \ per_device_batch_size=1 \ enable_checkpointing=false \ @@ -145,7 +152,7 @@ export JAX_PLATFORMS=proxy export JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 # Run the training script -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ base_output_directory=gs://${BUCKET_NAME} \ per_device_batch_size=1 \ enable_checkpointing=false \ @@ -153,19 +160,20 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \ enable_single_controller=True \ run_name=${RUN_NAME}-pathways-headless ``` + The output streams directly to your terminal, just as if you were running on a local accelerator. ## Troubleshooting -* **Permission denied errors for Cloud Storage bucket**: Check that the service account used by your GKE nodes has "Storage Object Admin" permissions on your GCS bucket. -* **`Image not found` or `ImagePullBackOff`**: - * Verify your `DOCKER_IMAGE` variable is correct. - * Ensure you have successfully pushed the image to your project's Artifact Registry. - * Check that your GKE cluster has permissions to pull from the registry. -* **`kubectl port-forward` fails**: - * Confirm that the pod from Step 1 is running (`kubectl get pods`). The name should match `${WORKLOAD_NAME}-pathways-head-0`. - * Ensure you are authenticated with `kubectl` and have the correct context set for your GKE cluster. -* Make sure you import `pathwaysutils` package and call `pathwaysutils.initialize()` in your script when running the workload. +- **Permission denied errors for Cloud Storage bucket**: Check that the service account used by your GKE nodes has "Storage Object Admin" permissions on your GCS bucket. +- **`Image not found` or `ImagePullBackOff`**: + - Verify your `DOCKER_IMAGE` variable is correct. + - Ensure you have successfully pushed the image to your project's Artifact Registry. + - Check that your GKE cluster has permissions to pull from the registry. +- **`kubectl port-forward` fails**: + - Confirm that the pod from Step 1 is running (`kubectl get pods`). The name should match `${WORKLOAD_NAME}-pathways-head-0`. + - Ensure you are authenticated with `kubectl` and have the correct context set for your GKE cluster. +- Make sure you import `pathwaysutils` package and call `pathwaysutils.initialize()` in your script when running the workload. ## More information diff --git a/docs/run_maxtext/run_maxtext_via_xpk.md b/docs/run_maxtext/run_maxtext_via_xpk.md index a61f1dd987..a29d4d207a 100644 --- a/docs/run_maxtext/run_maxtext_via_xpk.md +++ b/docs/run_maxtext/run_maxtext_via_xpk.md @@ -15,6 +15,7 @@ --> (run-xpk)= + # At scale with XPK This guide provides the recommended workflow for running MaxText on Google Kubernetes Engine (GKE) using the **Accelerated Processing Kit (XPK)**. For a complete reference on XPK, please see the [official XPK repository](https://github.com/AI-Hypercomputer/xpk). @@ -36,7 +37,7 @@ XPK abstracts away the complexity of cluster management and job submission, hand +--------------------------+ +--------------------+ +-------------------+ ``` -* * * * * +______________________________________________________________________ ## 1. Prerequisites @@ -44,59 +45,59 @@ Before you begin, you must have the necessary tools installed and permissions co ### Required tools -- **Python >= 3.12** with `pip` and `venv`. +- **Python >= 3.12** with `pip` and `venv`. -- **Google Cloud CLI (`gcloud`):** Install it from [here](https://cloud.google.com/sdk/docs/install) and then run `gcloud init`. +- **Google Cloud CLI (`gcloud`):** Install it from [here](https://cloud.google.com/sdk/docs/install) and then run `gcloud init`. -- **kubectl:** The Kubernetes command-line tool. +- **kubectl:** The Kubernetes command-line tool. -- **Docker:** Follow the [installation instructions](https://docs.docker.com/engine/install/) and complete the [post-install steps](https://docs.docker.com/engine/install/linux-postinstall/) to run Docker without `sudo`. +- **Docker:** Follow the [installation instructions](https://docs.docker.com/engine/install/) and complete the [post-install steps](https://docs.docker.com/engine/install/linux-postinstall/) to run Docker without `sudo`. ### GCP permissions Your Google Cloud user account needs the following IAM roles for the project you're using: -- Artifact Registry Writer +- Artifact Registry Writer -- Compute Admin +- Compute Admin -- Kubernetes Engine Admin +- Kubernetes Engine Admin -- Logging Admin +- Logging Admin -- Monitoring Admin +- Monitoring Admin -- Service Account User +- Service Account User -- Storage Admin +- Storage Admin -- Vertex AI Administrator +- Vertex AI Administrator -* * * * * +______________________________________________________________________ ## 2. One-time environment setup These commands configure your local environment to connect to Google Cloud services. -1. **Authenticate gcloud** +1. **Authenticate gcloud** - ``` - gcloud auth login - ``` + ``` + gcloud auth login + ``` -2. **Install GKE auth plugin** +2. **Install GKE auth plugin** - ``` - sudo apt-get update && sudo apt-get install google-cloud-sdk-gke-gcloud-auth-plugin - ``` + ``` + sudo apt-get update && sudo apt-get install google-cloud-sdk-gke-gcloud-auth-plugin + ``` -3. **Configure Docker credentials** +3. **Configure Docker credentials** - ``` - gcloud auth configure-docker - ``` + ``` + gcloud auth configure-docker + ``` -* * * * * +______________________________________________________________________ ## 3. Install XPK @@ -113,32 +114,32 @@ source ~/xpk_venv/bin/activate pip install xpk ``` -* * * * * +______________________________________________________________________ ## 4. Build the MaxText Docker image -1. **Clone the MaxText repository** +1. **Clone the MaxText repository** - ``` - git clone https://github.com/google/maxtext.git - cd maxtext - ``` + ``` + git clone https://github.com/google/maxtext.git + cd maxtext + ``` -2. **Build the image for your target hardware (TPU or GPU)** This script creates a local Docker image named `maxtext_base_image`. +2. **Build the image for your target hardware (TPU or GPU)** This script creates a local Docker image named `maxtext_base_image`. - - **For TPUs:** + - **For TPUs:** - ``` - bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable - ``` + ``` + bash docker_build_dependency_image.sh DEVICE=tpu MODE=stable + ``` - - **For GPUs:** + - **For GPUs:** - ``` - bash docker_build_dependency_image.sh DEVICE=gpu MODE=stable - ``` + ``` + bash docker_build_dependency_image.sh DEVICE=gpu MODE=stable + ``` -* * * * * +______________________________________________________________________ ## 5. Run your first MaxText job @@ -148,22 +149,22 @@ This section assumes you have an existing GKE cluster with either TPU or GPU nod This guide focuses on submitting workloads to an existing cluster. Cluster creation and management is a separate topic. For a comprehensive guide on all `xpk` commands, including `xpk cluster create`, please refer to the **[official XPK documentation](https://github.com/AI-Hypercomputer/xpk)**. ``` -1. **Set your configuration** +1. **Set your configuration** - ``` - export PROJECT_ID="your-gcp-project-id" - export ZONE="your-gcp-zone" # e.g., us-central1-a - export CLUSTER_NAME="your-existing-cluster-name" - export BASE_OUTPUT_DIR="gs://your-output-bucket/" - export DATASET_PATH="gs://your-dataset-bucket/" - ``` + ``` + export PROJECT_ID="your-gcp-project-id" + export ZONE="your-gcp-zone" # e.g., us-central1-a + export CLUSTER_NAME="your-existing-cluster-name" + export BASE_OUTPUT_DIR="gs://your-output-bucket/" + export DATASET_PATH="gs://your-dataset-bucket/" + ``` -2. **Configure gcloud CLI** +2. **Configure gcloud CLI** - ``` - gcloud config set project $PROJECT_ID - gcloud config set compute/zone $ZONE - ``` + ``` + gcloud config set project $PROJECT_ID + gcloud config set compute/zone $ZONE + ``` ### A Note on multi-slice and multi-node runs @@ -171,56 +172,56 @@ The examples below run on a single TPU slice (`--num-slices=1`) or a small numbe For instance, to run a job across **four TPU slices**, you would change `--num-slices=1` to `--num-slices=4`. This tells XPK to allocate four `v5litepod-256` slices and orchestrate the training job across all of them as a single workload. Similarly, for GPUs, you would increase the `--num-nodes` value. -3. **Create the workload (run the job)** +3. **Create the workload (run the job)** - - **On your TPU cluster:** + - **On your TPU cluster:** - ``` - xpk workload create\ - --cluster ${CLUSTER_NAME}\ - --workload ${USER}-tpu-job\ - --base-docker-image maxtext_base_image\ - --tpu-type v5litepod-256\ - --num-slices 1\ - --command "python3 -m MaxText.train src/MaxText/configs/base.yml run_name=${USER}-tpu-job base_output_directory=${BASE_OUTPUT_DIR} dataset_path=${DATASET_PATH} steps=100" - ``` + ``` + xpk workload create\ + --cluster ${CLUSTER_NAME}\ + --workload ${USER}-tpu-job\ + --base-docker-image maxtext_base_image\ + --tpu-type v5litepod-256\ + --num-slices 1\ + --command "python3 -m MaxText.train src/maxtext/configs/base.yml run_name=${USER}-tpu-job base_output_directory=${BASE_OUTPUT_DIR} dataset_path=${DATASET_PATH} steps=100" + ``` - - **On your GPU cluster:** + - **On your GPU cluster:** - ``` - xpk workload create\ - --cluster ${CLUSTER_NAME}\ - --workload ${USER}-gpu-job\ - --base-docker-image maxtext_base_image\ - --device-type h100-80gb-8\ - --num-nodes 2\ - --command "python3 -m MaxText.train src/MaxText/configs/base.yml run_name=${USER}-gpu-job base_output_directory=${BASE_OUTPUT_DIR} dataset_path=${DATASET_PATH} steps=100" - ``` + ``` + xpk workload create\ + --cluster ${CLUSTER_NAME}\ + --workload ${USER}-gpu-job\ + --base-docker-image maxtext_base_image\ + --device-type h100-80gb-8\ + --num-nodes 2\ + --command "python3 -m MaxText.train src/maxtext/configs/base.yml run_name=${USER}-gpu-job base_output_directory=${BASE_OUTPUT_DIR} dataset_path=${DATASET_PATH} steps=100" + ``` -* * * * * +______________________________________________________________________ ## 6. Managing and monitoring your job -- **View logs in real-time:** The easiest way to see the output of your training job is through the Google Cloud Console. +- **View logs in real-time:** The easiest way to see the output of your training job is through the Google Cloud Console. - 1. Navigate to the **Kubernetes Engine** section. + 1. Navigate to the **Kubernetes Engine** section. - 2. Go to **Workloads**. + 2. Go to **Workloads**. - 3. Find your workload (e.g., `${USER}-tpu-job`) and click on it. + 3. Find your workload (e.g., `${USER}-tpu-job`) and click on it. - 4. Select the **Logs** tab to view the container logs. + 4. Select the **Logs** tab to view the container logs. -- **List your jobs:** +- **List your jobs:** - ``` - xpk workload list --cluster ${CLUSTER_NAME} - ``` + ``` + xpk workload list --cluster ${CLUSTER_NAME} + ``` -- **Analyze output:** Checkpoints and other artifacts will be saved to the Google Cloud Storage bucket you specified in `BASE_OUTPUT_DIR`. +- **Analyze output:** Checkpoints and other artifacts will be saved to the Google Cloud Storage bucket you specified in `BASE_OUTPUT_DIR`. -- **Delete a job:** +- **Delete a job:** - ``` - xpk workload delete --cluster ${CLUSTER_NAME} --workload - ``` \ No newline at end of file + ``` + xpk workload delete --cluster ${CLUSTER_NAME} --workload + ``` diff --git a/docs/tutorials/first_run.md b/docs/tutorials/first_run.md index 960ecfbb98..1612385a22 100644 --- a/docs/tutorials/first_run.md +++ b/docs/tutorials/first_run.md @@ -24,7 +24,7 @@ This topic provides a basic introduction to get your MaxText workload up and run 1. To store logs and checkpoints, [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) in your project. To run MaxText, the TPU or GPU VMs must have read/write permissions for the bucket. These permissions are granted by service account roles, such as the `STORAGE ADMIN` role. -1. MaxText reads a yaml file for configuration. We also recommend reviewing the configurable options in `configs/base.yml`. This file includes a decoder-only model of ~1B parameters. The configurable options can be overwritten from the command line. For instance, you can change the `steps` or `log_period` by either modifying `configs/base.yml` or by passing in `steps` and `log_period` as additional arguments to the `train.py` call. Set `base_output_directory` to a folder in the bucket you just created. +2. MaxText reads a yaml file for configuration. We also recommend reviewing the configurable options in `configs/base.yml`. This file includes a decoder-only model of ~1B parameters. The configurable options can be overwritten from the command line. For instance, you can change the `steps` or `log_period` by either modifying `configs/base.yml` or by passing in `steps` and `log_period` as additional arguments to the `train.py` call. Set `base_output_directory` to a folder in the bucket you just created. ## Local development for single host @@ -36,8 +36,8 @@ Local development is a convenient way to run MaxText on a single host. It doesn' multiple hosts but is a good way to learn about MaxText. 1. [Create and SSH to the single host VM of your choice](https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm). You can use any available single host TPU, such as `v5litepod-8`, `v5p-8`, or `v4-8`. -1. Clone MaxText onto that TPU VM. -1. Within the root directory of the cloned repo, install dependencies and pre-commit hook by running: +2. Clone MaxText onto that TPU VM. +3. Within the root directory of the cloned repo, install dependencies and pre-commit hook by running: ```sh python3 -m venv ~/venv-maxtext @@ -49,7 +49,7 @@ pre-commit install 4. After installation completes, run training on synthetic data with the following command: ```sh -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ dataset_type=synthetic \ @@ -61,7 +61,7 @@ Optional: If you want to try training on a Hugging Face dataset, see [Data Input 5. To demonstrate model output, run the following command: ```sh -python3 -m maxtext.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/maxtext/configs/base.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ per_device_batch_size=1 @@ -80,10 +80,10 @@ You can use [demo_decoding.ipynb](https://github.com/AI-Hypercomputer/maxtext/bl ### Run MaxText on NVIDIA GPUs 1. Use `bash dependencies/scripts/docker_build_dependency_image.sh DEVICE=gpu` to build a container with the required dependencies. -1. After installation is complete, run training with the following command on synthetic data: +2. After installation is complete, run training with the following command on synthetic data: ```sh -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ dataset_type=synthetic \ @@ -93,7 +93,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \ 3. To demonstrate model output, run the following command: ```sh -python3 -m maxtext.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/maxtext/configs/base.yml \ run_name=$YOUR_JOB_NAME \ base_output_directory=gs:// \ per_device_batch_size=1 diff --git a/docs/tutorials/posttraining/full_finetuning.md b/docs/tutorials/posttraining/full_finetuning.md index 53444bbdfe..ff9254cfc2 100644 --- a/docs/tutorials/posttraining/full_finetuning.md +++ b/docs/tutorials/posttraining/full_finetuning.md @@ -82,7 +82,7 @@ MaxText provides examples to work with [Common Crawl](https://commoncrawl.org/). Run these steps once per project prior to any local development or cluster experiments. 1. Create two gcs buckets in your project, one for downloading and retrieving the dataset and the other for storing the logs. -1. Download the dataset in your gcs bucket. +2. Download the dataset in your gcs bucket. MaxText assumes these GCS buckets are created in the same project and that it has permissions to read and write from them. @@ -101,7 +101,7 @@ Below is a sample training script. ```sh python3 -m MaxText.train \ - src/MaxText/configs/base.yml \ + src/maxtext/configs/base.yml \ run_name=${RUN_NAME} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ load_parameters_path=${MODEL_CKPT_PATH} \ diff --git a/docs/tutorials/posttraining/knowledge_distillation.md b/docs/tutorials/posttraining/knowledge_distillation.md index a77803251d..8adb8bae3e 100644 --- a/docs/tutorials/posttraining/knowledge_distillation.md +++ b/docs/tutorials/posttraining/knowledge_distillation.md @@ -27,7 +27,7 @@ This tutorial focuses on **response-based knowledge distillation**, a technique - The pre-trained teacher model (running in vLLM) generates a new dataset of input-output pairs. - The student model is then trained on this teacher-generated dataset using standard fine-tuning techniques in MaxText. -1. **Online Distillation (Logit Matching):** +2. **Online Distillation (Logit Matching):** - During the training process, both the teacher model (which is typically frozen) and the student model process the same input data simultaneously. - The student model is trained by minimizing a loss function that encourages its output logits to match the logits produced by the teacher model for the same inputs. @@ -51,7 +51,7 @@ To install MaxText and its dependencies for post-training (including vLLM for th 1. Follow the [MaxText installation instructions](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#install-maxtext). -1. Install the additional dependencies for post-training: +2. Install the additional dependencies for post-training: ```bash bash tools/setup/setup_post_training_requirements.sh @@ -132,7 +132,7 @@ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu export PRE_TRAINED_MODEL_CKPT_DIRECTORY=${BASE_DIRECTORY}/llama3.1-8b-ckpt # Convert to MaxText format -python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/maxtext/configs/base.yml \ model_name=llama3.1-8b \ hf_access_token=${HF_TOKEN} \ base_output_directory=${PRE_TRAINED_MODEL_CKPT_DIRECTORY} \ @@ -170,7 +170,7 @@ You can now fine-tune your smaller student model using supervised fine-tuning te Example command to run fine-tuning on a TPU v6e-8: ```bash -python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \ +python3 -m MaxText.sft_trainer src/maxtext/configs/post_train/sft.yml \ run_name=${RUN_NAME} \ base_output_directory=${BASE_DIRECTORY}/distillation/qwen3-32b-distill-llama3.1-8b \ tokenizer_path=meta-llama/Llama-3.1-8B-Instruct tokenizer_type=huggingface \ @@ -209,7 +209,7 @@ largest_dir="${sorted_dirs[-1]}" FINE_TUNED_MODEL_CKPT_PATH=${CHECKPOINTS_PATH}/${largest_dir}/model_params # Fine-tune student model on original dataset -python3 -m MaxText.sft.sft_trainer src/MaxText/configs/sft.yml \ +python3 -m MaxText.sft.sft_trainer src/maxtext/configs/post_train/sft.yml \ run_name=${RUN_NAME}_stage2 \ base_output_directory=${BASE_DIRECTORY}/distillation/qwen3-32b-distill-llama3.1-8b \ tokenizer_path=meta-llama/Llama-3.1-8B-Instruct tokenizer_type=huggingface \ diff --git a/docs/tutorials/posttraining/multimodal.md b/docs/tutorials/posttraining/multimodal.md index 11c6982c66..c192f9c892 100644 --- a/docs/tutorials/posttraining/multimodal.md +++ b/docs/tutorials/posttraining/multimodal.md @@ -38,7 +38,7 @@ Then use this command to convert an unscanned checkpoint from HuggingFace to Max ```shell export HF_ACCESS_TOKEN=hf_... export MAXTEXT_CKPT_GCS_PATH=gs://... -python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \ +python -m MaxText.utils.ckpt_conversion.to_maxtext maxtext/configs/base.yml \ model_name=gemma3-4b \ hf_access_token=$HF_ACCESS_TOKEN \ base_output_directory=$MAXTEXT_CKPT_GCS_PATH \ @@ -73,7 +73,7 @@ To run a forward pass and verify the model's output, use the following command: ```shell # Gemma3 decode python -m maxtext.decode \ - MaxText/configs/base.yml \ + maxtext/configs/base.yml \ model_name=gemma3-4b \ hf_access_token=$HF_ACCESS_TOKEN \ tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.gemma3 \ @@ -109,7 +109,7 @@ export TARGET_LENGTH=... # Adjust to fit expected output length export PREDICT_LENGTH=... # Adjust to fit image tokens + text prompt python -m maxtext.decode \ - MaxText/configs/base.yml \ + maxtext/configs/base.yml \ model_name=gemma3-4b \ ... \ max_prefill_predict_length=$PREDICT_LENGTH # Adjust to fit image tokens + text prompt \ @@ -123,14 +123,14 @@ For larger models such as Llama4-Scout/Maverick, we suggest to run the decoding ## Supervised Fine-Tuning -Supervised Fine-Tuning (SFT) of multimodal LLMs in MaxText focuses specifically on post-training; we don't yet support pre-training multimodal models from scratch. The SFT process typically involves training on Visual Question Answering (VQA) datasets where the model learns to generate accurate text responses based on both visual and textual inputs. During this fine-tuning phase, we recommend to freeze the pre-trained encoder layers (such as vision transformers) to preserve their learned visual representations, while the projection layers and LLM decoder components remain trainable. This selective training strategy allows the model to adapt the cross-modal alignment and text generation capabilities without disrupting the robust feature extraction abilities of the encoders, ultimately leading to improved performance on multimodal understanding and reasoning tasks while maintaining computational efficiency. This is achieved by setting `freeze_vision_encoder_params=True` in [sft-vision-chartqa.yml](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/sft-vision-chartqa.yml). +Supervised Fine-Tuning (SFT) of multimodal LLMs in MaxText focuses specifically on post-training; we don't yet support pre-training multimodal models from scratch. The SFT process typically involves training on Visual Question Answering (VQA) datasets where the model learns to generate accurate text responses based on both visual and textual inputs. During this fine-tuning phase, we recommend to freeze the pre-trained encoder layers (such as vision transformers) to preserve their learned visual representations, while the projection layers and LLM decoder components remain trainable. This selective training strategy allows the model to adapt the cross-modal alignment and text generation capabilities without disrupting the robust feature extraction abilities of the encoders, ultimately leading to improved performance on multimodal understanding and reasoning tasks while maintaining computational efficiency. This is achieved by setting `freeze_vision_encoder_params=True` in [sft-vision-chartqa.yml](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/post_train/sft-vision-chartqa.yml). Here, we use [ChartQA](https://huggingface.co/datasets/HuggingFaceM4/ChartQA) as an example to demonstrate SFT functionality: ```shell export UNSCANNED_CKPT_PATH=... # either set to an already available MaxText ckpt or to the one we just converted in the previous step python -m MaxText.sft_trainer \ - src/MaxText/configs/sft-vision-chartqa.yml \ + src/maxtext/configs/post_train/sft-vision-chartqa.yml \ run_name="chartqa-sft" \ model_name=gemma3-4b \ tokenizer_path="google/gemma-3-4b-it" \ diff --git a/docs/tutorials/posttraining/rl.md b/docs/tutorials/posttraining/rl.md index f77af73a80..91024344c0 100644 --- a/docs/tutorials/posttraining/rl.md +++ b/docs/tutorials/posttraining/rl.md @@ -102,15 +102,15 @@ and `vllm`, follow these steps: [MaxText Package Tests](https://github.com/AI-Hypercomputer/maxtext/actions/workflows/build_and_test_maxtext.yml?query=event%3Aschedule) GitHub Actions workflow. -1. Select the latest successful run. +2. Select the latest successful run. -1. Within the workflow run, find and click on the `maxtext_jupyter_notebooks (py312)` job, then expand the `run` job. +3. Within the workflow run, find and click on the `maxtext_jupyter_notebooks (py312)` job, then expand the `run` job. -1. Locate the `Record Commit IDs` step. The commit SHAs for `maxtext`, `tunix`, +4. Locate the `Record Commit IDs` step. The commit SHAs for `maxtext`, `tunix`, `tpu-inference`, and `vllm` that were used in that successful run are listed in the logs of this step. -1. Prior to installation, ensure that the `maxtext`, `tunix`, `vllm`, and `tpu-inference` repositories are synchronized to the specific commits recorded from the CI logs. For each repository, use the following command to switch to the correct commit: `git checkout `. +5. Prior to installation, ensure that the `maxtext`, `tunix`, `vllm`, and `tpu-inference` repositories are synchronized to the specific commits recorded from the CI logs. For each repository, use the following command to switch to the correct commit: `git checkout `. ## Setup environment variables @@ -153,7 +153,7 @@ export MAXTEXT_CKPT_PATH= # e.g., gs://my-bucke Run the following command for GRPO: ``` -python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ +python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \ model_name=${MODEL} \ tokenizer_path=${TOKENIZER} \ load_parameters_path=${MAXTEXT_CKPT_PATH} \ @@ -166,9 +166,9 @@ The overview of what this run will do is as follows: 1. We load a policy model and a reference model. Both are copies of the model checkpoint you specified (e.g., `Llama3.1-8b-Instruct`). -1. Evaluate the policy model's performance on GSM8K math reasoning benchmark. -1. Train the policy model using GRPO. -1. Evaluate the policy model's performance on GSM8K math reasoning benchmark +2. Evaluate the policy model's performance on GSM8K math reasoning benchmark. +3. Train the policy model using GRPO. +4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO. ## Run GSPO @@ -176,7 +176,7 @@ The overview of what this run will do is as follows: Run the following command for GSPO: ``` -python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ +python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \ model_name=${MODEL} \ tokenizer_path=${TOKENIZER} \ load_parameters_path=${MAXTEXT_CKPT_PATH} \ @@ -190,7 +190,7 @@ The overview of what this run will do is as follows: 1. We load a policy model and a reference model. Both are copies of the model checkpoint you specified (e.g., `Llama3.1-8b-Instruct`). -1. Evaluate the policy model's performance on GSM8K math reasoning benchmark. -1. Train the policy model using GSPO. -1. Evaluate the policy model's performance on GSM8K math reasoning benchmark +2. Evaluate the policy model's performance on GSM8K math reasoning benchmark. +3. Train the policy model using GSPO. +4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GSPO. diff --git a/docs/tutorials/posttraining/rl_on_multi_host.md b/docs/tutorials/posttraining/rl_on_multi_host.md index fcee4ad20d..422b822131 100644 --- a/docs/tutorials/posttraining/rl_on_multi_host.md +++ b/docs/tutorials/posttraining/rl_on_multi_host.md @@ -159,15 +159,15 @@ docker image with these local sources. To get a set of compatible commit IDs for [MaxText Package Tests](https://github.com/AI-Hypercomputer/maxtext/actions/workflows/build_and_test_maxtext.yml?query=event%3Aschedule) GitHub Actions workflow. -1. Select the latest successful run. +2. Select the latest successful run. -1. Within the workflow run, find and click on the `maxtext_jupyter_notebooks (py312)` job, then expand the `run` job. +3. Within the workflow run, find and click on the `maxtext_jupyter_notebooks (py312)` job, then expand the `run` job. -1. Locate the `Record Commit IDs` step. The commit SHAs for `maxtext`, `tunix`, +4. Locate the `Record Commit IDs` step. The commit SHAs for `maxtext`, `tunix`, `tpu-inference`, and `vllm` that were used in that successful run are listed in the logs of this step. -1. Prior to installation, ensure that the `maxtext`, `tunix`, `vllm`, and `tpu-inference` repositories are synchronized to the specific commits recorded from the CI logs. For each repository, use the following command to switch to the correct commit: `git checkout `. +5. Prior to installation, ensure that the `maxtext`, `tunix`, `vllm`, and `tpu-inference` repositories are synchronized to the specific commits recorded from the CI logs. For each repository, use the following command to switch to the correct commit: `git checkout `. **Note:** Clone these repositories as siblings of the `maxtext` directory (e.g., in the same parent directory). After cloning, run the build from inside the @@ -208,7 +208,7 @@ xpk workload create-pathways --workload $WORKLOAD \ --tpu-type=$TPU_TYPE --num-slices=1 \ --project=$PROJECT_ID --priority=high \ --command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \ -python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ +python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \ model_name=${MODEL} \ tokenizer_path=${TOKENIZER} \ load_parameters_path=${MAXTEXT_CKPT_PATH} \ @@ -225,7 +225,7 @@ xpk workload create-pathways --workload $WORKLOAD \ --tpu-type=$TPU_TYPE --num-slices=1 \ --project=$PROJECT_ID --priority=high \ --command "HF_TOKEN=${HF_TOKEN} TF_CPP_MIN_LOG_LEVEL=0 JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE='1' \ -python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ +python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \ model_name=${MODEL} \ tokenizer_path=${TOKENIZER} \ load_parameters_path=${MAXTEXT_CKPT_PATH} \ diff --git a/docs/tutorials/posttraining/sft.md b/docs/tutorials/posttraining/sft.md index bb67b47a71..841a59879a 100644 --- a/docs/tutorials/posttraining/sft.md +++ b/docs/tutorials/posttraining/sft.md @@ -89,7 +89,7 @@ export PRE_TRAINED_MODEL_CKPT_PATH= # e.g., gs: Now you are ready to run SFT using the following command: ```sh -python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml \ +python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml \ run_name=${RUN_NAME} \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ model_name=${PRE_TRAINED_MODEL} \ diff --git a/docs/tutorials/posttraining/sft_on_multi_host.md b/docs/tutorials/posttraining/sft_on_multi_host.md index 26ff8b1d37..063a452ec8 100644 --- a/docs/tutorials/posttraining/sft_on_multi_host.md +++ b/docs/tutorials/posttraining/sft_on_multi_host.md @@ -143,7 +143,7 @@ xpk workload create \ --workload=${WORKLOAD_NAME} \ --tpu-type=${TPU_TYPE} \ --num-slices=${TPU_SLICE} \ ---command "python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane hf_path=$DATASET_NAME train_split=$TRAIN_SPLIT train_data_columns=$TRAIN_DATA_COLUMNS" +--command "python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane hf_path=$DATASET_NAME train_split=$TRAIN_SPLIT train_data_columns=$TRAIN_DATA_COLUMNS" ``` Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`. @@ -159,7 +159,7 @@ xpk workload create-pathways \ --workload=${WORKLOAD_NAME} \ --tpu-type=${TPU_TYPE} \ --num-slices=${TPU_SLICE} \ ---command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True" +--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml run_name=$WORKLOAD_NAME base_output_directory=$OUTPUT_PATH model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH hf_access_token=$HF_TOKEN tokenizer_path=$TOKENIZER_PATH per_device_batch_size=1 steps=$STEPS profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True" ``` Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`. diff --git a/docs/tutorials/pretraining.md b/docs/tutorials/pretraining.md index a59eda2406..dfb53da3eb 100644 --- a/docs/tutorials/pretraining.md +++ b/docs/tutorials/pretraining.md @@ -15,27 +15,27 @@ --> (pretraining)= + # Pre-training -In this tutorial, we introduce how to run pretraining with real datasets. While synthetic data is commonly used for benchmarking, we rely on real datasets to obtain meaningful weights. Currently, MaxText supports three dataset input pipelines: HuggingFace, Grain, and TensorFlow Datasets (TFDS). We will walk you through: setting up dataset, modifying the [dataset configs](https://github.com/AI-Hypercomputer/maxtext/blob/08d9f20329ab55b9b928543fedd28ad173e1cd97/src/MaxText/configs/base.yml#L486-L514) and [tokenizer configs](https://github.com/AI-Hypercomputer/maxtext/blob/08d9f20329ab55b9b928543fedd28ad173e1cd97/src/MaxText/configs/base.yml#L452-L455) for training, and optionally enabling evaluation. +In this tutorial, we introduce how to run pretraining with real datasets. While synthetic data is commonly used for benchmarking, we rely on real datasets to obtain meaningful weights. Currently, MaxText supports three dataset input pipelines: HuggingFace, Grain, and TensorFlow Datasets (TFDS). We will walk you through: setting up dataset, modifying the [dataset configs](https://github.com/AI-Hypercomputer/maxtext/blob/08d9f20329ab55b9b928543fedd28ad173e1cd97/src/maxtext/configs/base.yml#L486-L514) and [tokenizer configs](https://github.com/AI-Hypercomputer/maxtext/blob/08d9f20329ab55b9b928543fedd28ad173e1cd97/src/maxtext/configs/base.yml#L452-L455) for training, and optionally enabling evaluation. To start with, we focus on HuggingFace datasets for convenience. + - Later on, we will give brief examples for Grain and TFDS. For a comprehensive guide, see the [Data Input Pipeline](../guides/data_input_pipeline.md) topic. - For demonstration, we use Deepseek-V2-Lite model and C4 dataset. C4 stands for "Colossal Clean Crawled Corpus", a high-quality pretraining dataset first introduced by Google's [T5](https://arxiv.org/pdf/1910.10683) work. Feel free to try other models and datasets. - ## 1. HuggingFace pipeline We use the HuggingFace dataset [allenai/c4](https://huggingface.co/datasets/allenai/c4), a processed version of Google's C4. This dataset is organized into subsets (e.g., `en`, `es`), and each subset contains data splits (e.g., `train`, `validation`). - **Data preparation**: You don't need to download data, as the pipeline can stream data directly from the HuggingFace Hub. Alternatively, it can stream from a Cloud Storage bucket; see the [HuggingFace Pipeline](../guides/data_input_pipeline/data_input_hf.md) page. - We can use this **command** for pretraining: + ```bash # replace base_output_directory with your bucket -python3 -m MaxText.train MaxText/configs/base.yml \ +python3 -m MaxText.train maxtext/configs/base.yml \ base_output_directory=gs://runner-maxtext-logs run_name=demo \ model_name=deepseek2-16b per_device_batch_size=1 steps=10 max_target_length=2048 enable_checkpointing=false \ dataset_type=hf hf_path=allenai/c4 hf_data_dir=en train_split=train \ @@ -43,22 +43,26 @@ tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V2-Lite ``` **Dataset config**: + - `dataset_type`: `hf` - `hf_path`: the HuggingFace dataset repository is `allenai/c4` -- `hf_data_dir`: the subset is `en`, corresponding to English data. +- `hf_data_dir`: the subset is `en`, corresponding to English data. - `train_split`: `train`. Training will use the `train` split. The above command runs training only: `steps=10` on the `train` split, for `en` subset of `allenai/c4`. The log shows: + ``` completed step: 1, seconds: 0.287, TFLOP/s/device: 110.951, Tokens/s/device: 7131.788, total_weights: 7517, loss: 12.021 ... completed step: 9, seconds: 1.010, TFLOP/s/device: 31.541, Tokens/s/device: 2027.424, total_weights: 7979, loss: 9.436 ``` + The total weights is the number of real tokens processed in each step. More explanation can be found in [Understand Logs and Metrics](understand-logs-and-metrics) page. **Evaluation config (optional)**: To add evaluation steps, we can specify a positive evaluation interval and the dataset split, for instance `eval_interval=5 eval_steps=10 hf_eval_split=validation`. For every 5 training step, we run evaluation for 10 steps, using the `validation` split. In the log, you will additionally see: + ``` Completed eval step 0 ... @@ -71,19 +75,20 @@ eval metrics after step: 9, loss=9.420, total_weights=75264.0 ``` **Tokenizer config**: + - `tokenizer_type`: `huggingface`. Note HuggingFace input pipeline only supports HuggingFace tokenizer. - `tokenizer_path`: `deepseek-ai/DeepSeek-V2-Lite`, corresponding to the HuggingFace [model repository](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/tree/main). **HuggingFace access token (optional)**: -- For a [gated dataset](https://huggingface.co/docs/hub/en/datasets-gated) or a tokenizer from a [gated model](https://huggingface.co/docs/hub/en/models-gated), you need to request access on HuggingFace and provide `hf_access_token=` in the command. For instance, [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) is a gated model. +- For a [gated dataset](https://huggingface.co/docs/hub/en/datasets-gated) or a tokenizer from a [gated model](https://huggingface.co/docs/hub/en/models-gated), you need to request access on HuggingFace and provide `hf_access_token=` in the command. For instance, [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) is a gated model. ## 2. Grain pipeline Grain is a library for reading data for training and evaluating JAX models. It is the recommended input pipeline for determinism and resilience! It supports data formats like ArrayRecord and Parquet. You can check [Grain pipeline](../guides/data_input_pipeline/data_input_grain.md) for more details. - **Data preparation**: You need to download data to a Cloud Storage bucket, and read data via Cloud Storage Fuse with [setup_gcsfuse.sh](https://github.com/AI-Hypercomputer/maxtext/blob/0baff00ac27bb7996c62057f235cc1d2f43d734e/setup_gcsfuse.sh#L18). + - For example, we can mount the bucket `gs://maxtext-dataset` on the local path `/tmp/gcsfuse` before training ```bash bash setup_gcsfuse.sh DATASET_GCS_BUCKET=maxtext-dataset MOUNT_PATH=/tmp/gcsfuse @@ -93,11 +98,11 @@ Grain is a library for reading data for training and evaluating JAX models. It i fusermount -u /tmp/gcsfuse ``` - This **command** shows pretraining with Grain pipeline, along with evaluation: + ```bash # replace DATASET_GCS_BUCKET and base_output_directory with your buckets -python3 -m MaxText.train MaxText/configs/base.yml \ +python3 -m MaxText.train maxtext/configs/base.yml \ base_output_directory=gs://runner-maxtext-logs run_name=demo \ model_name=deepseek2-16b per_device_batch_size=1 steps=10 max_target_length=2048 enable_checkpointing=false \ dataset_type=grain grain_file_type=arrayrecord grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record* grain_worker_count=2 \ @@ -106,31 +111,35 @@ tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V2-Lite ``` **Dataset config**: + - `dataset_type`: `grain` - `grain_file_type`: `arrayrecord`. We also support `parquet`. -- `grain_train_files`: `/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*`, which is a regex pattern. +- `grain_train_files`: `/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*`, which is a regex pattern. - `grain_worker_count`: `2`. This parameter controls the number of child processes used by Grain, which should be tuned for performance. **Evaluation config (optional)**: + - `eval_interval=5 eval_steps=10`: after every 5 train steps, perform 10 evaluation steps - `grain_eval_files`: `/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-validation.array_record*`, which is a regex pattern. -**Tokenizer config**: +**Tokenizer config**: + - The Grain pipeline supports tokenizer_type: `sentencepiece, huggingface` - Here we use the same `huggingface` tokenizer as in Section 1. If you use a HuggingFace tokenizer from a gated model, you will need to provide `hf_access_token`. - ## 3. TFDS pipeline The TensorFlow Datasets (TFDS) pipeline uses dataset in the TFRecord format. You can check [TFDS Pipeline](../guides/data_input_pipeline/data_input_tfds.md) for more details. **Data preparation**: You need to download data to a [Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets), and the pipeline streams data from the bucket. + - To download the AllenAI C4 dataset to your bucket, you can use [download_dataset.sh](https://github.com/AI-Hypercomputer/maxtext/blob/08d9f20329ab55b9b928543fedd28ad173e1cd97/download_dataset.sh#L19): `bash download_dataset.sh ` This **command** shows pretraining with TFDS pipeline, along with evaluation: + ```bash # replace base_output_directory and dataset_path with your buckets -python3 -m MaxText.train MaxText/configs/base.yml \ +python3 -m MaxText.train maxtext/configs/base.yml \ base_output_directory=gs://runner-maxtext-logs run_name=demo \ model_name=deepseek2-16b per_device_batch_size=1 steps=10 max_target_length=2048 enable_checkpointing=false \ dataset_type=tfds dataset_path=gs://maxtext-dataset dataset_name='c4/en:3.0.1' train_split=train \ @@ -139,18 +148,21 @@ tokenizer_type=huggingface tokenizer_path=deepseek-ai/DeepSeek-V2-Lite ``` **Dataset config**: + - `dataset_type`: `tfds` - `dataset_path`: the cloud storage bucket is `gs://maxtext-dataset` - `dataset_name`: `c4/en:3.0.1` corresponds to the subdirectory inside dataset_path `gs://maxtext-dataset/c4/en/3.0.1` -- `train_split`: `train`, corresponds to `*-train.tfrecord-*` files +- `train_split`: `train`, corresponds to `*-train.tfrecord-*` files - Putting together, we are training on files like `gs://maxtext-dataset/c4/en/3.0.1/c4-train.tfrecord-0000-of-01024` **Evaluation config (optional)**: + - `eval_interval=5 eval_steps=10`: after every 5 train steps, perform 10 evaluation steps - `eval_dataset_name`: `c4/en:3.0.1`, corresponds to the subdirectory inside dataset_path `gs://maxtext-dataset/c4/en/3.0.1`. It can be different from `dataset_name`. -- `eval_split`: `validation`, corresponds to `*-validation.tfrecord-*` files +- `eval_split`: `validation`, corresponds to `*-validation.tfrecord-*` files - Putting together, we are evaluating on files like `gs://maxtext-dataset/c4/en/3.0.1/c4-validation.tfrecord-00000-of-00008` -**Tokenizer config**: +**Tokenizer config**: + - TFDS pipeline supports tokenizer_type: `sentencepiece, huggingface, tiktoken` - Here we use the same `huggingface` tokenizer as in Section 1. If you use a HuggingFace tokenizer from a gated model, you will need to provide `hf_access_token`. diff --git a/src/MaxText/configs/v5p/gpt3_175b/v5p_1024.sh b/src/MaxText/configs/v5p/gpt3_175b/v5p_1024.sh deleted file mode 100644 index 703863c7da..0000000000 --- a/src/MaxText/configs/v5p/gpt3_175b/v5p_1024.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -# GPT-3 175B Model. -# Train GPT-3 175B on v5p-1024 slice. - -# Example to invoke this script: -# bash src/MaxText/configs/v5p/gpt3_175b/v5p_1024.sh YOUR_RUN gs://YOUR_BUCKET" - -set -euox pipefail - -# Read arguments or use defaults from environment variables -RUNNAME=${1:-${RUNNAME:-some-run}} -BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} - -chmod +x "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/v5p/gpt3_175b/gpt3_175b_base.sh -./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 4 "full" 1 64 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" diff --git a/src/MaxText/configs/v5p/gpt3_175b/v5p_2048.sh b/src/MaxText/configs/v5p/gpt3_175b/v5p_2048.sh deleted file mode 100644 index e7617d13c7..0000000000 --- a/src/MaxText/configs/v5p/gpt3_175b/v5p_2048.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -# GPT-3 175B Model. -# Train GPT-3 175B on v5p-2048 slice. - -# Example to invoke this script: -# bash src/MaxText/configs/v5p/gpt3_175b/v5p_2048.sh YOUR_RUN gs://YOUR_BUCKET" - -set -euox pipefail - -# Read arguments or use defaults from environment variables -RUNNAME=${1:-${RUNNAME:-some-run}} -BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} - -chmod +x "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/v5p/gpt3_175b/gpt3_175b_base.sh -./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 2 "save_dot_except_mlpwi" 8 16 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file diff --git a/src/MaxText/configs/v5p/gpt3_175b/v5p_3072.sh b/src/MaxText/configs/v5p/gpt3_175b/v5p_3072.sh deleted file mode 100644 index c5085d6b3e..0000000000 --- a/src/MaxText/configs/v5p/gpt3_175b/v5p_3072.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash -# GPT-3 175B Model. -# Train GPT-3 175B on v5p-3072 slice. - -# Example to invoke this script: -# bash src/MaxText/configs/v5p/gpt3_175b/v5p_3072.sh YOUR_RUN gs://YOUR_BUCKET" - -set -euox pipefail - -# Read arguments or use defaults from environment variables -RUNNAME=${1:-${RUNNAME:-some-run}} -BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} - -chmod +x "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/v5p/gpt3_175b/gpt3_175b_base.sh -./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 2 "save_dot_except_mlpwi" 12 16 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file diff --git a/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md b/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md index 63bf0775e8..a8a94426e0 100644 --- a/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md +++ b/src/MaxText/experimental/agent/ckpt_conversion_agent/README.md @@ -18,7 +18,7 @@ pip install -q -U "google-genai>=1.0.0" The agent requires context files about the target and source model's parameter names and tensor shapes. You can generate them using the [`save_param.py`](../ckpt_conversion_agent/utils/save_param.py) script. The output directory defined by `config.base_output_directory`. The default is `src/MaxText/experimental/agent/ckpt_conversion_agent/context/` folder. ```bash -python3 -m MaxText.experimental.agent.ckpt_conversion_agent.utils.save_param src/MaxText/configs/base.yml \ +python3 -m MaxText.experimental.agent.ckpt_conversion_agent.utils.save_param src/maxtext/configs/base.yml \ per_device_batch_size=1 run_name=param_ model_name= scan_layers=false \ --hf_model_config= ``` @@ -66,7 +66,7 @@ If a ground-truth version isn't available, you'll need to debug the conversion m 3. After the conversion is done, run a decode to check the correctness of the generated code. Example command: ```bash -python3 -m maxtext.decode src/MaxText/configs/base.yml model_name=gemma3-4b tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.gemma3 \ +python3 -m maxtext.decode src/maxtext/configs/base.yml model_name=gemma3-4b tokenizer_path=src/maxtext/assets/tokenizers/tokenizer.gemma3 \ load_parameters_path= per_device_batch_size=1 run_name=ht_test \ max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=true \ prompt='I love to' attention='dot_product' @@ -75,7 +75,7 @@ If outputs are wrong, you can use jax.debug.print() to print the layer-wise mean 4. To further validate the converted checkpoint, we recommend to use the [forward_pass_logit_checker.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/README.md#verifying-conversion-correctness) to compare the original ckpt with the converted ckpt: ```bash -python3 -m tests.utils.forward_pass_logit_checker src/MaxText/configs/base.yml \ +python3 -m tests.utils.forward_pass_logit_checker src/maxtext/configs/base.yml \ tokenizer_path=assets/tokenizers/ \ load_parameters_path= \ model_name= \ diff --git a/src/MaxText/get_flops.py b/src/MaxText/get_flops.py index de06f57b62..763b79e249 100644 --- a/src/MaxText/get_flops.py +++ b/src/MaxText/get_flops.py @@ -14,9 +14,9 @@ """ A wrapper file for easily calculating training TFLOPs. """ -from MaxText.maxtext_utils import calculate_tflops_training_per_device from MaxText import pyconfig from MaxText.globals import MAXTEXT_PKG_DIR +from maxtext.utils.maxtext_utils import calculate_tflops_training_per_device import os from typing import Sequence, cast from absl import app diff --git a/src/MaxText/globals.py b/src/MaxText/globals.py index 547a1ae964..87129fc333 100644 --- a/src/MaxText/globals.py +++ b/src/MaxText/globals.py @@ -17,7 +17,7 @@ import os.path # This is the MaxText root: with "max_utils.py"; &etc. TODO: Replace `os.path.basename` with `os.path.abspath` -MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", os.path.basename(os.path.dirname(__file__))) +MAXTEXT_PKG_DIR = os.environ.get("MAXTEXT_PKG_DIR", os.path.join("src", "maxtext")) # This is the maxtext repo root: with ".git" folder; "README.md"; "pyproject.toml"; &etc. MAXTEXT_REPO_ROOT = os.environ.get( @@ -27,6 +27,11 @@ else MAXTEXT_PKG_DIR, ) +# This is the configs root: with "base.yml"; "models/"; &etc. +MAXTEXT_CONFIGS_DIR = os.environ.get( + "MAXTEXT_CONFIGS_DIR", os.path.join(os.path.dirname(os.path.dirname(__file__)), "maxtext", "configs") +) + # This is the assets root: with "tokenizers/"; &etc. MAXTEXT_ASSETS_ROOT = os.environ.get("MAXTEXT_ASSETS_ROOT", os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "assets")) @@ -40,6 +45,7 @@ "DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE", "EPS", "MAXTEXT_ASSETS_ROOT", + "MAXTEXT_CONFIGS_DIR", "MAXTEXT_PKG_DIR", "MAXTEXT_REPO_ROOT", "MAXTEXT_TEST_ASSETS_ROOT", diff --git a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py index 3281a1eb0a..ea63a371ae 100644 --- a/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -27,6 +27,7 @@ from maxtext.utils import max_logging from maxtext.utils import model_creation_utils + try: from tpu_inference.layers.common.attention_metadata import AttentionMetadata except ImportError: diff --git a/src/MaxText/layerwise_quantization.py b/src/MaxText/layerwise_quantization.py index 83e269099e..8ac86013f7 100644 --- a/src/MaxText/layerwise_quantization.py +++ b/src/MaxText/layerwise_quantization.py @@ -19,7 +19,7 @@ Example cmd: -python3 -m MaxText.layerwise_quantization src/MaxText/configs/base.yml \ +python3 -m MaxText.layerwise_quantization src/maxtext/configs/base.yml \ tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${LOAD_PARAMS_PATH} \ model_name=deepseek2-16b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 \ ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1 \ diff --git a/src/MaxText/pyconfig.py b/src/MaxText/pyconfig.py index 6a3a160a28..025ec429a0 100644 --- a/src/MaxText/pyconfig.py +++ b/src/MaxText/pyconfig.py @@ -27,8 +27,9 @@ from MaxText import pyconfig_deprecated from MaxText.common_types import DecoderBlockType, ShardMode -from MaxText.configs import types -from MaxText.configs.types import MaxTextConfig +from MaxText.globals import MAXTEXT_PKG_DIR +from maxtext.configs import types +from maxtext.configs.types import MaxTextConfig from maxtext.inference.inference_utils import str2bool from maxtext.utils import max_utils @@ -83,8 +84,7 @@ def _load_config(config_name: str) -> omegaconf.DictConfig: # Search relative to current config, then in the default configs folder loaded_parent_config_filename = os.path.join(os.path.dirname(config_name), base_path) if not os.path.isfile(loaded_parent_config_filename): - dir_path = os.path.dirname(os.path.realpath(__file__)) - loaded_parent_config_filename = os.path.join(dir_path, "configs", base_path) + loaded_parent_config_filename = os.path.join(MAXTEXT_PKG_DIR, "configs", base_path) else: loaded_parent_config_filename = base_path diff --git a/src/MaxText/rl/train_rl.py b/src/MaxText/rl/train_rl.py index bb227d34b6..e5584b3094 100644 --- a/src/MaxText/rl/train_rl.py +++ b/src/MaxText/rl/train_rl.py @@ -23,7 +23,7 @@ Usage Examples: # GRPO on Llama3.1-8B-Instruct -python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ +python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \ model_name=llama3.1-8b \ tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ load_parameters_path=gs://path/to/checkpoint/0/items \ @@ -32,7 +32,7 @@ hf_access_token=$HF_TOKEN # GSPO on Llama3.1-70B-Instruct -python3 -m src.MaxText.rl.train_rl src/MaxText/configs/rl.yml \ +python3 -m src.MaxText.rl.train_rl src/maxtext/configs/post_train/rl.yml \ model_name=llama3.1-70b \ tokenizer_path=meta-llama/Llama-3.1-70B-Instruct \ load_parameters_path=gs://path/to/checkpoint/0/items \ @@ -487,8 +487,8 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): "enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics." ) - pkg_dir = os.environ.get("MAXTEXT_PKG_DIR", MAXTEXT_PKG_DIR) - vllm_config_path = epath.Path(pkg_dir) / "configs" / "vllm.yml" + configs_dir = os.environ.get("MAXTEXT_CONFIGS_DIR", os.path.join(MAXTEXT_PKG_DIR, "configs")) + vllm_config_path = epath.Path(configs_dir) / "vllm.yml" argv_list = ["", str(vllm_config_path), "log_config=False"] vllm_config = pyconfig.initialize(argv_list) diff --git a/src/MaxText/utils/ckpt_conversion/compare_hf_ckpt.py b/src/MaxText/utils/ckpt_conversion/compare_hf_ckpt.py index 2d32dcee79..65c8a65c33 100644 --- a/src/MaxText/utils/ckpt_conversion/compare_hf_ckpt.py +++ b/src/MaxText/utils/ckpt_conversion/compare_hf_ckpt.py @@ -16,7 +16,7 @@ Verify the converted safetensor checkpoint (GCS or local) matches the HuggingFace checkpoint reference. Usage to compare converted safetensor with remote HF reference: -JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_conversion.compare_hf_ckpt src/MaxText/configs/base.yml \ +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_conversion.compare_hf_ckpt src/maxtext/configs/base.yml \ model_name= \ hf_access_token= \ hardware=cpu \ @@ -24,7 +24,7 @@ --atol=1e-2 --rtol=1e-2 --max_workers=12 Usage to compare converted safetensor with GCS/Local HF reference: -JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_conversion.compare_hf_ckpt src/MaxText/configs/base.yml \ +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_conversion.compare_hf_ckpt src/maxtext/configs/base.yml \ hardware=cpu \ --candidate_path= \ --reference_path= \ diff --git a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_hf.sh b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_hf.sh index e23f253ad4..035b4f2b43 100644 --- a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_hf.sh +++ b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_hf.sh @@ -20,7 +20,7 @@ SCAN_LAYERS=false echo "Starting Hugging Face model conversion for gemma2-2b..." python3 -m MaxText.utils.ckpt_conversion.to_huggingface \ - "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}configs/base.yml" \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml" \ model_name="${MODEL_NAME}" \ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.gemma" \ load_parameters_path="${MAXTEXT_CHECKPOINT_DIR}" \ @@ -47,7 +47,7 @@ gsutil -m cp -r "${HF_CHECKPOINT_GCS_PATH}/*" "${LOCAL_HF_CHECKPOINT_DIR}/" echo "Download complete." python3 -m tests.utils.forward_pass_logit_checker \ - "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}configs/base.yml" \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml" \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path="${MAXTEXT_CHECKPOINT_DIR}"\ run_name=forward_pass_test_${MODEL_NAME}\ diff --git a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh index 5e580f6137..f35aab0065 100644 --- a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh +++ b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma2_to_mt.sh @@ -21,7 +21,7 @@ PROMPT="I love to" # --- Step 1: Convert Checkpoint to MaxText Format --- echo "--- Starting Checkpoint Conversion ---" python3 -m MaxText.utils.ckpt_conversion.to_maxtext \ - "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}configs/base.yml" \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml" \ model_name="${MODEL_NAME}" \ base_output_directory="${OUTPUT_BASE_DIR}" \ per_device_batch_size="${PER_DEVICE_BATCH_SIZE}" \ @@ -35,7 +35,7 @@ echo "--- Checkpoint Conversion Complete ---" echo "--- Starting Decoding ---" python3 -m maxtext.decode \ - ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}configs/base.yml \ + ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml \ model_name="${MODEL_NAME}" \ tokenizer_path="${TOKENIZER_PATH}" \ load_parameters_path="${OUTPUT_BASE_DIR}/0/items" \ @@ -55,7 +55,7 @@ echo "--- Decoding Complete ---" echo "--- Starting Comparing Logits and Predicted Tokens ---" python3 -m tests.utils.forward_pass_logit_checker \ - ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml \ + ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path="${OUTPUT_BASE_DIR}/0/items"\ run_name=forward_pass_test_${MODEL_NAME}\ diff --git a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma3_to_hf.sh b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma3_to_hf.sh index e00ba56974..373dd179d0 100644 --- a/src/MaxText/utils/ckpt_conversion/examples/convert_gemma3_to_hf.sh +++ b/src/MaxText/utils/ckpt_conversion/examples/convert_gemma3_to_hf.sh @@ -19,7 +19,7 @@ SCAN_LAYERS=false echo "Starting Hugging Face model conversion for gemma3-4b..." python3 -m MaxText.utils.ckpt_conversion.to_huggingface \ - "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml" \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml" \ model_name="gemma3-4b" \ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/tokenizer.gemma3" \ load_parameters_path="${MAXTEXT_CHECKPOINT_DIR}" \ @@ -47,7 +47,7 @@ gsutil -m cp -r "${HF_CHECKPOINT_GCS_PATH}/*" "${LOCAL_HF_CHECKPOINT_DIR}/" echo "Download complete." python3 -m tests.utils.forward_pass_logit_checker \ - "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml" \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml" \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path="${MAXTEXT_CHECKPOINT_DIR}"\ run_name=forward_pass_test_${MODEL_NAME}\ diff --git a/src/MaxText/utils/ckpt_conversion/to_huggingface.py b/src/MaxText/utils/ckpt_conversion/to_huggingface.py index 4cf65f464f..40a0fae77c 100644 --- a/src/MaxText/utils/ckpt_conversion/to_huggingface.py +++ b/src/MaxText/utils/ckpt_conversion/to_huggingface.py @@ -40,7 +40,7 @@ export HF_AUTH_TOKEN="hf_YOUR_TOKEN" python src/MaxText/utils/ckpt_conversion/to_huggingface.py \ - src/MaxText/configs/base.yml \ + src/maxtext/configs/base.yml \ model_name="gemma2-2b" \ load_parameters_path="/path/to/your/maxtext/checkpoint/" \ base_output_directory="/path/to/your/output/directory" \ diff --git a/src/MaxText/utils/ckpt_conversion/to_maxtext.py b/src/MaxText/utils/ckpt_conversion/to_maxtext.py index e3c0fb6106..d738a8cbf9 100644 --- a/src/MaxText/utils/ckpt_conversion/to_maxtext.py +++ b/src/MaxText/utils/ckpt_conversion/to_maxtext.py @@ -41,7 +41,7 @@ To convert a gemma2-2b model and save it to a specific directory: /usr/bin/time -v python src/MaxText/utils/ckpt_conversion/to_maxtext.py \ - MaxText/configs/base.yml model_name="gemma2-2b" \ + maxtext/configs/base.yml model_name="gemma2-2b" \ base_output_directory="/path/to/your/output/directory" \ hf_access_token=$HF_TOKEN hardware=cpu skip_jax_distributed_system=True \ scan_layers=False @@ -52,7 +52,7 @@ To convert a 70B model with minimal RAM usage: /usr/bin/time -v python src/MaxText/utils/ckpt_conversion/to_maxtext.py \ - MaxText/configs/base.yml model_name="meta-llama/Llama-3.1-70B" \ + maxtext/configs/base.yml model_name="meta-llama/Llama-3.1-70B" \ base_output_directory="gs://my-bucket/maxtext-checkpoints" \ hf_access_token=$HF_TOKEN hardware=cpu skip_jax_distributed_system=True \ --lazy_load_tensors=True diff --git a/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py b/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py index af917bf401..d26bab5f58 100644 --- a/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py +++ b/src/MaxText/utils/ckpt_scripts/llama_mistral_mixtral_orbax_to_hf.py @@ -23,7 +23,7 @@ python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt --base-model-path \ --maxtext-model-path --model-size llama2-7b -python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf src/MaxText/configs/base.yml +python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf src/maxtext/configs/base.yml base_output_directory=path/to/saving/intermediate_MaxText_files load_parameters_path=/path/to/src/MaxText/checkpoint run_name= model_name= hardware=gpu diff --git a/src/MaxText/configs/README.md b/src/maxtext/configs/README.md similarity index 62% rename from src/MaxText/configs/README.md rename to src/maxtext/configs/README.md index d487074cfc..9e175f2204 100644 --- a/src/MaxText/configs/README.md +++ b/src/maxtext/configs/README.md @@ -18,7 +18,7 @@ This directory contains high performance model configurations for different generations of TPU and GPU hardware. These configurations do 3 things: -* Sets various XLA compiler flags (see [below](/src/MaxText/configs#xla-flags-used-by-maxtext)) as `LIBTPU_INIT_ARGS` to optimize runtime performance. +* Sets various XLA compiler flags (see [below](/src/maxtext/configs#xla-flags-used-by-maxtext)) as `LIBTPU_INIT_ARGS` to optimize runtime performance. * Runs [rto_setup.sh](https://github.com/google/maxtext/blob/main/rto_setup.sh) to optimize communication protocols for network performance. (This only needs to be run once on each worker) * Runs [train.py](https://github.com/google/maxtext/blob/main/src/MaxText/train.py) with specific hyper-parameters (batch size, etc.) @@ -59,19 +59,19 @@ These configurations do 3 things: Running with `multihost_runner.py` on GCE: ``` - python3 multihost_runner.py --TPU_PREFIX=${TPU_PREFIX} --COMMAND="bash setup.sh && bash src/MaxText/configs/v5p/128b.sh RUN_NAME=${YOUR_RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} DATASET_PATH=${MAXTEXT_DATASET_PATH} PLATFORM=gce" + python3 multihost_runner.py --TPU_PREFIX=${TPU_PREFIX} --COMMAND="bash setup.sh && bash src/maxtext/configs/tpu/v5p/128b.sh RUN_NAME=${YOUR_RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} DATASET_PATH=${MAXTEXT_DATASET_PATH} PLATFORM=gce" ``` Running with `multihost_job.py` on GCE: ``` - python3 multihost_job.py --NUM_SLICES=${NUM_SLICES} --TPU_TYPE=${ACCELERATOR_TYPE} --VERSION=${RUNTIME_VERSION} --RUN_NAME=${RUN_NAME} --BUCKET_NAME=${GCS_BUCKET_NAME} --COMMAND="bash setup.sh && bash src/MaxText/configs/v5p/128b.sh RUN_NAME=${YOUR_RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} DATASET_PATH=${MAXTEXT_DATASET_PATH} PLATFORM=gce" + python3 multihost_job.py --NUM_SLICES=${NUM_SLICES} --TPU_TYPE=${ACCELERATOR_TYPE} --VERSION=${RUNTIME_VERSION} --RUN_NAME=${RUN_NAME} --BUCKET_NAME=${GCS_BUCKET_NAME} --COMMAND="bash setup.sh && bash src/maxtext/configs/tpu/v5p/128b.sh RUN_NAME=${YOUR_RUN_NAME} OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} DATASET_PATH=${MAXTEXT_DATASET_PATH} PLATFORM=gce" # Add --CQR_EXTRA_ARGS="--network=mtu9k" to the command if you would like to use the custom MTU network. ``` Running with `XPK` on GKE: ``` - xpk workload create --cluster ${YOUR_CLUSTER_NAME} --docker-image gcr.io/${PROJECT}/${YOUR_IMAGE_NAME} --workload ${YOUR_RUN_NAME} --tpu-type=${ACCELERATOR_TYPE} --num-slices=${NUM_SLICES} --command "bash src/MaxText/configs/v5p/128b.sh OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} DATASET_PATH=${MAXTEXT_DATASET_PATH} PLATFORM=gke" + xpk workload create --cluster ${YOUR_CLUSTER_NAME} --docker-image gcr.io/${PROJECT}/${YOUR_IMAGE_NAME} --workload ${YOUR_RUN_NAME} --tpu-type=${ACCELERATOR_TYPE} --num-slices=${NUM_SLICES} --command "bash src/maxtext/configs/tpu/v5p/128b.sh OUTPUT_PATH=${MAXTEXT_OUTPUT_PATH} DATASET_PATH=${MAXTEXT_DATASET_PATH} PLATFORM=gke" ``` Note: When running these scripts, be sure to specify the `PLATFORM` flag with the correct platform you are running on `"gce"` or `"gke"`. @@ -81,29 +81,29 @@ Here are some of the most common XLA compiler flags used by MaxText. | Flag | Type | Notes | | ---- | ---- |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| xla_tpu_enable_data_parallel_all_reduce_opt | Boolean (true/false) | Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding.
**Usage:** [v5p/32B](src/MaxText/configs/v5p/32b.sh) [v5p/64B](/MaxText/configs/v5p/64b.sh) [v5p/128B](/MaxText/configs/v5p/128b.sh) [v5p/256B](/MaxText/configs/v5p/256b.sh) [v5p/512B](/MaxText/configs/v5p/512b.sh) [v5p/1024B](/MaxText/configs/v5p/1024b.sh) [v5e/16B](/MaxText/configs/v5e/16B.sh) [v5e/32B](/MaxText/configs/v5e/32b.sh) [v5e/64B](/MaxText/configs/v5e/64b.sh) [v5e/128B](/MaxText/configs/v5e/128b.sh) [v5e/Llama2-7B](/MaxText/configs/v5e/llama2_7b.sh) [v5e/Llama2-13B](/MaxText/configs/v5e/llama2_13b.sh) [v5e/Llama2-70B](/MaxText/configs/v5e/llama2_70b.sh) [v5e/GPT3-175B](/MaxText/configs/v5e/gpt3_175b.sh) | -| xla_tpu_data_parallel_opt_different_sized_ops | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes doesn't match what can Be saved in place in the stacked variables. Can increase memory pressure.
**Usage:** [v5p/32B](/MaxText/configs/v5p/32b.sh) [v5p/64B](/MaxText/configs/v5p/64b.sh) [v5p/128B](/MaxText/configs/v5p/128b.sh) [v5p/256B](/MaxText/configs/v5p/256b.sh) [v5p/512B](/MaxText/configs/v5p/512b.sh) [v5p/1024B](/MaxText/configs/v5p/1024b.sh) [v5e/16B](/MaxText/configs/v5e/16B.sh) [v5e/32B](/MaxText/configs/v5e/32b.sh) [v5e/64B](/MaxText/configs/v5e/64b.sh) [v5e/128B](/MaxText/configs/v5e/128b.sh) [v5e/Llama2-7B](/MaxText/configs/v5e/llama2_7b.sh) [v5e/Llama2-13B](/MaxText/configs/v5e/llama2_13b.sh) [v5e/Llama2-70B](/MaxText/configs/v5e/llama2_70b.sh) [v5e/GPT3-175B](/MaxText/configs/v5e/gpt3_175b.sh) | -| xla_tpu_enable_async_collective_fusion | Boolean (true/false) | Enables the pass which fuses async collective communications with compute ops (output/loop-fusion or convolution) that are scheduled between their -start and -done instructions.
**Usage:** [v5p/32B](/MaxText/configs/v5p/32b.sh) [v5p/64B](/MaxText/configs/v5p/64b.sh) [v5p/128B](/MaxText/configs/v5p/128b.sh) [v5p/256B](/MaxText/configs/v5p/256b.sh) [v5p/512B](/MaxText/configs/v5p/512b.sh) [v5p/1024B](/MaxText/configs/v5p/1024b.sh) [v5e/16B](/MaxText/configs/v5e/16B.sh) [v5e/32B](/MaxText/configs/v5e/32b.sh) [v5e/64B](/MaxText/configs/v5e/64b.sh) [v5e/128B](/MaxText/configs/v5e/128b.sh) [v5e/Llama2-7B](/MaxText/configs/v5e/llama2_7b.sh) [v5e/Llama2-13B](/MaxText/configs/v5e/llama2_13b.sh) [v5e/Llama2-70B](/MaxText/configs/v5e/llama2_70b.sh) [v5e/GPT3-175B](/MaxText/configs/v5e/gpt3_175b.sh) | -| xla_tpu_enable_async_collective_fusion_fuse_all_gather | TristateFlag (true/false/kAuto) | Enables fusing all-gathers within the AsyncCollectiveFusion pass.
If set to ``kAuto``, it will be enabled based on the target."
**Usage:** [v5p/32B](/MaxText/configs/v5p/32b.sh) [v5p/64B](/MaxText/configs/v5p/64b.sh) [v5p/128B](/MaxText/configs/v5p/128b.sh) [v5p/256B](/MaxText/configs/v5p/256b.sh) [v5p/512B](/MaxText/configs/v5p/512b.sh) [v5p/1024B](/MaxText/configs/v5p/1024b.sh) [v5e/16B](/MaxText/configs/v5e/16B.sh) [v5e/32B](/MaxText/configs/v5e/32b.sh) [v5e/64B](/MaxText/configs/v5e/64b.sh) [v5e/128B](/MaxText/configs/v5e/128b.sh) [v5e/Llama2-7B](/MaxText/configs/v5e/llama2_7b.sh) [v5e/Llama2-13B](/MaxText/configs/v5e/llama2_13b.sh) [v5e/Llama2-70B](/MaxText/configs/v5e/llama2_70b.sh) [v5e/GPT3-175B](/MaxText/configs/v5e/gpt3_175b.sh) | -| xla_tpu_enable_async_collective_fusion_multiple_steps | Boolean (true/false) | Enables continuing the same async collective in multiple steps (fusions) in the AsyncCollectiveFusion pass.
**Usage:** [v5p/32B](/MaxText/configs/v5p/32b.sh) [v5p/64B](/MaxText/configs/v5p/64b.sh) [v5p/128B](/MaxText/configs/v5p/128b.sh) [v5p/256B](/MaxText/configs/v5p/256b.sh) [v5p/512B](/MaxText/configs/v5p/512b.sh) [v5p/1024B](/MaxText/configs/v5p/1024b.sh) [v5e/16B](/MaxText/configs/v5e/16B.sh) [v5e/32B](/MaxText/configs/v5e/32b.sh) [v5e/64B](/MaxText/configs/v5e/64b.sh) [v5e/128B](/MaxText/configs/v5e/128b.sh) [v5e/Llama2-7B](/MaxText/configs/v5e/llama2_7b.sh) [v5e/Llama2-13B](/MaxText/configs/v5e/llama2_13b.sh) [v5e/Llama2-70B](/MaxText/configs/v5e/llama2_70b.sh) [v5e/GPT3-175B](/MaxText/configs/v5e/gpt3_175b.sh) | -| xla_tpu_overlap_compute_collective_tc | Boolean (true/false) | Enables the overlap of compute and communication on a single TensorCore, i.e., one core equivalent of MegaCore fusion.
**Usage:** [v5p/32B](/MaxText/configs/v5p/32b.sh) [v5p/64B](/MaxText/configs/v5p/64b.sh) [v5p/128B](/MaxText/configs/v5p/128b.sh) [v5p/256B](/MaxText/configs/v5p/256b.sh) [v5p/512B](/MaxText/configs/v5p/512b.sh) [v5p/1024B](/MaxText/configs/v5p/1024b.sh) [v5e/16B](/MaxText/configs/v5e/16B.sh) [v5e/32B](/MaxText/configs/v5e/32b.sh) [v5e/64B](/MaxText/configs/v5e/64b.sh) [v5e/128B](/MaxText/configs/v5e/128b.sh) [v5e/Llama2-7B](/MaxText/configs/v5e/llama2_7b.sh) [v5e/Llama2-13B](/MaxText/configs/v5e/llama2_13b.sh) [v5e/Llama2-70B](/MaxText/configs/v5e/llama2_70b.sh) [v5e/GPT3-175B](/MaxText/configs/v5e/gpt3_175b.sh) | -| xla_enable_async_all_gather | TristateFlag (true/false/kAuto) | If set to true, enables async all gather. If ``kAuto``, enables only for platforms that implement async all-gather. The implementation (such as BC-offload or continuation fusion) is chosen based on other flag values.
**Usage:** [v4/22B](/MaxText/configs/v4/22B.sh) [v4/52B](/MaxText/configs/v4/52B.sh) [v5p/32B](/MaxText/configs/v5p/32b.sh) [v5p/64B](/MaxText/configs/v5p/64b.sh) [v5p/128B](/MaxText/configs/v5p/128b.sh) [v5p/256B](/MaxText/configs/v5p/256b.sh) [v5p/512B](/MaxText/configs/v5p/512b.sh) [v5p/1024B](/MaxText/configs/v5p/1024b.sh) [v5e/16B](/MaxText/configs/v5e/16B.sh) [v5e/32B](/MaxText/configs/v5e/32b.sh) [v5e/64B](/MaxText/configs/v5e/64b.sh) [v5e/128B](/MaxText/configs/v5e/128b.sh) [v5e/Llama2-7B](/MaxText/configs/v5e/llama2_7b.sh) [v5e/Llama2-13B](/MaxText/configs/v5e/llama2_13b.sh) [v5e/Llama2-70B](/MaxText/configs/v5e/llama2_70b.sh) [v5e/GPT3-175B](/MaxText/configs/v5e/gpt3_175b.sh) | -| xla_tpu_spmd_rng_bit_generator_unsafe | Boolean (true/false) | Whether to run RngBitGenerator HLO in a partitioned way, which is unsafe if deterministic results are expected with different shardings on different parts of the computation.
**Usage:** [v5e/GPT3-175B](/MaxText/configs/v5e/gpt3_175b.sh) | -| xla_tpu_megacore_fusion_allow_ags | Boolean (true/false) | Allows fusing all-gathers with convolutions/all-reduces.
**Usage:** [v5p/32B](/MaxText/configs/v5p/32b.sh) [v5p/64B](/MaxText/configs/v5p/64b.sh) [v5p/128B](/MaxText/configs/v5p/128b.sh) [v5p/256B](/MaxText/configs/v5p/256b.sh) [v5p/512B](/MaxText/configs/v5p/512b.sh) [v5p/1024B](/MaxText/configs/v5p/1024b.sh) | -| xla_tpu_enable_ag_backward_pipelining | Boolean (true/false) | Pipelines all-gathers (currently megascale all-gathers) backwards through scan loops.
**Usage:** [v5p/32B](/MaxText/configs/v5p/32b.sh) [v5p/64B](/MaxText/configs/v5p/64b.sh) [v5p/128B](/MaxText/configs/v5p/128b.sh) [v5p/256B](/MaxText/configs/v5p/256b.sh) [v5p/512B](/MaxText/configs/v5p/512b.sh) [v5p/1024B](/MaxText/configs/v5p/1024b.sh) | -| xla_enable_async_collective_permute | TristateFlag (true/false/kAuto) | Rewrites all collective-permute operations to their asynchronous variants. When set to ``kAuto``, XLA can turn on async collective based on other configurations or conditions automatically.
**Usage:** [v5p/32B](/MaxText/configs/v5p/32b.sh) [v5p/64B](/MaxText/configs/v5p/64b.sh) [v5p/128B](/MaxText/configs/v5p/128b.sh) [v5p/256B](/MaxText/configs/v5p/256b.sh) [v5p/512B](/MaxText/configs/v5p/512b.sh) [v5p/1024B](/MaxText/configs/v5p/1024b.sh) | -| xla_dump_to | String (filepath) | The folder where pre-optimization HLO files and other artifacts will be placed (see [XLA Tools](https://openxla.org/xla/tools)).
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_enable_latency_hiding_scheduler | Boolean (true/false) | This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_enable_triton_gemm | Boolean (true/false) | Use Triton-based matrix multiplication.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_enable_command_buffer | List of CommandBufferCmdType | Which kind of commands should be captured in command buffers.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_all_reduce_combine_threshold_bytes | Integer (bytes) | These flags tune when to combine multiple small AllGather / ReduceScatter / AllReduce into one big AllGather / ReduceScatter / AllReduce to reduce time spent on cross-device communication. For example, for the AllGather / ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather / ReduceScatter. By default, the combine_threshold_bytes is set to 256.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_all_gather_combine_threshold_bytes | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_reduce_scatter_combine_threshold_bytes | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_enable_pipelined_all_gather | Boolean (true/false) | Enable pipelinling of all-gather instructions.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_enable_pipelined_reduce_scatter | Boolean (true/false) | Enable pipelinling of reduce-scatter instructions.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_enable_pipelined_all_reduce | Boolean (true/false) | Enable pipelinling of all-reduce instructions.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_enable_while_loop_double_buffering | Boolean (true/false) | Enable double-buffering for while loop.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_enable_all_gather_combine_by_dim | Boolean (true/false) | Combine all-gather ops with the same gather dimension or irrespective of their dimension.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_gpu_enable_reduce_scatter_combine_by_dim | Boolean (true/false) | Combine reduce-scatter ops with the same dimension or irrespective of their dimension.
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | -| xla_disable_hlo_passes | String (comma-separated list of pass names) | Comma-separated list of HLO passes to be disabled. These names must exactly match the pass name (no whitespace around commas).
**Usage:** [a3/Llama2-7B 1vm](/MaxText/configs/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/MaxText/configs/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/MaxText/configs/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/MaxText/configs/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/MaxText/configs/a3/llama_2_7B/16vm.sh) | +| xla_tpu_enable_data_parallel_all_reduce_opt | Boolean (true/false) | Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding.
**Usage:** [v5p/32B](src/maxtext/configs/tpu/v5p/32b.sh) [v5p/64B](/maxtext/configs/tpu/v5p/64b.sh) [v5p/128B](/maxtext/configs/tpu/v5p/128b.sh) [v5p/256B](/maxtext/configs/tpu/v5p/256b.sh) [v5p/512B](/maxtext/configs/tpu/v5p/512b.sh) [v5p/1024B](/maxtext/configs/tpu/v5p/1024b.sh) [v5e/16B](/maxtext/configs/tpu/v5e/16B.sh) [v5e/32B](/maxtext/configs/tpu/v5e/32b.sh) [v5e/64B](/maxtext/configs/tpu/v5e/64b.sh) [v5e/128B](/maxtext/configs/tpu/v5e/128b.sh) [v5e/Llama2-7B](/maxtext/configs/tpu/v5e/llama2_7b.sh) [v5e/Llama2-13B](/maxtext/configs/tpu/v5e/llama2_13b.sh) [v5e/Llama2-70B](/maxtext/configs/tpu/v5e/llama2_70b.sh) [v5e/GPT3-175B](/maxtext/configs/tpu/v5e/gpt3_175b.sh) | +| xla_tpu_data_parallel_opt_different_sized_ops | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes doesn't match what can Be saved in place in the stacked variables. Can increase memory pressure.
**Usage:** [v5p/32B](/maxtext/configs/tpu/v5p/32b.sh) [v5p/64B](/maxtext/configs/tpu/v5p/64b.sh) [v5p/128B](/maxtext/configs/tpu/v5p/128b.sh) [v5p/256B](/maxtext/configs/tpu/v5p/256b.sh) [v5p/512B](/maxtext/configs/tpu/v5p/512b.sh) [v5p/1024B](/maxtext/configs/tpu/v5p/1024b.sh) [v5e/16B](/maxtext/configs/tpu/v5e/16B.sh) [v5e/32B](/maxtext/configs/tpu/v5e/32b.sh) [v5e/64B](/maxtext/configs/tpu/v5e/64b.sh) [v5e/128B](/maxtext/configs/tpu/v5e/128b.sh) [v5e/Llama2-7B](/maxtext/configs/tpu/v5e/llama2_7b.sh) [v5e/Llama2-13B](/maxtext/configs/tpu/v5e/llama2_13b.sh) [v5e/Llama2-70B](/maxtext/configs/tpu/v5e/llama2_70b.sh) [v5e/GPT3-175B](/maxtext/configs/tpu/v5e/gpt3_175b.sh) | +| xla_tpu_enable_async_collective_fusion | Boolean (true/false) | Enables the pass which fuses async collective communications with compute ops (output/loop-fusion or convolution) that are scheduled between their -start and -done instructions.
**Usage:** [v5p/32B](/maxtext/configs/tpu/v5p/32b.sh) [v5p/64B](/maxtext/configs/tpu/v5p/64b.sh) [v5p/128B](/maxtext/configs/tpu/v5p/128b.sh) [v5p/256B](/maxtext/configs/tpu/v5p/256b.sh) [v5p/512B](/maxtext/configs/tpu/v5p/512b.sh) [v5p/1024B](/maxtext/configs/tpu/v5p/1024b.sh) [v5e/16B](/maxtext/configs/tpu/v5e/16B.sh) [v5e/32B](/maxtext/configs/tpu/v5e/32b.sh) [v5e/64B](/maxtext/configs/tpu/v5e/64b.sh) [v5e/128B](/maxtext/configs/tpu/v5e/128b.sh) [v5e/Llama2-7B](/maxtext/configs/tpu/v5e/llama2_7b.sh) [v5e/Llama2-13B](/maxtext/configs/tpu/v5e/llama2_13b.sh) [v5e/Llama2-70B](/maxtext/configs/tpu/v5e/llama2_70b.sh) [v5e/GPT3-175B](/maxtext/configs/tpu/v5e/gpt3_175b.sh) | +| xla_tpu_enable_async_collective_fusion_fuse_all_gather | TristateFlag (true/false/kAuto) | Enables fusing all-gathers within the AsyncCollectiveFusion pass.
If set to ``kAuto``, it will be enabled based on the target."
**Usage:** [v5p/32B](/maxtext/configs/tpu/v5p/32b.sh) [v5p/64B](/maxtext/configs/tpu/v5p/64b.sh) [v5p/128B](/maxtext/configs/tpu/v5p/128b.sh) [v5p/256B](/maxtext/configs/tpu/v5p/256b.sh) [v5p/512B](/maxtext/configs/tpu/v5p/512b.sh) [v5p/1024B](/maxtext/configs/tpu/v5p/1024b.sh) [v5e/16B](/maxtext/configs/tpu/v5e/16B.sh) [v5e/32B](/maxtext/configs/tpu/v5e/32b.sh) [v5e/64B](/maxtext/configs/tpu/v5e/64b.sh) [v5e/128B](/maxtext/configs/tpu/v5e/128b.sh) [v5e/Llama2-7B](/maxtext/configs/tpu/v5e/llama2_7b.sh) [v5e/Llama2-13B](/maxtext/configs/tpu/v5e/llama2_13b.sh) [v5e/Llama2-70B](/maxtext/configs/tpu/v5e/llama2_70b.sh) [v5e/GPT3-175B](/maxtext/configs/tpu/v5e/gpt3_175b.sh) | +| xla_tpu_enable_async_collective_fusion_multiple_steps | Boolean (true/false) | Enables continuing the same async collective in multiple steps (fusions) in the AsyncCollectiveFusion pass.
**Usage:** [v5p/32B](/maxtext/configs/tpu/v5p/32b.sh) [v5p/64B](/maxtext/configs/tpu/v5p/64b.sh) [v5p/128B](/maxtext/configs/tpu/v5p/128b.sh) [v5p/256B](/maxtext/configs/tpu/v5p/256b.sh) [v5p/512B](/maxtext/configs/tpu/v5p/512b.sh) [v5p/1024B](/maxtext/configs/tpu/v5p/1024b.sh) [v5e/16B](/maxtext/configs/tpu/v5e/16B.sh) [v5e/32B](/maxtext/configs/tpu/v5e/32b.sh) [v5e/64B](/maxtext/configs/tpu/v5e/64b.sh) [v5e/128B](/maxtext/configs/tpu/v5e/128b.sh) [v5e/Llama2-7B](/maxtext/configs/tpu/v5e/llama2_7b.sh) [v5e/Llama2-13B](/maxtext/configs/tpu/v5e/llama2_13b.sh) [v5e/Llama2-70B](/maxtext/configs/tpu/v5e/llama2_70b.sh) [v5e/GPT3-175B](/maxtext/configs/tpu/v5e/gpt3_175b.sh) | +| xla_tpu_overlap_compute_collective_tc | Boolean (true/false) | Enables the overlap of compute and communication on a single TensorCore, i.e., one core equivalent of MegaCore fusion.
**Usage:** [v5p/32B](/maxtext/configs/tpu/v5p/32b.sh) [v5p/64B](/maxtext/configs/tpu/v5p/64b.sh) [v5p/128B](/maxtext/configs/tpu/v5p/128b.sh) [v5p/256B](/maxtext/configs/tpu/v5p/256b.sh) [v5p/512B](/maxtext/configs/tpu/v5p/512b.sh) [v5p/1024B](/maxtext/configs/tpu/v5p/1024b.sh) [v5e/16B](/maxtext/configs/tpu/v5e/16B.sh) [v5e/32B](/maxtext/configs/tpu/v5e/32b.sh) [v5e/64B](/maxtext/configs/tpu/v5e/64b.sh) [v5e/128B](/maxtext/configs/tpu/v5e/128b.sh) [v5e/Llama2-7B](/maxtext/configs/tpu/v5e/llama2_7b.sh) [v5e/Llama2-13B](/maxtext/configs/tpu/v5e/llama2_13b.sh) [v5e/Llama2-70B](/maxtext/configs/tpu/v5e/llama2_70b.sh) [v5e/GPT3-175B](/maxtext/configs/tpu/v5e/gpt3_175b.sh) | +| xla_enable_async_all_gather | TristateFlag (true/false/kAuto) | If set to true, enables async all gather. If ``kAuto``, enables only for platforms that implement async all-gather. The implementation (such as BC-offload or continuation fusion) is chosen based on other flag values.
**Usage:** [v4/22B](/maxtext/configs/tpu/v4/22B.sh) [v4/52B](/maxtext/configs/tpu/v4/52B.sh) [v5p/32B](/maxtext/configs/tpu/v5p/32b.sh) [v5p/64B](/maxtext/configs/tpu/v5p/64b.sh) [v5p/128B](/maxtext/configs/tpu/v5p/128b.sh) [v5p/256B](/maxtext/configs/tpu/v5p/256b.sh) [v5p/512B](/maxtext/configs/tpu/v5p/512b.sh) [v5p/1024B](/maxtext/configs/tpu/v5p/1024b.sh) [v5e/16B](/maxtext/configs/tpu/v5e/16B.sh) [v5e/32B](/maxtext/configs/tpu/v5e/32b.sh) [v5e/64B](/maxtext/configs/tpu/v5e/64b.sh) [v5e/128B](/maxtext/configs/tpu/v5e/128b.sh) [v5e/Llama2-7B](/maxtext/configs/tpu/v5e/llama2_7b.sh) [v5e/Llama2-13B](/maxtext/configs/tpu/v5e/llama2_13b.sh) [v5e/Llama2-70B](/maxtext/configs/tpu/v5e/llama2_70b.sh) [v5e/GPT3-175B](/maxtext/configs/tpu/v5e/gpt3_175b.sh) | +| xla_tpu_spmd_rng_bit_generator_unsafe | Boolean (true/false) | Whether to run RngBitGenerator HLO in a partitioned way, which is unsafe if deterministic results are expected with different shardings on different parts of the computation.
**Usage:** [v5e/GPT3-175B](/maxtext/configs/tpu/v5e/gpt3_175b.sh) | +| xla_tpu_megacore_fusion_allow_ags | Boolean (true/false) | Allows fusing all-gathers with convolutions/all-reduces.
**Usage:** [v5p/32B](/maxtext/configs/tpu/v5p/32b.sh) [v5p/64B](/maxtext/configs/tpu/v5p/64b.sh) [v5p/128B](/maxtext/configs/tpu/v5p/128b.sh) [v5p/256B](/maxtext/configs/tpu/v5p/256b.sh) [v5p/512B](/maxtext/configs/tpu/v5p/512b.sh) [v5p/1024B](/maxtext/configs/tpu/v5p/1024b.sh) | +| xla_tpu_enable_ag_backward_pipelining | Boolean (true/false) | Pipelines all-gathers (currently megascale all-gathers) backwards through scan loops.
**Usage:** [v5p/32B](/maxtext/configs/tpu/v5p/32b.sh) [v5p/64B](/maxtext/configs/tpu/v5p/64b.sh) [v5p/128B](/maxtext/configs/tpu/v5p/128b.sh) [v5p/256B](/maxtext/configs/tpu/v5p/256b.sh) [v5p/512B](/maxtext/configs/tpu/v5p/512b.sh) [v5p/1024B](/maxtext/configs/tpu/v5p/1024b.sh) | +| xla_enable_async_collective_permute | TristateFlag (true/false/kAuto) | Rewrites all collective-permute operations to their asynchronous variants. When set to ``kAuto``, XLA can turn on async collective based on other configurations or conditions automatically.
**Usage:** [v5p/32B](/maxtext/configs/tpu/v5p/32b.sh) [v5p/64B](/maxtext/configs/tpu/v5p/64b.sh) [v5p/128B](/maxtext/configs/tpu/v5p/128b.sh) [v5p/256B](/maxtext/configs/tpu/v5p/256b.sh) [v5p/512B](/maxtext/configs/tpu/v5p/512b.sh) [v5p/1024B](/maxtext/configs/tpu/v5p/1024b.sh) | +| xla_dump_to | String (filepath) | The folder where pre-optimization HLO files and other artifacts will be placed (see [XLA Tools](https://openxla.org/xla/tools)).
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_enable_latency_hiding_scheduler | Boolean (true/false) | This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_enable_triton_gemm | Boolean (true/false) | Use Triton-based matrix multiplication.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_enable_command_buffer | List of CommandBufferCmdType | Which kind of commands should be captured in command buffers.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_all_reduce_combine_threshold_bytes | Integer (bytes) | These flags tune when to combine multiple small AllGather / ReduceScatter / AllReduce into one big AllGather / ReduceScatter / AllReduce to reduce time spent on cross-device communication. For example, for the AllGather / ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather / ReduceScatter. By default, the combine_threshold_bytes is set to 256.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_all_gather_combine_threshold_bytes | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_reduce_scatter_combine_threshold_bytes | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_enable_pipelined_all_gather | Boolean (true/false) | Enable pipelinling of all-gather instructions.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_enable_pipelined_reduce_scatter | Boolean (true/false) | Enable pipelinling of reduce-scatter instructions.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_enable_pipelined_all_reduce | Boolean (true/false) | Enable pipelinling of all-reduce instructions.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_enable_while_loop_double_buffering | Boolean (true/false) | Enable double-buffering for while loop.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_enable_all_gather_combine_by_dim | Boolean (true/false) | Combine all-gather ops with the same gather dimension or irrespective of their dimension.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_gpu_enable_reduce_scatter_combine_by_dim | Boolean (true/false) | Combine reduce-scatter ops with the same dimension or irrespective of their dimension.
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | +| xla_disable_hlo_passes | String (comma-separated list of pass names) | Comma-separated list of HLO passes to be disabled. These names must exactly match the pass name (no whitespace around commas).
**Usage:** [a3/Llama2-7B 1vm](/maxtext/configs/gpu/a3/llama_2_7B/1vm.sh) [a3/Llama2-7B 2vm](/maxtext/configs/gpu/a3/llama_2_7B/2vm.sh) [a3/Llama2-7B 4vm](/maxtext/configs/gpu/a3/llama_2_7B/4vm.sh) [a3/Llama2-7B 8vm](/maxtext/configs/gpu/a3/llama_2_7B/8vm.sh) [a3/Llama2-7B 16vm](/maxtext/configs/gpu/a3/llama_2_7B/16vm.sh) | diff --git a/src/MaxText/configs/__init__.py b/src/maxtext/configs/__init__.py similarity index 100% rename from src/MaxText/configs/__init__.py rename to src/maxtext/configs/__init__.py diff --git a/src/MaxText/configs/base.yml b/src/maxtext/configs/base.yml similarity index 99% rename from src/MaxText/configs/base.yml rename to src/maxtext/configs/base.yml index 651bfd3267..75d63c5cfb 100644 --- a/src/MaxText/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -97,7 +97,7 @@ dtype: "bfloat16" # used to configure quantization in the transformer layers, defaults to null implying bf16. # possible alternative settings are as follows: # 'int8' for dynamic range quantization using 8-bits -# 'intmp' for mixed precision quantization for inference as described here: src/MaxText/configs/quantization/readme.md +# 'intmp' for mixed precision quantization for inference as described here: src/maxtext/configs/quantization/readme.md # 'fp8' for 8-bit floating-point gemms on nvidia gpus. # 'nanoo_fp8' for 8-bit floating-point gemms on amd mi300/mi325 gpus. # 'fp8_full' for fp8 quantization with static scaling. diff --git a/src/MaxText/configs/decoupled_base_test.yml b/src/maxtext/configs/decoupled_base_test.yml similarity index 100% rename from src/MaxText/configs/decoupled_base_test.yml rename to src/maxtext/configs/decoupled_base_test.yml diff --git a/src/MaxText/configs/experimental/1024b.sh b/src/maxtext/configs/experimental/1024b.sh similarity index 87% rename from src/MaxText/configs/experimental/1024b.sh rename to src/maxtext/configs/experimental/1024b.sh index c4a2061628..a7c1658ff3 100644 --- a/src/MaxText/configs/experimental/1024b.sh +++ b/src/maxtext/configs/experimental/1024b.sh @@ -1,6 +1,6 @@ echo "Running 1024b.sh" # Example command to invoke this script -# bash src/MaxText/configs/experimental/1024b.sh +# bash src/maxtext/configs/experimental/1024b.sh # Stop execution if any command exits with error set -e @@ -19,7 +19,7 @@ bash preflight.sh PLATFORM=$PLATFORM # Train export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=$RUN_NAME\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME\ steps=20 per_device_batch_size=2 enable_checkpointing=false\ remat_policy=full global_parameter_scale=1024\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=16\ diff --git a/src/MaxText/configs/experimental/128b.sh b/src/maxtext/configs/experimental/128b.sh similarity index 88% rename from src/MaxText/configs/experimental/128b.sh rename to src/maxtext/configs/experimental/128b.sh index c6022bdfad..033e92556e 100644 --- a/src/MaxText/configs/experimental/128b.sh +++ b/src/maxtext/configs/experimental/128b.sh @@ -1,6 +1,6 @@ echo "Running 128b.sh" # Example command to invoke this script -# bash src/MaxText/configs/experimental/128b.sh +# bash src/maxtext/configs/experimental/128b.sh # Stop execution if any command exits with error set -e @@ -19,7 +19,7 @@ bash preflight.sh PLATFORM=$PLATFORM # Train export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=$RUN_NAME\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME\ steps=30 per_device_batch_size=2 enable_checkpointing=false\ remat_policy=minimal global_parameter_scale=128\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=8\ diff --git a/src/MaxText/configs/experimental/256b.sh b/src/maxtext/configs/experimental/256b.sh similarity index 88% rename from src/MaxText/configs/experimental/256b.sh rename to src/maxtext/configs/experimental/256b.sh index 8e13a5d094..8ebd301d4c 100644 --- a/src/MaxText/configs/experimental/256b.sh +++ b/src/maxtext/configs/experimental/256b.sh @@ -1,6 +1,6 @@ echo "Running 256b.sh" # Example command to invoke this script -# bash src/MaxText/configs/experimental/256b.sh +# bash src/maxtext/configs/experimental/256b.sh # Stop execution if any command exits with error set -e @@ -19,7 +19,7 @@ bash preflight.sh PLATFORM=$PLATFORM # Train export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=$RUN_NAME\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME\ steps=20 per_device_batch_size=2 enable_checkpointing=false\ remat_policy=minimal global_parameter_scale=256\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=8\ diff --git a/src/MaxText/configs/experimental/32b.sh b/src/maxtext/configs/experimental/32b.sh similarity index 88% rename from src/MaxText/configs/experimental/32b.sh rename to src/maxtext/configs/experimental/32b.sh index 6d656277bd..f95f296cbb 100644 --- a/src/MaxText/configs/experimental/32b.sh +++ b/src/maxtext/configs/experimental/32b.sh @@ -1,6 +1,6 @@ echo "Running 32b.sh" # Example command to invoke this script -# bash src/MaxText/configs/experimental/32b.sh +# bash src/maxtext/configs/experimental/32b.sh # Stop execution if any command exits with error set -e @@ -19,7 +19,7 @@ bash preflight.sh PLATFORM=$PLATFORM # Train export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=$RUN_NAME\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME\ steps=30 per_device_batch_size=8 enable_checkpointing=false\ remat_policy=minimal global_parameter_scale=32\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=4\ diff --git a/src/MaxText/configs/experimental/512b.sh b/src/maxtext/configs/experimental/512b.sh similarity index 87% rename from src/MaxText/configs/experimental/512b.sh rename to src/maxtext/configs/experimental/512b.sh index 9c5b76a00a..5bd2006b8f 100644 --- a/src/MaxText/configs/experimental/512b.sh +++ b/src/maxtext/configs/experimental/512b.sh @@ -1,6 +1,6 @@ echo "Running 512b.sh" # Example command to invoke this script -# bash src/MaxText/configs/experimental/512b.sh +# bash src/maxtext/configs/experimental/512b.sh # Stop execution if any command exits with error set -e @@ -19,7 +19,7 @@ bash preflight.sh PLATFORM=$PLATFORM # Train export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=$RUN_NAME\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME\ steps=20 per_device_batch_size=2 enable_checkpointing=false\ remat_policy=full global_parameter_scale=512\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=16\ diff --git a/src/MaxText/configs/experimental/64b.sh b/src/maxtext/configs/experimental/64b.sh similarity index 88% rename from src/MaxText/configs/experimental/64b.sh rename to src/maxtext/configs/experimental/64b.sh index d58ed87346..f6756eb621 100644 --- a/src/MaxText/configs/experimental/64b.sh +++ b/src/maxtext/configs/experimental/64b.sh @@ -1,6 +1,6 @@ echo "Running 64b.sh" # Example command to invoke this script -# bash src/MaxText/configs/experimental/64b.sh +# bash src/maxtext/configs/experimental/64b.sh # Stop execution if any command exits with error set -e @@ -19,7 +19,7 @@ bash preflight.sh PLATFORM=$PLATFORM # Train export LIBTPU_INIT_ARGS="--xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=$RUN_NAME\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME\ steps=30 per_device_batch_size=4 enable_checkpointing=false\ remat_policy=minimal global_parameter_scale=64\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=4\ diff --git a/src/MaxText/configs/a3/llama_2_7b/16vm.sh b/src/maxtext/configs/gpu/a3/llama_2_7b/16vm.sh similarity index 87% rename from src/MaxText/configs/a3/llama_2_7b/16vm.sh rename to src/maxtext/configs/gpu/a3/llama_2_7b/16vm.sh index 5acdd9cb81..f49f1121a4 100644 --- a/src/MaxText/configs/a3/llama_2_7b/16vm.sh +++ b/src/maxtext/configs/gpu/a3/llama_2_7b/16vm.sh @@ -3,7 +3,7 @@ echo "Running 16vm.sh" # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \ # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \ # --device-type ${DEVICE_TYPE} --num-slices 16 --priority=high \ -# --command "bash src/MaxText/configs/a3/llama_2_7b/16vm.sh" +# --command "bash src/maxtext/configs/gpu/a3/llama_2_7b/16vm.sh" # Stop execution if any command exits with error set -e @@ -30,5 +30,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ --xla_disable_hlo_passes=rematerialization" # 16 nodes -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ dcn_data_parallelism=16 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane diff --git a/src/MaxText/configs/a3/llama_2_7b/1vm.sh b/src/maxtext/configs/gpu/a3/llama_2_7b/1vm.sh similarity index 87% rename from src/MaxText/configs/a3/llama_2_7b/1vm.sh rename to src/maxtext/configs/gpu/a3/llama_2_7b/1vm.sh index 1ba7549cd8..2fb05edc10 100644 --- a/src/MaxText/configs/a3/llama_2_7b/1vm.sh +++ b/src/maxtext/configs/gpu/a3/llama_2_7b/1vm.sh @@ -4,7 +4,7 @@ echo "Running 1vm.sh" # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \ # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \ # --device-type ${DEVICE_TYPE} --num-slices 1 \ -# --command "bash src/MaxText/configs/a3/llama_2_7b/1vm.sh" +# --command "bash src/maxtext/configs/gpu/a3/llama_2_7b/1vm.sh" # Stop execution if any command exits with error set -e @@ -30,5 +30,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ # 1 node, DATA_DP=1, ICI_FSDP=8 -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ dcn_data_parallelism=1 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane diff --git a/src/MaxText/configs/a3/llama_2_7b/2vm.sh b/src/maxtext/configs/gpu/a3/llama_2_7b/2vm.sh similarity index 87% rename from src/MaxText/configs/a3/llama_2_7b/2vm.sh rename to src/maxtext/configs/gpu/a3/llama_2_7b/2vm.sh index 68e03268c3..cc2ce97d7e 100644 --- a/src/MaxText/configs/a3/llama_2_7b/2vm.sh +++ b/src/maxtext/configs/gpu/a3/llama_2_7b/2vm.sh @@ -4,7 +4,7 @@ echo "Running 2vm.sh" # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \ # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \ # --device-type ${DEVICE_TYPE} --num-slices 2 \ -# --command "bash src/MaxText/configs/a3/llama_2_7b/2vm.sh" +# --command "bash src/maxtext/configs/gpu/a3/llama_2_7b/2vm.sh" # Stop execution if any command exits with error set -e @@ -31,5 +31,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ # 2 nodes -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ dcn_data_parallelism=2 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane diff --git a/src/MaxText/configs/a3/llama_2_7b/4vm.sh b/src/maxtext/configs/gpu/a3/llama_2_7b/4vm.sh similarity index 87% rename from src/MaxText/configs/a3/llama_2_7b/4vm.sh rename to src/maxtext/configs/gpu/a3/llama_2_7b/4vm.sh index aeef6c0a98..5fccd31877 100644 --- a/src/MaxText/configs/a3/llama_2_7b/4vm.sh +++ b/src/maxtext/configs/gpu/a3/llama_2_7b/4vm.sh @@ -3,7 +3,7 @@ echo "Running 4vm.sh" # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \ # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \ # --device-type ${DEVICE_TYPE} --num-slices 4 \ -# --command "bash src/MaxText/configs/a3/llama_2_7b/4vm.sh" +# --command "bash src/maxtext/configs/gpu/a3/llama_2_7b/4vm.sh" # Stop execution if any command exits with error set -e @@ -28,5 +28,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ --xla_disable_hlo_passes=rematerialization" # 4 nodes -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ dcn_data_parallelism=4 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane diff --git a/src/MaxText/configs/a3/llama_2_7b/8vm.sh b/src/maxtext/configs/gpu/a3/llama_2_7b/8vm.sh similarity index 87% rename from src/MaxText/configs/a3/llama_2_7b/8vm.sh rename to src/maxtext/configs/gpu/a3/llama_2_7b/8vm.sh index d588284909..3388bbb52b 100644 --- a/src/MaxText/configs/a3/llama_2_7b/8vm.sh +++ b/src/maxtext/configs/gpu/a3/llama_2_7b/8vm.sh @@ -3,7 +3,7 @@ echo "Running 8vm.sh" # python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \ # --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \ # --device-type ${DEVICE_TYPE} --num-slices 8 \ -# --command "bash src/MaxText/configs/a3/llama_2_7b/8vm.sh" +# --command "bash src/maxtext/configs/gpu/a3/llama_2_7b/8vm.sh" # Stop execution if any command exits with error set -e @@ -29,5 +29,5 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ --xla_disable_hlo_passes=rematerialization" # 8 nodes -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/models/gpu/llama2_7b.yml run_name=$RUN_NAME \ dcn_data_parallelism=8 ici_fsdp_parallelism=8 base_output_directory=$OUTPUT_PATH profiler=xplane diff --git a/src/MaxText/configs/a3/llama_2_7b/README.md b/src/maxtext/configs/gpu/a3/llama_2_7b/README.md similarity index 100% rename from src/MaxText/configs/a3/llama_2_7b/README.md rename to src/maxtext/configs/gpu/a3/llama_2_7b/README.md diff --git a/src/MaxText/configs/a3/llama_3.1_405b/128vm.sh b/src/maxtext/configs/gpu/a3/llama_3.1_405b/128vm.sh similarity index 90% rename from src/MaxText/configs/a3/llama_3.1_405b/128vm.sh rename to src/maxtext/configs/gpu/a3/llama_3.1_405b/128vm.sh index 20ff807f9f..bc24a9be64 100644 --- a/src/MaxText/configs/a3/llama_3.1_405b/128vm.sh +++ b/src/maxtext/configs/gpu/a3/llama_3.1_405b/128vm.sh @@ -1,6 +1,6 @@ echo "Running 128vm.sh" # Example command to invoke this script via XPK, assume you've installed xpk -# COMMAND="bash src/MaxText/configs/a3/llama_3.1_405b/128vm.sh" +# COMMAND="bash src/maxtext/configs/gpu/a3/llama_3.1_405b/128vm.sh" # COMMAND='export LD_LIBRARY_PATH=/usr/local/cuda-12.6/compat:$LD_LIBRARY_PATH;'"${COMMAND}"; # # xpk workload create --project=${PROJECT}--cluster=${CLUSTER_NAME} --zone=${ZONE} \ @@ -33,7 +33,7 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ --xla_disable_hlo_passes=rematerialization" # 128 nodes -python3 -m MaxText.$EXECUTABLE ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}configs/models/llama3.1_405b.yml run_name=$RUN_NAME \ +python3 -m MaxText.$EXECUTABLE ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/models/llama3.1_405b.yml run_name=$RUN_NAME \ base_config=base.yml \ run_name=gpu_train_test \ hardware=gpu \ diff --git a/src/MaxText/configs/gpu_smoke_test.yml b/src/maxtext/configs/gpu/gpu_smoke_test.yml similarity index 100% rename from src/MaxText/configs/gpu_smoke_test.yml rename to src/maxtext/configs/gpu/gpu_smoke_test.yml diff --git a/src/MaxText/configs/models/gpu/llama2_70b.yml b/src/maxtext/configs/gpu/models/llama2_70b.yml similarity index 100% rename from src/MaxText/configs/models/gpu/llama2_70b.yml rename to src/maxtext/configs/gpu/models/llama2_70b.yml diff --git a/src/MaxText/configs/models/gpu/llama2_7b.yml b/src/maxtext/configs/gpu/models/llama2_7b.yml similarity index 100% rename from src/MaxText/configs/models/gpu/llama2_7b.yml rename to src/maxtext/configs/gpu/models/llama2_7b.yml diff --git a/src/MaxText/configs/models/gpu/llama3.1_405b.yml b/src/maxtext/configs/gpu/models/llama3.1_405b.yml similarity index 100% rename from src/MaxText/configs/models/gpu/llama3.1_405b.yml rename to src/maxtext/configs/gpu/models/llama3.1_405b.yml diff --git a/src/MaxText/configs/models/gpu/llama3_70b.yml b/src/maxtext/configs/gpu/models/llama3_70b.yml similarity index 100% rename from src/MaxText/configs/models/gpu/llama3_70b.yml rename to src/maxtext/configs/gpu/models/llama3_70b.yml diff --git a/src/MaxText/configs/models/gpu/llama3_8b.yml b/src/maxtext/configs/gpu/models/llama3_8b.yml similarity index 100% rename from src/MaxText/configs/models/gpu/llama3_8b.yml rename to src/maxtext/configs/gpu/models/llama3_8b.yml diff --git a/src/MaxText/configs/models/gpu/mixtral_8x1b.yml b/src/maxtext/configs/gpu/models/mixtral_8x1b.yml similarity index 100% rename from src/MaxText/configs/models/gpu/mixtral_8x1b.yml rename to src/maxtext/configs/gpu/models/mixtral_8x1b.yml diff --git a/src/MaxText/configs/models/gpu/mixtral_8x2b.yml b/src/maxtext/configs/gpu/models/mixtral_8x2b.yml similarity index 100% rename from src/MaxText/configs/models/gpu/mixtral_8x2b.yml rename to src/maxtext/configs/gpu/models/mixtral_8x2b.yml diff --git a/src/MaxText/configs/models/gpu/mixtral_8x7b.yml b/src/maxtext/configs/gpu/models/mixtral_8x7b.yml similarity index 100% rename from src/MaxText/configs/models/gpu/mixtral_8x7b.yml rename to src/maxtext/configs/gpu/models/mixtral_8x7b.yml diff --git a/src/MaxText/configs/inference.yml b/src/maxtext/configs/inference/inference.yml similarity index 100% rename from src/MaxText/configs/inference.yml rename to src/maxtext/configs/inference/inference.yml diff --git a/src/MaxText/configs/inference_jetstream.yml b/src/maxtext/configs/inference/inference_jetstream.yml similarity index 100% rename from src/MaxText/configs/inference_jetstream.yml rename to src/maxtext/configs/inference/inference_jetstream.yml diff --git a/src/maxtext/inference/configs/multi_host/disaggregation/llama3_405b_v6e-16-16.yml b/src/maxtext/configs/inference/multihost/disaggregation/llama3_405b_v6e-16-16.yml similarity index 97% rename from src/maxtext/inference/configs/multi_host/disaggregation/llama3_405b_v6e-16-16.yml rename to src/maxtext/configs/inference/multihost/disaggregation/llama3_405b_v6e-16-16.yml index 8cbcc612de..aedbc64853 100644 --- a/src/maxtext/inference/configs/multi_host/disaggregation/llama3_405b_v6e-16-16.yml +++ b/src/maxtext/configs/inference/multihost/disaggregation/llama3_405b_v6e-16-16.yml @@ -1,4 +1,4 @@ -base_config: "inference_jetstream.yml" +base_config: "inference/inference_jetstream.yml" model_name: "llama3.1-405b" sharding_strategy: "experimental" diff --git a/src/MaxText/configs/v5e/llama2_70b_v5e-16.yml b/src/maxtext/configs/inference/multihost/interleaved/llama2_70b_v5e-16.yml similarity index 97% rename from src/MaxText/configs/v5e/llama2_70b_v5e-16.yml rename to src/maxtext/configs/inference/multihost/interleaved/llama2_70b_v5e-16.yml index dae967e20e..121efe248e 100644 --- a/src/MaxText/configs/v5e/llama2_70b_v5e-16.yml +++ b/src/maxtext/configs/inference/multihost/interleaved/llama2_70b_v5e-16.yml @@ -1,4 +1,4 @@ -base_config: "inference_jetstream.yml" +base_config: "inference/inference_jetstream.yml" # tensor = 8, autoregressive=2 # per_device_batch_size=6 diff --git a/src/maxtext/inference/configs/multi_host/interleaved/llama3_405b_v5e-64.yml b/src/maxtext/configs/inference/multihost/interleaved/llama3_405b_v5e-64.yml similarity index 97% rename from src/maxtext/inference/configs/multi_host/interleaved/llama3_405b_v5e-64.yml rename to src/maxtext/configs/inference/multihost/interleaved/llama3_405b_v5e-64.yml index 0e874c5cab..b91bb85fb3 100644 --- a/src/maxtext/inference/configs/multi_host/interleaved/llama3_405b_v5e-64.yml +++ b/src/maxtext/configs/inference/multihost/interleaved/llama3_405b_v5e-64.yml @@ -1,4 +1,4 @@ -base_config: "inference_jetstream.yml" +base_config: "inference/inference_jetstream.yml" # v5e-64 # tensor = 8, autoregressive=8 diff --git a/src/maxtext/inference/configs/multi_host/interleaved/llama3_70b_v5e-16.yml b/src/maxtext/configs/inference/multihost/interleaved/llama3_70b_v5e-16.yml similarity index 97% rename from src/maxtext/inference/configs/multi_host/interleaved/llama3_70b_v5e-16.yml rename to src/maxtext/configs/inference/multihost/interleaved/llama3_70b_v5e-16.yml index e7cc70310f..b3ca2d1465 100644 --- a/src/maxtext/inference/configs/multi_host/interleaved/llama3_70b_v5e-16.yml +++ b/src/maxtext/configs/inference/multihost/interleaved/llama3_70b_v5e-16.yml @@ -1,4 +1,4 @@ -base_config: "inference_jetstream.yml" +base_config: "inference/inference_jetstream.yml" # tensor = 8, autoregressive=2 # per_device_batch_size=6 diff --git a/src/MaxText/configs/models/deepseek-custom.yml b/src/maxtext/configs/models/deepseek-custom.yml similarity index 100% rename from src/MaxText/configs/models/deepseek-custom.yml rename to src/maxtext/configs/models/deepseek-custom.yml diff --git a/src/MaxText/configs/models/deepseek2-16b.yml b/src/maxtext/configs/models/deepseek2-16b.yml similarity index 100% rename from src/MaxText/configs/models/deepseek2-16b.yml rename to src/maxtext/configs/models/deepseek2-16b.yml diff --git a/src/MaxText/configs/models/deepseek2-236b.yml b/src/maxtext/configs/models/deepseek2-236b.yml similarity index 100% rename from src/MaxText/configs/models/deepseek2-236b.yml rename to src/maxtext/configs/models/deepseek2-236b.yml diff --git a/src/MaxText/configs/models/deepseek3-671b-2dfsdp.yml b/src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml similarity index 100% rename from src/MaxText/configs/models/deepseek3-671b-2dfsdp.yml rename to src/maxtext/configs/models/deepseek3-671b-2dfsdp.yml diff --git a/src/MaxText/configs/models/deepseek3-671b.yml b/src/maxtext/configs/models/deepseek3-671b.yml similarity index 100% rename from src/MaxText/configs/models/deepseek3-671b.yml rename to src/maxtext/configs/models/deepseek3-671b.yml diff --git a/src/MaxText/configs/models/deepseek3-test.yml b/src/maxtext/configs/models/deepseek3-test.yml similarity index 100% rename from src/MaxText/configs/models/deepseek3-test.yml rename to src/maxtext/configs/models/deepseek3-test.yml diff --git a/src/maxtext/configs/models/deepseek3-tiny.yml b/src/maxtext/configs/models/deepseek3-tiny.yml new file mode 100644 index 0000000000..4448df0693 --- /dev/null +++ b/src/maxtext/configs/models/deepseek3-tiny.yml @@ -0,0 +1,50 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny version of DeepSeek V3 for testing. + +base_emb_dim: 64 +base_num_query_heads: 4 +base_num_kv_heads: 4 +base_mlp_dim: 64 +base_moe_mlp_dim: 64 +base_num_decoder_layers: 61 +first_num_dense_layers: 3 +mlp_activations: ["silu","linear"] +vocab_size: 129280 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-6 +num_experts: 16 +num_experts_per_tok: 8 +shared_experts: 1 +routed_scaling_factor: 2.5 +routed_score_func: "sigmoid" +routed_bias: True +decoder_block: "deepseek" +# MLA +attention_type: "mla" +q_lora_rank: 32 +kv_lora_rank: 16 +qk_nope_head_dim: 128 +qk_rope_head_dim: 64 +v_head_dim: 128 +mscale: 1.0 +# RoPE +rope_type: "yarn" +rope_max_timescale: 10_000 # DeepSeek uses "rope_theta": 10000 +max_position_embeddings: 163840 +original_max_position_embeddings: 4096 +rope_factor: 40 +beta_fast: 32 diff --git a/src/MaxText/configs/models/deepseek3.2-671b.yml b/src/maxtext/configs/models/deepseek3.2-671b.yml similarity index 100% rename from src/MaxText/configs/models/deepseek3.2-671b.yml rename to src/maxtext/configs/models/deepseek3.2-671b.yml diff --git a/src/MaxText/configs/models/gemma-2b.yml b/src/maxtext/configs/models/gemma-2b.yml similarity index 100% rename from src/MaxText/configs/models/gemma-2b.yml rename to src/maxtext/configs/models/gemma-2b.yml diff --git a/src/MaxText/configs/models/gemma-7b.yml b/src/maxtext/configs/models/gemma-7b.yml similarity index 100% rename from src/MaxText/configs/models/gemma-7b.yml rename to src/maxtext/configs/models/gemma-7b.yml diff --git a/src/MaxText/configs/models/gemma2-27b.yml b/src/maxtext/configs/models/gemma2-27b.yml similarity index 100% rename from src/MaxText/configs/models/gemma2-27b.yml rename to src/maxtext/configs/models/gemma2-27b.yml diff --git a/src/MaxText/configs/models/gemma2-2b.yml b/src/maxtext/configs/models/gemma2-2b.yml similarity index 100% rename from src/MaxText/configs/models/gemma2-2b.yml rename to src/maxtext/configs/models/gemma2-2b.yml diff --git a/src/MaxText/configs/models/gemma2-9b.yml b/src/maxtext/configs/models/gemma2-9b.yml similarity index 100% rename from src/MaxText/configs/models/gemma2-9b.yml rename to src/maxtext/configs/models/gemma2-9b.yml diff --git a/src/MaxText/configs/models/gemma3-12b.yml b/src/maxtext/configs/models/gemma3-12b.yml similarity index 100% rename from src/MaxText/configs/models/gemma3-12b.yml rename to src/maxtext/configs/models/gemma3-12b.yml diff --git a/src/MaxText/configs/models/gemma3-27b.yml b/src/maxtext/configs/models/gemma3-27b.yml similarity index 100% rename from src/MaxText/configs/models/gemma3-27b.yml rename to src/maxtext/configs/models/gemma3-27b.yml diff --git a/src/MaxText/configs/models/gemma3-4b.yml b/src/maxtext/configs/models/gemma3-4b.yml similarity index 100% rename from src/MaxText/configs/models/gemma3-4b.yml rename to src/maxtext/configs/models/gemma3-4b.yml diff --git a/src/MaxText/configs/models/gpt-oss-120b.yml b/src/maxtext/configs/models/gpt-oss-120b.yml similarity index 100% rename from src/MaxText/configs/models/gpt-oss-120b.yml rename to src/maxtext/configs/models/gpt-oss-120b.yml diff --git a/src/MaxText/configs/models/gpt-oss-20b.yml b/src/maxtext/configs/models/gpt-oss-20b.yml similarity index 100% rename from src/MaxText/configs/models/gpt-oss-20b.yml rename to src/maxtext/configs/models/gpt-oss-20b.yml diff --git a/src/MaxText/configs/models/gpt3-175b.yml b/src/maxtext/configs/models/gpt3-175b.yml similarity index 100% rename from src/MaxText/configs/models/gpt3-175b.yml rename to src/maxtext/configs/models/gpt3-175b.yml diff --git a/src/MaxText/configs/models/gpt3-22b.yml b/src/maxtext/configs/models/gpt3-22b.yml similarity index 100% rename from src/MaxText/configs/models/gpt3-22b.yml rename to src/maxtext/configs/models/gpt3-22b.yml diff --git a/src/MaxText/configs/models/gpt3-52k.yml b/src/maxtext/configs/models/gpt3-52k.yml similarity index 100% rename from src/MaxText/configs/models/gpt3-52k.yml rename to src/maxtext/configs/models/gpt3-52k.yml diff --git a/src/MaxText/configs/models/gpt3-6b.yml b/src/maxtext/configs/models/gpt3-6b.yml similarity index 100% rename from src/MaxText/configs/models/gpt3-6b.yml rename to src/maxtext/configs/models/gpt3-6b.yml diff --git a/src/MaxText/configs/models/kimi-k2-1t.yml b/src/maxtext/configs/models/kimi-k2-1t.yml similarity index 100% rename from src/MaxText/configs/models/kimi-k2-1t.yml rename to src/maxtext/configs/models/kimi-k2-1t.yml diff --git a/src/MaxText/configs/models/llama2-13b.yml b/src/maxtext/configs/models/llama2-13b.yml similarity index 100% rename from src/MaxText/configs/models/llama2-13b.yml rename to src/maxtext/configs/models/llama2-13b.yml diff --git a/src/MaxText/configs/models/llama2-70b.yml b/src/maxtext/configs/models/llama2-70b.yml similarity index 100% rename from src/MaxText/configs/models/llama2-70b.yml rename to src/maxtext/configs/models/llama2-70b.yml diff --git a/src/MaxText/configs/models/llama2-7b.yml b/src/maxtext/configs/models/llama2-7b.yml similarity index 100% rename from src/MaxText/configs/models/llama2-7b.yml rename to src/maxtext/configs/models/llama2-7b.yml diff --git a/src/MaxText/configs/models/llama3-405b.yml b/src/maxtext/configs/models/llama3-405b.yml similarity index 100% rename from src/MaxText/configs/models/llama3-405b.yml rename to src/maxtext/configs/models/llama3-405b.yml diff --git a/src/MaxText/configs/models/llama3-70b.yml b/src/maxtext/configs/models/llama3-70b.yml similarity index 100% rename from src/MaxText/configs/models/llama3-70b.yml rename to src/maxtext/configs/models/llama3-70b.yml diff --git a/src/MaxText/configs/models/llama3-8b.yml b/src/maxtext/configs/models/llama3-8b.yml similarity index 100% rename from src/MaxText/configs/models/llama3-8b.yml rename to src/maxtext/configs/models/llama3-8b.yml diff --git a/src/MaxText/configs/models/llama3.1-405b.yml b/src/maxtext/configs/models/llama3.1-405b.yml similarity index 100% rename from src/MaxText/configs/models/llama3.1-405b.yml rename to src/maxtext/configs/models/llama3.1-405b.yml diff --git a/src/MaxText/configs/models/llama3.1-70b.yml b/src/maxtext/configs/models/llama3.1-70b.yml similarity index 100% rename from src/MaxText/configs/models/llama3.1-70b.yml rename to src/maxtext/configs/models/llama3.1-70b.yml diff --git a/src/MaxText/configs/models/llama3.1-8b.yml b/src/maxtext/configs/models/llama3.1-8b.yml similarity index 100% rename from src/MaxText/configs/models/llama3.1-8b.yml rename to src/maxtext/configs/models/llama3.1-8b.yml diff --git a/src/MaxText/configs/models/llama3.3-70b.yml b/src/maxtext/configs/models/llama3.3-70b.yml similarity index 100% rename from src/MaxText/configs/models/llama3.3-70b.yml rename to src/maxtext/configs/models/llama3.3-70b.yml diff --git a/src/MaxText/configs/models/llama4-17b-128e.yml b/src/maxtext/configs/models/llama4-17b-128e.yml similarity index 100% rename from src/MaxText/configs/models/llama4-17b-128e.yml rename to src/maxtext/configs/models/llama4-17b-128e.yml diff --git a/src/MaxText/configs/models/llama4-17b-16e.yml b/src/maxtext/configs/models/llama4-17b-16e.yml similarity index 100% rename from src/MaxText/configs/models/llama4-17b-16e.yml rename to src/maxtext/configs/models/llama4-17b-16e.yml diff --git a/src/MaxText/configs/models/mistral-7b.yml b/src/maxtext/configs/models/mistral-7b.yml similarity index 100% rename from src/MaxText/configs/models/mistral-7b.yml rename to src/maxtext/configs/models/mistral-7b.yml diff --git a/src/MaxText/configs/models/mixtral-8x22b.yml b/src/maxtext/configs/models/mixtral-8x22b.yml similarity index 100% rename from src/MaxText/configs/models/mixtral-8x22b.yml rename to src/maxtext/configs/models/mixtral-8x22b.yml diff --git a/src/MaxText/configs/models/mixtral-8x7b.yml b/src/maxtext/configs/models/mixtral-8x7b.yml similarity index 100% rename from src/MaxText/configs/models/mixtral-8x7b.yml rename to src/maxtext/configs/models/mixtral-8x7b.yml diff --git a/src/MaxText/configs/models/olmo3_32b.yml b/src/maxtext/configs/models/olmo3_32b.yml similarity index 100% rename from src/MaxText/configs/models/olmo3_32b.yml rename to src/maxtext/configs/models/olmo3_32b.yml diff --git a/src/MaxText/configs/models/olmo3_7b.yml b/src/maxtext/configs/models/olmo3_7b.yml similarity index 100% rename from src/MaxText/configs/models/olmo3_7b.yml rename to src/maxtext/configs/models/olmo3_7b.yml diff --git a/src/MaxText/configs/models/qwen3-0.6b.yml b/src/maxtext/configs/models/qwen3-0.6b.yml similarity index 100% rename from src/MaxText/configs/models/qwen3-0.6b.yml rename to src/maxtext/configs/models/qwen3-0.6b.yml diff --git a/src/MaxText/configs/models/qwen3-14b.yml b/src/maxtext/configs/models/qwen3-14b.yml similarity index 100% rename from src/MaxText/configs/models/qwen3-14b.yml rename to src/maxtext/configs/models/qwen3-14b.yml diff --git a/src/MaxText/configs/models/qwen3-235b-a22b.yml b/src/maxtext/configs/models/qwen3-235b-a22b.yml similarity index 100% rename from src/MaxText/configs/models/qwen3-235b-a22b.yml rename to src/maxtext/configs/models/qwen3-235b-a22b.yml diff --git a/src/MaxText/configs/models/qwen3-30b-a3b.yml b/src/maxtext/configs/models/qwen3-30b-a3b.yml similarity index 100% rename from src/MaxText/configs/models/qwen3-30b-a3b.yml rename to src/maxtext/configs/models/qwen3-30b-a3b.yml diff --git a/src/MaxText/configs/models/qwen3-32b.yml b/src/maxtext/configs/models/qwen3-32b.yml similarity index 100% rename from src/MaxText/configs/models/qwen3-32b.yml rename to src/maxtext/configs/models/qwen3-32b.yml diff --git a/src/MaxText/configs/models/qwen3-480b-a35b.yml b/src/maxtext/configs/models/qwen3-480b-a35b.yml similarity index 100% rename from src/MaxText/configs/models/qwen3-480b-a35b.yml rename to src/maxtext/configs/models/qwen3-480b-a35b.yml diff --git a/src/MaxText/configs/models/qwen3-4b-thinking-2507.yml b/src/maxtext/configs/models/qwen3-4b-thinking-2507.yml similarity index 100% rename from src/MaxText/configs/models/qwen3-4b-thinking-2507.yml rename to src/maxtext/configs/models/qwen3-4b-thinking-2507.yml diff --git a/src/MaxText/configs/models/qwen3-4b.yml b/src/maxtext/configs/models/qwen3-4b.yml similarity index 100% rename from src/MaxText/configs/models/qwen3-4b.yml rename to src/maxtext/configs/models/qwen3-4b.yml diff --git a/src/MaxText/configs/models/qwen3-8b.yml b/src/maxtext/configs/models/qwen3-8b.yml similarity index 100% rename from src/MaxText/configs/models/qwen3-8b.yml rename to src/maxtext/configs/models/qwen3-8b.yml diff --git a/src/MaxText/configs/models/qwen3-next-80b-a3b.yml b/src/maxtext/configs/models/qwen3-next-80b-a3b.yml similarity index 96% rename from src/MaxText/configs/models/qwen3-next-80b-a3b.yml rename to src/maxtext/configs/models/qwen3-next-80b-a3b.yml index 6f362ba4f5..ecdd9cceda 100644 --- a/src/MaxText/configs/models/qwen3-next-80b-a3b.yml +++ b/src/maxtext/configs/models/qwen3-next-80b-a3b.yml @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# MaxText/configs/models/qwen3-next-80b-a3b.yml +# maxtext/configs/models/qwen3-next-80b-a3b.yml # Set the decoder block to our new implementation decoder_block: "qwen3_next" diff --git a/src/MaxText/configs/models/qwen3-omni-30b-a3b.yml b/src/maxtext/configs/models/qwen3-omni-30b-a3b.yml similarity index 100% rename from src/MaxText/configs/models/qwen3-omni-30b-a3b.yml rename to src/maxtext/configs/models/qwen3-omni-30b-a3b.yml diff --git a/src/MaxText/configs/distillation.yml b/src/maxtext/configs/post_train/distillation.yml similarity index 100% rename from src/MaxText/configs/distillation.yml rename to src/maxtext/configs/post_train/distillation.yml diff --git a/src/MaxText/configs/dpo.yml b/src/maxtext/configs/post_train/dpo.yml similarity index 100% rename from src/MaxText/configs/dpo.yml rename to src/maxtext/configs/post_train/dpo.yml diff --git a/src/MaxText/configs/rl.yml b/src/maxtext/configs/post_train/rl.yml similarity index 100% rename from src/MaxText/configs/rl.yml rename to src/maxtext/configs/post_train/rl.yml diff --git a/src/MaxText/configs/rl_mt_jt.yml b/src/maxtext/configs/post_train/rl_mt_jt.yml similarity index 100% rename from src/MaxText/configs/rl_mt_jt.yml rename to src/maxtext/configs/post_train/rl_mt_jt.yml diff --git a/src/MaxText/configs/sft-vision-chartqa.yml b/src/maxtext/configs/post_train/sft-vision-chartqa.yml similarity index 100% rename from src/MaxText/configs/sft-vision-chartqa.yml rename to src/maxtext/configs/post_train/sft-vision-chartqa.yml diff --git a/src/MaxText/configs/sft-vision-slidevqa.yml b/src/maxtext/configs/post_train/sft-vision-slidevqa.yml similarity index 100% rename from src/MaxText/configs/sft-vision-slidevqa.yml rename to src/maxtext/configs/post_train/sft-vision-slidevqa.yml diff --git a/src/MaxText/configs/sft.yml b/src/maxtext/configs/post_train/sft.yml similarity index 100% rename from src/MaxText/configs/sft.yml rename to src/maxtext/configs/post_train/sft.yml diff --git a/src/MaxText/configs/quantization/README.md b/src/maxtext/configs/quantization/README.md similarity index 100% rename from src/MaxText/configs/quantization/README.md rename to src/maxtext/configs/quantization/README.md diff --git a/src/MaxText/configs/quantization/dense_llm_subchannel.json b/src/maxtext/configs/quantization/dense_llm_subchannel.json similarity index 100% rename from src/MaxText/configs/quantization/dense_llm_subchannel.json rename to src/maxtext/configs/quantization/dense_llm_subchannel.json diff --git a/src/MaxText/configs/quantization/dense_llm_weight_only_scale.json b/src/maxtext/configs/quantization/dense_llm_weight_only_scale.json similarity index 100% rename from src/MaxText/configs/quantization/dense_llm_weight_only_scale.json rename to src/maxtext/configs/quantization/dense_llm_weight_only_scale.json diff --git a/src/MaxText/configs/quantization/int4_weight_only.json b/src/maxtext/configs/quantization/int4_weight_only.json similarity index 100% rename from src/MaxText/configs/quantization/int4_weight_only.json rename to src/maxtext/configs/quantization/int4_weight_only.json diff --git a/src/MaxText/configs/quantization/int8_weight_only.json b/src/maxtext/configs/quantization/int8_weight_only.json similarity index 100% rename from src/MaxText/configs/quantization/int8_weight_only.json rename to src/maxtext/configs/quantization/int8_weight_only.json diff --git a/src/MaxText/configs/tpu_smoke_test.yml b/src/maxtext/configs/tpu/tpu_smoke_test.yml similarity index 100% rename from src/MaxText/configs/tpu_smoke_test.yml rename to src/maxtext/configs/tpu/tpu_smoke_test.yml diff --git a/src/MaxText/configs/v4/22b.sh b/src/maxtext/configs/tpu/v4/22b.sh similarity index 83% rename from src/MaxText/configs/v4/22b.sh rename to src/maxtext/configs/tpu/v4/22b.sh index f811a296a0..15a99874aa 100644 --- a/src/MaxText/configs/v4/22b.sh +++ b/src/maxtext/configs/tpu/v4/22b.sh @@ -22,10 +22,10 @@ echo "Running 22b.sh" # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script for training: -# bash src/MaxText/configs/v4/22b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v4/22b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v4/22b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v4-128 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v4/22b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v4-128 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -55,7 +55,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml\ ici_fsdp_parallelism=64 steps=10 per_device_batch_size=13 profiler=xplane remat_policy=full\ base_emb_dim=6144 base_num_kv_heads=24 base_num_query_heads=24 base_mlp_dim=24576 base_num_decoder_layers=48\ base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH diff --git a/src/MaxText/configs/v4/52b.sh b/src/maxtext/configs/tpu/v4/52b.sh similarity index 84% rename from src/MaxText/configs/v4/52b.sh rename to src/maxtext/configs/tpu/v4/52b.sh index 03cb89fedf..26f26109b8 100644 --- a/src/MaxText/configs/v4/52b.sh +++ b/src/maxtext/configs/tpu/v4/52b.sh @@ -22,10 +22,10 @@ echo "Running 52b.sh" # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v4/52b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v4/52b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v4/52b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v4-384 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v4/52b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v4-384 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -55,7 +55,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml\ profiler=xplane enable_checkpointing=false steps=10\ ici_fsdp_parallelism=192 ici_tensor_parallelism=1 per_device_batch_size=7 remat_policy=full\ base_num_decoder_layers=32 base_emb_dim=12288 base_mlp_dim=49152 base_num_query_heads=32 base_num_kv_heads=32 learning_rate=1e-8\ diff --git a/src/MaxText/configs/v4/README.md b/src/maxtext/configs/tpu/v4/README.md similarity index 100% rename from src/MaxText/configs/v4/README.md rename to src/maxtext/configs/tpu/v4/README.md diff --git a/src/MaxText/configs/v5e/128b.sh b/src/maxtext/configs/tpu/v5e/128b.sh similarity index 82% rename from src/MaxText/configs/v5e/128b.sh rename to src/maxtext/configs/tpu/v5e/128b.sh index b0bd949d73..58f7f43888 100644 --- a/src/MaxText/configs/v5e/128b.sh +++ b/src/maxtext/configs/tpu/v5e/128b.sh @@ -8,10 +8,10 @@ echo "Running 128b.sh" # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5e/128b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5e/128b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5e/128b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5e/128b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -42,7 +42,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml\ steps=15 per_device_batch_size=1 enable_checkpointing=false\ remat_policy=qkv_proj_offloaded global_parameter_scale=128\ ici_fsdp_parallelism=16 ici_tensor_parallelism=16\ diff --git a/src/MaxText/configs/v5e/16b.sh b/src/maxtext/configs/tpu/v5e/16b.sh similarity index 81% rename from src/MaxText/configs/v5e/16b.sh rename to src/maxtext/configs/tpu/v5e/16b.sh index ede6164a04..8bcc4402d8 100644 --- a/src/MaxText/configs/v5e/16b.sh +++ b/src/maxtext/configs/tpu/v5e/16b.sh @@ -8,10 +8,10 @@ echo "Running 16b.sh" # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5e/16b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5e/16b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5e/16b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5e/16b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -41,7 +41,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml\ steps=15 per_device_batch_size=6 enable_checkpointing=false\ remat_policy=full global_parameter_scale=16\ max_target_length=2048 base_output_directory=$OUTPUT_PATH\ diff --git a/src/MaxText/configs/v5e/32b.sh b/src/maxtext/configs/tpu/v5e/32b.sh similarity index 81% rename from src/MaxText/configs/v5e/32b.sh rename to src/maxtext/configs/tpu/v5e/32b.sh index dcb6c29a07..82659b3539 100644 --- a/src/MaxText/configs/v5e/32b.sh +++ b/src/maxtext/configs/tpu/v5e/32b.sh @@ -8,10 +8,10 @@ echo "Running 32b.sh" # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5e/32b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5e/32b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5e/32b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5e/32b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -41,7 +41,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml\ steps=15 per_device_batch_size=4 enable_checkpointing=false\ remat_policy=full global_parameter_scale=32\ max_target_length=2048 base_output_directory=$OUTPUT_PATH\ diff --git a/src/MaxText/configs/v5e/64b.sh b/src/maxtext/configs/tpu/v5e/64b.sh similarity index 81% rename from src/MaxText/configs/v5e/64b.sh rename to src/maxtext/configs/tpu/v5e/64b.sh index 0239e11cbe..1cca87de65 100644 --- a/src/MaxText/configs/v5e/64b.sh +++ b/src/maxtext/configs/tpu/v5e/64b.sh @@ -8,10 +8,10 @@ echo "Running 64b.sh" # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5e/64b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5e/64b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5e/64b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5e/64b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -41,7 +41,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml\ steps=15 per_device_batch_size=2 enable_checkpointing=false\ remat_policy=full global_parameter_scale=64\ max_target_length=2048 base_output_directory=$OUTPUT_PATH\ diff --git a/src/MaxText/configs/v5e/README.md b/src/maxtext/configs/tpu/v5e/README.md similarity index 100% rename from src/MaxText/configs/v5e/README.md rename to src/maxtext/configs/tpu/v5e/README.md diff --git a/src/MaxText/configs/v5e/gpt3_175b.sh b/src/maxtext/configs/tpu/v5e/gpt3_175b.sh similarity index 80% rename from src/MaxText/configs/v5e/gpt3_175b.sh rename to src/maxtext/configs/tpu/v5e/gpt3_175b.sh index cc29ee174a..a60294209f 100644 --- a/src/MaxText/configs/v5e/gpt3_175b.sh +++ b/src/maxtext/configs/tpu/v5e/gpt3_175b.sh @@ -7,10 +7,10 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5e/gpt3_175b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5e/gpt3_175b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5e/gpt3_175b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=8 +# bash src/maxtext/configs/tpu/v5e/gpt3_175b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=8 # Stop execution if any command exits with error @@ -41,7 +41,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_tpu_spmd_rng_bit_generator_unsafe=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=gpt3-175b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=gpt3-175b\ steps=15 per_device_batch_size=0.5 enable_checkpointing=false\ remat_policy=full ici_fsdp_parallelism=-1 ici_tensor_parallelism=16\ max_target_length=2048 base_output_directory=$OUTPUT_PATH\ diff --git a/src/MaxText/configs/v5e/llama2_13b.sh b/src/maxtext/configs/tpu/v5e/llama2_13b.sh similarity index 80% rename from src/MaxText/configs/v5e/llama2_13b.sh rename to src/maxtext/configs/tpu/v5e/llama2_13b.sh index 0604d97220..47937688a0 100644 --- a/src/MaxText/configs/v5e/llama2_13b.sh +++ b/src/maxtext/configs/tpu/v5e/llama2_13b.sh @@ -7,10 +7,10 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5e/llama2_13b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5e/llama2_13b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5e/llama2_13b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5e/llama2_13b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -41,7 +41,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=llama2-13b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=llama2-13b\ base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=8 remat_policy=qkv_proj_offloaded\ steps=15 enable_checkpointing=false use_iota_embed=true diff --git a/src/MaxText/configs/v5e/llama2_70b.sh b/src/maxtext/configs/tpu/v5e/llama2_70b.sh similarity index 80% rename from src/MaxText/configs/v5e/llama2_70b.sh rename to src/maxtext/configs/tpu/v5e/llama2_70b.sh index bf2cb73d62..825275aef4 100644 --- a/src/MaxText/configs/v5e/llama2_70b.sh +++ b/src/maxtext/configs/tpu/v5e/llama2_70b.sh @@ -7,10 +7,10 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5e/llama2_70b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5e/llama2_70b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5e/llama2_70b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5e/llama2_70b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -41,7 +41,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=llama2-70b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=llama2-70b\ base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=2 remat_policy=qkv_proj_offloaded\ steps=15 enable_checkpointing=false use_iota_embed=true diff --git a/src/maxtext/inference/configs/multi_host/interleaved/llama2_70b_v5e-16.yml b/src/maxtext/configs/tpu/v5e/llama2_70b_v5e-16.yml similarity index 97% rename from src/maxtext/inference/configs/multi_host/interleaved/llama2_70b_v5e-16.yml rename to src/maxtext/configs/tpu/v5e/llama2_70b_v5e-16.yml index dae967e20e..121efe248e 100644 --- a/src/maxtext/inference/configs/multi_host/interleaved/llama2_70b_v5e-16.yml +++ b/src/maxtext/configs/tpu/v5e/llama2_70b_v5e-16.yml @@ -1,4 +1,4 @@ -base_config: "inference_jetstream.yml" +base_config: "inference/inference_jetstream.yml" # tensor = 8, autoregressive=2 # per_device_batch_size=6 diff --git a/src/MaxText/configs/v5e/llama2_7b.sh b/src/maxtext/configs/tpu/v5e/llama2_7b.sh similarity index 80% rename from src/MaxText/configs/v5e/llama2_7b.sh rename to src/maxtext/configs/tpu/v5e/llama2_7b.sh index 3fa110e03f..80af3c58e1 100644 --- a/src/MaxText/configs/v5e/llama2_7b.sh +++ b/src/maxtext/configs/tpu/v5e/llama2_7b.sh @@ -7,10 +7,10 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5e/llama2_7b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5e/llama2_7b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5e/llama2_7b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5e/llama2_7b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5e-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -41,7 +41,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=llama2-7b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=llama2-7b\ base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=4 remat_policy=save_qkv_proj\ steps=15 enable_checkpointing=false use_iota_embed=true \ No newline at end of file diff --git a/src/MaxText/configs/v5e/llama3_405b_v5e-64.yml b/src/maxtext/configs/tpu/v5e/llama3_405b_v5e-64.yml similarity index 97% rename from src/MaxText/configs/v5e/llama3_405b_v5e-64.yml rename to src/maxtext/configs/tpu/v5e/llama3_405b_v5e-64.yml index 0e874c5cab..b91bb85fb3 100644 --- a/src/MaxText/configs/v5e/llama3_405b_v5e-64.yml +++ b/src/maxtext/configs/tpu/v5e/llama3_405b_v5e-64.yml @@ -1,4 +1,4 @@ -base_config: "inference_jetstream.yml" +base_config: "inference/inference_jetstream.yml" # v5e-64 # tensor = 8, autoregressive=8 diff --git a/src/MaxText/configs/v5e/llama3_70b_v5e-16.yml b/src/maxtext/configs/tpu/v5e/llama3_70b_v5e-16.yml similarity index 97% rename from src/MaxText/configs/v5e/llama3_70b_v5e-16.yml rename to src/maxtext/configs/tpu/v5e/llama3_70b_v5e-16.yml index e7cc70310f..b3ca2d1465 100644 --- a/src/MaxText/configs/v5e/llama3_70b_v5e-16.yml +++ b/src/maxtext/configs/tpu/v5e/llama3_70b_v5e-16.yml @@ -1,4 +1,4 @@ -base_config: "inference_jetstream.yml" +base_config: "inference/inference_jetstream.yml" # tensor = 8, autoregressive=2 # per_device_batch_size=6 diff --git a/src/MaxText/configs/v5p/1024b.sh b/src/maxtext/configs/tpu/v5p/1024b.sh similarity index 83% rename from src/MaxText/configs/v5p/1024b.sh rename to src/maxtext/configs/tpu/v5p/1024b.sh index 22850295c3..b628da1916 100644 --- a/src/MaxText/configs/v5p/1024b.sh +++ b/src/maxtext/configs/tpu/v5p/1024b.sh @@ -8,10 +8,10 @@ echo "Running 1024b.sh" # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5p/1024b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5p/1024b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5p/1024b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-2048 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5p/1024b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-2048 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -41,7 +41,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml\ steps=15 per_device_batch_size=2 enable_checkpointing=false\ remat_policy=full global_parameter_scale=1024\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=16\ diff --git a/src/MaxText/configs/v5p/128b.sh b/src/maxtext/configs/tpu/v5p/128b.sh similarity index 83% rename from src/MaxText/configs/v5p/128b.sh rename to src/maxtext/configs/tpu/v5p/128b.sh index 19b7644b2b..ea853fbab8 100644 --- a/src/MaxText/configs/v5p/128b.sh +++ b/src/maxtext/configs/tpu/v5p/128b.sh @@ -8,10 +8,10 @@ echo "Running 128b.sh" # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5p/128b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5p/128b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5p/128b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5p/128b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-256 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -41,7 +41,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml\ steps=15 per_device_batch_size=1 enable_checkpointing=false\ remat_policy=minimal global_parameter_scale=128\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=8\ diff --git a/src/MaxText/configs/v5p/256b.sh b/src/maxtext/configs/tpu/v5p/256b.sh similarity index 83% rename from src/MaxText/configs/v5p/256b.sh rename to src/maxtext/configs/tpu/v5p/256b.sh index 49068be7f8..b66d142732 100644 --- a/src/MaxText/configs/v5p/256b.sh +++ b/src/maxtext/configs/tpu/v5p/256b.sh @@ -9,10 +9,10 @@ echo "Running 256b.sh" # PLATFORM (Optional, can be "gke" or "gce", default is "gce") # # Example to invoke this script: -# bash src/MaxText/configs/v5p/256b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5p/256b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5p/256b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-1024 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5p/256b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-1024 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -42,7 +42,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml\ steps=15 per_device_batch_size=1 enable_checkpointing=false\ remat_policy=minimal global_parameter_scale=256\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=8\ diff --git a/src/MaxText/configs/v5p/32b.sh b/src/maxtext/configs/tpu/v5p/32b.sh similarity index 84% rename from src/MaxText/configs/v5p/32b.sh rename to src/maxtext/configs/tpu/v5p/32b.sh index f5ba53b236..3fd96ddeb6 100644 --- a/src/MaxText/configs/v5p/32b.sh +++ b/src/maxtext/configs/tpu/v5p/32b.sh @@ -8,10 +8,10 @@ echo "Running 32b.sh" # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5p/32b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5p/32b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5p/32b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-128 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5p/32b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-128 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -41,7 +41,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true --xla_sc_disable_megacore_partitioning=true --xla_tpu_use_tc_device_shape_on_sc=true --xla_tpu_enable_sparse_core_collective_offload_all_gather=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=false" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml\ steps=15 per_device_batch_size=6 enable_checkpointing=false\ remat_policy=minimal global_parameter_scale=32\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=4\ diff --git a/src/MaxText/configs/v5p/512b.sh b/src/maxtext/configs/tpu/v5p/512b.sh similarity index 83% rename from src/MaxText/configs/v5p/512b.sh rename to src/maxtext/configs/tpu/v5p/512b.sh index 5422979cd1..c015ee6534 100644 --- a/src/MaxText/configs/v5p/512b.sh +++ b/src/maxtext/configs/tpu/v5p/512b.sh @@ -9,10 +9,10 @@ echo "Running 512b.sh" # PLATFORM (Optional, can be "gke" or "gce", default is "gce") # # Example to invoke this script: -# bash src/MaxText/configs/v5p/512b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5p/512b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5p/512b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-1024 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5p/512b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-1024 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -42,7 +42,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml\ steps=15 per_device_batch_size=2 enable_checkpointing=false\ remat_policy=full global_parameter_scale=512\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=8\ diff --git a/src/MaxText/configs/v5p/64b.sh b/src/maxtext/configs/tpu/v5p/64b.sh similarity index 83% rename from src/MaxText/configs/v5p/64b.sh rename to src/maxtext/configs/tpu/v5p/64b.sh index f9eaf64df4..33a689c01f 100644 --- a/src/MaxText/configs/v5p/64b.sh +++ b/src/maxtext/configs/tpu/v5p/64b.sh @@ -8,10 +8,10 @@ echo "Running 64b.sh" # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5p/64b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5p/64b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5p/64b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-128 M_COMPILE_TOPOLOGY_NUM_SLICES=2 +# bash src/maxtext/configs/tpu/v5p/64b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-128 M_COMPILE_TOPOLOGY_NUM_SLICES=2 # Stop execution if any command exits with error @@ -41,7 +41,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml\ steps=15 per_device_batch_size=3 enable_checkpointing=false\ remat_policy=minimal global_parameter_scale=64\ ici_fsdp_parallelism=-1 ici_tensor_parallelism=4\ diff --git a/src/MaxText/configs/v5p/README.md b/src/maxtext/configs/tpu/v5p/README.md similarity index 100% rename from src/MaxText/configs/v5p/README.md rename to src/maxtext/configs/tpu/v5p/README.md diff --git a/src/MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh b/src/maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh similarity index 85% rename from src/MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh rename to src/maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh index 7fe0931108..197a407b50 100644 --- a/src/MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh +++ b/src/maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh @@ -5,10 +5,10 @@ # directory to actually run the training. # Example to invoke this script for compilation -# ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 4 "full" 1 64 8 "some_run" "gs://some_bucket" "train_compile.py" "v5p-1024" 1 +# ./maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh 4 "full" 1 64 8 "some_run" "gs://some_bucket" "train_compile.py" "v5p-1024" 1 # Example to invoke this script for training -# ./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 4 "full" 1 64 8 "some_run" "gs://some_bucket" +# ./maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh 4 "full" 1 64 8 "some_run" "gs://some_bucket" set -euox pipefail @@ -40,7 +40,7 @@ if [[ "$EXECUTABLE" == "train_compile" ]]; then COMPILE_TOPOLOGY=${9} COMPILE_TOPOLOGY_NUM_SLICES=${10} - python3 -m MaxText."$EXECUTABLE" "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b\ + python3 -m MaxText."$EXECUTABLE" "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name="${RUNNAME}" model_name=gpt3-175b\ base_output_directory="${BASE_OUTPUT_DIRECTORY}"\ enable_checkpointing=false async_checkpointing=false\ steps=20\ @@ -53,7 +53,7 @@ if [[ "$EXECUTABLE" == "train_compile" ]]; then compile_topology="${COMPILE_TOPOLOGY}"\ compile_topology_num_slices="${COMPILE_TOPOLOGY_NUM_SLICES}" else - python3 -m MaxText."$EXECUTABLE" "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name="${RUNNAME}" model_name=gpt3-175b\ + python3 -m MaxText."$EXECUTABLE" "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name="${RUNNAME}" model_name=gpt3-175b\ base_output_directory="${BASE_OUTPUT_DIRECTORY}"\ enable_checkpointing=false async_checkpointing=false\ steps=20\ diff --git a/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_1024.sh b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_1024.sh new file mode 100644 index 0000000000..777ebb2776 --- /dev/null +++ b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_1024.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# GPT-3 175B Model. +# Train GPT-3 175B on v5p-1024 slice. + +# Example to invoke this script: +# bash src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_1024.sh YOUR_RUN gs://YOUR_BUCKET" + +set -euox pipefail + +# Read arguments or use defaults from environment variables +RUNNAME=${1:-${RUNNAME:-some-run}} +BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} + +chmod +x "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/tpu/v5p/gpt3_175b/gpt3_175b_base.sh +./maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh 4 "full" 1 64 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" diff --git a/src/MaxText/configs/v5p/gpt3_175b/v5p_12288.sh b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh similarity index 68% rename from src/MaxText/configs/v5p/gpt3_175b/v5p_12288.sh rename to src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh index 49d51a769e..926dff0abc 100644 --- a/src/MaxText/configs/v5p/gpt3_175b/v5p_12288.sh +++ b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh @@ -3,7 +3,7 @@ # Train GPT-3 175B on v5p-12288 slice, with a custom topology of 8x16x48. # Example to invoke this script: -# bash src/MaxText/configs/v5p/gpt3_175b/v5p_12288.sh YOUR_RUN gs://YOUR_BUCKET" +# bash src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_12288.sh YOUR_RUN gs://YOUR_BUCKET" set -euox pipefail @@ -12,4 +12,4 @@ RUNNAME=${1:-${RUNNAME:-some-run}} BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} chmod +x "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/v5p/gpt3_175b/gpt3_175b_base.sh -./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 16 48 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file +./maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 16 48 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file diff --git a/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_2048.sh b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_2048.sh new file mode 100644 index 0000000000..955ecae377 --- /dev/null +++ b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_2048.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# GPT-3 175B Model. +# Train GPT-3 175B on v5p-2048 slice. + +# Example to invoke this script: +# bash src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_2048.sh YOUR_RUN gs://YOUR_BUCKET" + +set -euox pipefail + +# Read arguments or use defaults from environment variables +RUNNAME=${1:-${RUNNAME:-some-run}} +BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} + +chmod +x "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/tpu/v5p/gpt3_175b/gpt3_175b_base.sh +./maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh 2 "save_dot_except_mlpwi" 8 16 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file diff --git a/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_3072.sh b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_3072.sh new file mode 100644 index 0000000000..53635e155c --- /dev/null +++ b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_3072.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# GPT-3 175B Model. +# Train GPT-3 175B on v5p-3072 slice. + +# Example to invoke this script: +# bash src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_3072.sh YOUR_RUN gs://YOUR_BUCKET" + +set -euox pipefail + +# Read arguments or use defaults from environment variables +RUNNAME=${1:-${RUNNAME:-some-run}} +BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} + +chmod +x "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/tpu/v5p/gpt3_175b/gpt3_175b_base.sh +./maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh 2 "save_dot_except_mlpwi" 12 16 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file diff --git a/src/MaxText/configs/v5p/gpt3_175b/v5p_4096.sh b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_4096.sh similarity index 50% rename from src/MaxText/configs/v5p/gpt3_175b/v5p_4096.sh rename to src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_4096.sh index 0e7f185650..59ec31a986 100644 --- a/src/MaxText/configs/v5p/gpt3_175b/v5p_4096.sh +++ b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_4096.sh @@ -3,7 +3,7 @@ # Train GPT-3 175B on v5p-4096 slice, with a custom topology of 4x8x64. # Example to invoke this script: -# bash src/MaxText/configs/v5p/gpt3_175b/v5p_4096.sh YOUR_RUN gs://YOUR_BUCKET" +# bash src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_4096.sh YOUR_RUN gs://YOUR_BUCKET" set -euox pipefail @@ -11,5 +11,5 @@ set -euox pipefail RUNNAME=${1:-${RUNNAME:-some-run}} BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} -chmod +x "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/v5p/gpt3_175b/gpt3_175b_base.sh -./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 4 64 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file +chmod +x "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/tpu/v5p/gpt3_175b/gpt3_175b_base.sh +./maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 4 64 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file diff --git a/src/MaxText/configs/v5p/gpt3_175b/v5p_8192.sh b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_8192.sh similarity index 50% rename from src/MaxText/configs/v5p/gpt3_175b/v5p_8192.sh rename to src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_8192.sh index 9ebce2b0a5..e3abac9565 100644 --- a/src/MaxText/configs/v5p/gpt3_175b/v5p_8192.sh +++ b/src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_8192.sh @@ -3,7 +3,7 @@ # Train GPT-3 175B on v5p-8192 slice, with a custom topology of 8x16x32. # Example to invoke this script: -# bash src/MaxText/configs/v5p/gpt3_175b/v5p_8192.sh YOUR_RUN gs://YOUR_BUCKET" +# bash src/maxtext/configs/tpu/v5p/gpt3_175b/v5p_8192.sh YOUR_RUN gs://YOUR_BUCKET" set -euox pipefail @@ -11,5 +11,5 @@ set -euox pipefail RUNNAME=${1:-${RUNNAME:-some-run}} BASE_OUTPUT_DIRECTORY=${2:-${BASE_OUTPUT_DIRECTORY:-gs://some-bucket}} -chmod +x "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/v5p/gpt3_175b/gpt3_175b_base.sh -./MaxText/configs/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 16 32 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file +chmod +x "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/tpu/v5p/gpt3_175b/gpt3_175b_base.sh +./maxtext/configs/tpu/v5p/gpt3_175b/gpt3_175b_base.sh 1 "minimal" 16 32 8 "${RUNNAME}" "${BASE_OUTPUT_DIRECTORY}" \ No newline at end of file diff --git a/src/MaxText/configs/v5p/llama2_70b.sh b/src/maxtext/configs/tpu/v5p/llama2_70b.sh similarity index 83% rename from src/MaxText/configs/v5p/llama2_70b.sh rename to src/maxtext/configs/tpu/v5p/llama2_70b.sh index bf69b52cef..b233e8a607 100644 --- a/src/MaxText/configs/v5p/llama2_70b.sh +++ b/src/maxtext/configs/tpu/v5p/llama2_70b.sh @@ -9,10 +9,10 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5p/llama2_70b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5p/llama2_70b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5p/llama2_70b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-512 M_COMPILE_TOPOLOGY_NUM_SLICES=1 +# bash src/maxtext/configs/tpu/v5p/llama2_70b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-512 M_COMPILE_TOPOLOGY_NUM_SLICES=1 # Stop execution if any command exits with error @@ -44,7 +44,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=llama2-70b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=llama2-70b\ base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 remat_policy=save_dot_except_mlpwi per_device_batch_size=4\ steps=30 enable_checkpointing=false use_iota_embed=true max_target_length=4096\ diff --git a/src/MaxText/configs/v5p/llama2_7b.sh b/src/maxtext/configs/tpu/v5p/llama2_7b.sh similarity index 84% rename from src/MaxText/configs/v5p/llama2_7b.sh rename to src/maxtext/configs/tpu/v5p/llama2_7b.sh index 8d3e0a9206..af78271ba0 100644 --- a/src/MaxText/configs/v5p/llama2_7b.sh +++ b/src/maxtext/configs/tpu/v5p/llama2_7b.sh @@ -9,10 +9,10 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/v5p/llama2_7b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v5p/llama2_7b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # # Example to AOT compile: -# bash src/MaxText/configs/v5p/llama2_7b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-512 M_COMPILE_TOPOLOGY_NUM_SLICES=1 +# bash src/maxtext/configs/tpu/v5p/llama2_7b.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=v5p-512 M_COMPILE_TOPOLOGY_NUM_SLICES=1 # Stop execution if any command exits with error @@ -44,7 +44,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=llama2-7b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=llama2-7b\ base_output_directory=$OUTPUT_PATH dataset_path=${DATASET_PATH}\ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 remat_policy=minimal per_device_batch_size=4\ steps=30 enable_checkpointing=false use_iota_embed=true max_target_length=4096\ diff --git a/src/MaxText/configs/trillium/gemma2_27b.sh b/src/maxtext/configs/tpu/v6e/gemma2_27b.sh similarity index 85% rename from src/MaxText/configs/trillium/gemma2_27b.sh rename to src/maxtext/configs/tpu/v6e/gemma2_27b.sh index 4bbe89298b..c20928abe1 100644 --- a/src/MaxText/configs/trillium/gemma2_27b.sh +++ b/src/maxtext/configs/tpu/v6e/gemma2_27b.sh @@ -7,7 +7,7 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/trillium/gemma2_27b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v6e/gemma2_27b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # @@ -39,7 +39,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=gemma2-27b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=gemma2-27b\ steps=15 per_device_batch_size=2 enable_checkpointing=false\ remat_policy=full ici_fsdp_transpose_parallelism=256 ici_fsdp_parallelism=-1\ max_target_length=8192 base_output_directory=$OUTPUT_PATH\ diff --git a/src/MaxText/configs/trillium/gemma2_9b.sh b/src/maxtext/configs/tpu/v6e/gemma2_9b.sh similarity index 86% rename from src/MaxText/configs/trillium/gemma2_9b.sh rename to src/maxtext/configs/tpu/v6e/gemma2_9b.sh index c93481a7f1..90acccf01e 100644 --- a/src/MaxText/configs/trillium/gemma2_9b.sh +++ b/src/maxtext/configs/tpu/v6e/gemma2_9b.sh @@ -7,7 +7,7 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/trillium/gemma2_9b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v6e/gemma2_9b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # @@ -39,7 +39,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=114688 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=gemma2-9b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=gemma2-9b\ steps=15 per_device_batch_size=3 enable_checkpointing=false\ remat_policy=full ici_fsdp_transpose_parallelism=256 ici_fsdp_parallelism=-1\ max_target_length=8192 base_output_directory=$OUTPUT_PATH\ diff --git a/src/MaxText/configs/trillium/gemma3_27b.sh b/src/maxtext/configs/tpu/v6e/gemma3_27b.sh similarity index 85% rename from src/MaxText/configs/trillium/gemma3_27b.sh rename to src/maxtext/configs/tpu/v6e/gemma3_27b.sh index 4b7a2294b4..e5b3569fe3 100644 --- a/src/MaxText/configs/trillium/gemma3_27b.sh +++ b/src/maxtext/configs/tpu/v6e/gemma3_27b.sh @@ -7,7 +7,7 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/trillium/gemma3_27b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v6e/gemma3_27b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # @@ -39,7 +39,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=122880 --xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_assign_all_reduce_scatter_layout --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=gemma3-27b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=gemma3-27b\ steps=15 per_device_batch_size=2 enable_checkpointing=false\ remat_policy=full ici_fsdp_transpose_parallelism=256 ici_fsdp_parallelism=-1\ max_target_length=8192 base_output_directory=$OUTPUT_PATH\ diff --git a/src/MaxText/configs/trillium/gpt3_175b.sh b/src/maxtext/configs/tpu/v6e/gpt3_175b.sh similarity index 86% rename from src/MaxText/configs/trillium/gpt3_175b.sh rename to src/maxtext/configs/tpu/v6e/gpt3_175b.sh index cbb4af0ee8..85ab0668d1 100644 --- a/src/MaxText/configs/trillium/gpt3_175b.sh +++ b/src/maxtext/configs/tpu/v6e/gpt3_175b.sh @@ -7,7 +7,7 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/trillium/gpt3_175b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v6e/gpt3_175b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # @@ -39,7 +39,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=gpt3-175b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=gpt3-175b\ steps=15 per_device_batch_size=3 enable_checkpointing=false\ remat_policy=full ici_fsdp_parallelism=-1\ max_target_length=2048 base_output_directory=$OUTPUT_PATH\ diff --git a/src/MaxText/configs/v6e/inference/llama4_maverick_v6e-64.yml b/src/maxtext/configs/tpu/v6e/inference/llama4_maverick_v6e-64.yml similarity index 98% rename from src/MaxText/configs/v6e/inference/llama4_maverick_v6e-64.yml rename to src/maxtext/configs/tpu/v6e/inference/llama4_maverick_v6e-64.yml index 0aa2ae04e4..165500da8d 100644 --- a/src/MaxText/configs/v6e/inference/llama4_maverick_v6e-64.yml +++ b/src/maxtext/configs/tpu/v6e/inference/llama4_maverick_v6e-64.yml @@ -4,7 +4,7 @@ # tensor parallelism = 8, autoregressive parallelism = 8 # weight bf16, kv cache bf16 -base_config: "inference_jetstream.yml" +base_config: "inference/inference_jetstream.yml" sharding_strategy: "experimental" attention: 'dot_product' diff --git a/src/MaxText/configs/trillium/llama2_7b_4096.sh b/src/maxtext/configs/tpu/v6e/llama2_7b_4096.sh similarity index 84% rename from src/MaxText/configs/trillium/llama2_7b_4096.sh rename to src/maxtext/configs/tpu/v6e/llama2_7b_4096.sh index ac6818903b..587cf82946 100644 --- a/src/MaxText/configs/trillium/llama2_7b_4096.sh +++ b/src/maxtext/configs/tpu/v6e/llama2_7b_4096.sh @@ -7,7 +7,7 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/trillium/llama2_7b_4096.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v6e/llama2_7b_4096.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # @@ -39,7 +39,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=98304 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=llama2-7b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=llama2-7b\ steps=15 per_device_batch_size=12 enable_checkpointing=false\ remat_policy=full ici_fsdp_parallelism=-1\ max_target_length=4096 base_output_directory=$OUTPUT_PATH\ diff --git a/src/MaxText/configs/trillium/mixtral_8x7b.sh b/src/maxtext/configs/tpu/v6e/mixtral_8x7b.sh similarity index 84% rename from src/MaxText/configs/trillium/mixtral_8x7b.sh rename to src/maxtext/configs/tpu/v6e/mixtral_8x7b.sh index dddbc06dd5..8f5d51930a 100644 --- a/src/MaxText/configs/trillium/mixtral_8x7b.sh +++ b/src/maxtext/configs/tpu/v6e/mixtral_8x7b.sh @@ -7,7 +7,7 @@ # RUN_NAME (Required, unless run_name is already set in base.yml or running with XPK/GKE) # # Example to invoke this script: -# bash src/MaxText/configs/trillium/mixtral_8x7b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" +# bash src/maxtext/configs/tpu/v6e/mixtral_8x7b.sh RUN_NAME="" OUTPUT_PATH="gs://" DATASET_PATH="gs://" # @@ -39,7 +39,7 @@ fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_scoped_vmem_limit_kib=81920 --xla_enable_async_all_gather=true --xla_tpu_overlap_compute_collective_tc=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true" -python3 -m MaxText.$EXECUTABLE "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=mixtral-8x7b\ +python3 -m MaxText.$EXECUTABLE "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=mixtral-8x7b\ steps=15 per_device_batch_size=32 enable_checkpointing=false\ remat_policy=full ici_fsdp_parallelism=-1\ max_target_length=1024 base_output_directory=$OUTPUT_PATH\ diff --git a/src/MaxText/configs/types.py b/src/maxtext/configs/types.py similarity index 100% rename from src/MaxText/configs/types.py rename to src/maxtext/configs/types.py diff --git a/src/MaxText/configs/vllm.yml b/src/maxtext/configs/vllm.yml similarity index 100% rename from src/MaxText/configs/vllm.yml rename to src/maxtext/configs/vllm.yml diff --git a/src/maxtext/examples/demo_decoding.ipynb b/src/maxtext/examples/demo_decoding.ipynb index 4698a260a0..8475596567 100644 --- a/src/maxtext/examples/demo_decoding.ipynb +++ b/src/maxtext/examples/demo_decoding.ipynb @@ -214,7 +214,7 @@ "%%capture\n", "argv = [\n", " \"\", # This is a placeholder, it's not actually used by the script's logic\n", - " f\"{MAXTEXT_PKG_DIR}/configs/base.yml\",\n", + " f\"{MAXTEXT_CONFIGS_DIR}/base.yml\",\n", " f\"model_name={MODEL_NAME}\",\n", " f\"base_output_directory={MODEL_CHECKPOINT_PATH}\",\n", " f\"hf_access_token={HF_TOKEN}\",\n", @@ -252,7 +252,7 @@ "source": [ "%%capture\n", "config = pyconfig.initialize(\n", - " [\"\", f\"{MAXTEXT_PKG_DIR}/configs/base.yml\"],\n", + " [\"\", f\"{MAXTEXT_CONFIGS_DIR}/base.yml\"],\n", " per_device_batch_size=1.0,\n", " run_name=\"test\",\n", " max_target_length=4,\n", diff --git a/src/maxtext/examples/multimodal_gemma3_demo.ipynb b/src/maxtext/examples/multimodal_gemma3_demo.ipynb index 4df0157314..a7f9cdfdab 100644 --- a/src/maxtext/examples/multimodal_gemma3_demo.ipynb +++ b/src/maxtext/examples/multimodal_gemma3_demo.ipynb @@ -99,7 +99,7 @@ "outputs": [], "source": [ "!python3 -m MaxText.utils.ckpt_conversion.to_maxtext \\\n", - " $MAXTEXT_PKG_DIR/configs/base.yml \\\n", + " $MAXTEXT_CONFIGS_DIR/base.yml \\\n", " model_name=$MODEL_NAME \\\n", " hf_access_token=$HF_TOKEN \\\n", " base_output_directory=$MODEL_CHECKPOINT_PATH \\\n", @@ -121,7 +121,7 @@ "outputs": [], "source": [ "!python -m maxtext.decode \\\n", - " $MAXTEXT_PKG_DIR/configs/base.yml \\\n", + " $MAXTEXT_CONFIGS_DIR/base.yml \\\n", " model_name=$MODEL_NAME \\\n", " tokenizer_path=$MAXTEXT_ASSETS_ROOT/tokenizers/tokenizer.gemma3 \\\n", " load_parameters_path=$MODEL_CHECKPOINT_PATH/0/items \\\n", @@ -165,7 +165,7 @@ "PER_DEVICE_BATCH_SIZE=1\n", "\n", "!python -m MaxText.sft_trainer \\\n", - " $MAXTEXT_PKG_DIR/configs/sft-vision-chartqa.yml \\\n", + " $MAXTEXT_CONFIGS_DIR/sft-vision-chartqa.yml \\\n", " run_name=$WORKLOAD_NAME \\\n", " model_name=$MODEL_NAME \\\n", " tokenizer_path=$PRE_TRAINED_MODEL_TOKENIZER \\\n", diff --git a/src/maxtext/examples/rl_llama3_demo.ipynb b/src/maxtext/examples/rl_llama3_demo.ipynb index 4eacf0b34a..9471f53ab3 100644 --- a/src/maxtext/examples/rl_llama3_demo.ipynb +++ b/src/maxtext/examples/rl_llama3_demo.ipynb @@ -343,7 +343,7 @@ "## 📚 Learn More\n", "\n", "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/rl.html\n", - "- **Configuration**: See `src/MaxText/configs/rl.yml` for all available options\n", + "- **Configuration**: See `src/maxtext/configs/rl.yml` for all available options\n", "- **Documentation**: Check `src/MaxText/rl/train_rl.py` for the `rl_train` function implementation" ] } diff --git a/src/maxtext/examples/sft_llama3_demo.ipynb b/src/maxtext/examples/sft_llama3_demo.ipynb index 0b7dd227ce..f4e497cd50 100644 --- a/src/maxtext/examples/sft_llama3_demo.ipynb +++ b/src/maxtext/examples/sft_llama3_demo.ipynb @@ -145,15 +145,13 @@ "source": [ "import datetime\n", "import os\n", - "import subprocess\n", - "import sys\n", - "import MaxText\n", "from MaxText import pyconfig\n", + "from MaxText.globals import MAXTEXT_PKG_DIR\n", "from maxtext.trainers.post_train.sft import train_sft\n", "import jax\n", "from huggingface_hub import login\n", "\n", - "MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)\n", + "\n", "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" ] }, @@ -234,17 +232,40 @@ "outputs": [], "source": [ "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", - " # install torch for the conversion script\n", - " !python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu\n", - "\n", - " !JAX_PLATFORMS=cpu PYTHONPATH={MAXTEXT_PKG_DIR} {sys.executable} -m MaxText.utils.ckpt_conversion.to_maxtext \\\n", - " {MAXTEXT_PKG_DIR}/configs/base.yml \\\n", - " model_name={MODEL_NAME} \\\n", - " base_output_directory={MODEL_CHECKPOINT_PATH} \\\n", - " hf_access_token={HF_TOKEN} \\\n", - " use_multimodal=false \\\n", - " scan_layers=true \\\n", - " skip_jax_distributed_system=True\n", + " import subprocess\n", + " import sys\n", + "\n", + " # Install torch for the conversion script\n", + " print(\"Installing torch...\")\n", + " subprocess.run(\n", + " [\n", + " sys.executable, \"-m\", \"pip\", \"install\",\n", + " \"torch\", \"--index-url\", \"https://download.pytorch.org/whl/cpu\"\n", + " ],\n", + " check=True\n", + " )\n", + "\n", + " # Run checkpoint conversion with environment variables\n", + " print(\"Converting checkpoint from HuggingFace...\")\n", + " env = os.environ.copy()\n", + " env[\"JAX_PLATFORMS\"] = \"cpu\"\n", + " # env[\"PYTHONPATH\"] = MAXTEXT_PKG_DIR\n", + "\n", + " subprocess.run(\n", + " [\n", + " sys.executable,\n", + " \"-m\", \"MaxText.utils.ckpt_conversion.to_maxtext\",\n", + " f\"{MAXTEXT_PKG_DIR}/configs/base.yml\",\n", + " f\"model_name={MODEL_NAME}\",\n", + " f\"base_output_directory={MODEL_CHECKPOINT_PATH}\",\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " \"use_multimodal=false\",\n", + " \"scan_layers=true\",\n", + " \"skip_jax_distributed_system=True\",\n", + " ],\n", + " check=True,\n", + " env=env\n", + " )\n", "\n", "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", " raise ValueError(\"Model checkpoint conversion failed. Check the logs above.\")" @@ -268,7 +289,7 @@ "# Load configuration for SFT training\n", "config_argv = [\n", " \"\",\n", - " f\"{MAXTEXT_PKG_DIR}/configs/sft.yml\",\n", + " f\"{MAXTEXT_PKG_DIR}/configs/post_train/sft.yml\",\n", " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items\",\n", " f\"model_name={MODEL_NAME}\",\n", " \"steps=100\",\n", @@ -333,7 +354,7 @@ "## 📚 Learn More\n", "\n", "- **CLI Usage**: https://maxtext.readthedocs.io/en/latest/tutorials/posttraining/sft.html\n", - "- **Configuration**: See `src/MaxText/configs/sft.yml` for all available options\n", + "- **Configuration**: See `src/maxtext/configs/post_train/sft.yml` for all available options\n", "- **Documentation**: Check `src/MaxText/sft/sft_trainer.py` for the `sft_train` function implementation" ] } diff --git a/src/maxtext/examples/sft_qwen3_demo.ipynb b/src/maxtext/examples/sft_qwen3_demo.ipynb index 5adb71f93e..d7abf629b3 100644 --- a/src/maxtext/examples/sft_qwen3_demo.ipynb +++ b/src/maxtext/examples/sft_qwen3_demo.ipynb @@ -201,6 +201,7 @@ "from MaxText import pyconfig\n", "from maxtext.examples.sft_train_and_evaluate import evaluate_model, get_test_dataset\n", "from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter\n", + "from MaxText.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n", "from maxtext.trainers.post_train.sft import train_sft\n", "\n", "# Suppress vLLM logging with a severity level below ERROR\n", @@ -212,8 +213,6 @@ "from flax import nnx\n", "from huggingface_hub import login\n", "\n", - "MAXTEXT_PKG_DIR = os.path.dirname(MaxText.__file__)\n", - "MAXTEXT_REPO_ROOT = os.sep.join([\"maxtext\" if p == \"MaxText\" else p for p in MAXTEXT_PKG_DIR.split(os.sep)])\n", "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" ] }, @@ -350,7 +349,7 @@ "TEST_DATA_SPLIT = \"test\"\n", "HF_DATA_DIR = \"main\"\n", "TRAIN_DATA_COLUMNS = [\"question\", \"answer\"]\n", - "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/examples/chat_templates/math_qa.json\"\n", + "CHAT_TEMPLATE_PATH = f\"{MAXTEXT_REPO_ROOT}/src/maxtext/examples/chat_templates/math_qa.json\"\n", "if not os.path.exists(CHAT_TEMPLATE_PATH):\n", " raise FileNotFoundError(f\"Chat template not found: {CHAT_TEMPLATE_PATH}\")\n", "NUM_TEST_SAMPLES = 20 # Total number of samples to test\n", diff --git a/src/maxtext/examples/sft_train_and_evaluate.py b/src/maxtext/examples/sft_train_and_evaluate.py index 7263169362..e7c9c9008e 100644 --- a/src/maxtext/examples/sft_train_and_evaluate.py +++ b/src/maxtext/examples/sft_train_and_evaluate.py @@ -35,7 +35,7 @@ export MODEL_CHECKPOINT_PATH= export HF_ACCESS_TOKEN= -python3 -m maxtext.examples.sft_train_and_evaluate MaxText/configs/sft.yml \ +python3 -m maxtext.examples.sft_train_and_evaluate maxtext/configs/post_train/sft.yml \ run_name=$RUN_NAME base_output_directory=$OUTPUT_PATH \ model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH \ hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH @@ -67,7 +67,7 @@ --workload=sft-${RUN_NAME} \ --tpu-type ${TPU_TYPE} --num-slices=1 --zone=${ZONE} \ --project=${PROJECT} \ ---command "HF_TOKEN=$HF_ACCESS_TOKEN python3 -m maxtext.examples.sft_train_and_evaluate MaxText/configs/sft.yml \ +--command "HF_TOKEN=$HF_ACCESS_TOKEN python3 -m maxtext.examples.sft_train_and_evaluate maxtext/configs/post_train/sft.yml \ run_name=$RUN_NAME base_output_directory=$OUTPUT_PATH \ model_name=$MODEL_NAME load_parameters_path=$MODEL_CHECKPOINT_PATH \ hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH" diff --git a/src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh b/src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh index 0bb5bf26ab..e4b3163964 100755 --- a/src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh +++ b/src/maxtext/inference/gpu/microbenchmark_llama2-70b_h100-8.sh @@ -57,7 +57,7 @@ if [[ -z ${INFERENCE_LOG_FILE_PATH} ]] ; then export INFERENCE_LOG_FILE_PATH="${BASE_OUTPUT_DIRECTORY}/microbenchmark_llama2-70b_h100-8_results.txt" fi if [[ -z ${MAXENGINE_CONFIG_FILEPATH} ]] ; then - MAXENGINE_CONFIG_FILEPATH="$(dirname $0)/../../configs/inference.yml" + MAXENGINE_CONFIG_FILEPATH="$(dirname $0)/../../configs/inference/inference.yml" fi if [[ -z ${QUANTIZATION} ]] ; then QUANTIZATION="aqt_fp8" diff --git a/src/maxtext/inference/maxengine_server/maxengine_server_entrypoint.sh b/src/maxtext/inference/maxengine_server/maxengine_server_entrypoint.sh index 9b4d1b6f8a..3199c9a732 100644 --- a/src/maxtext/inference/maxengine_server/maxengine_server_entrypoint.sh +++ b/src/maxtext/inference/maxengine_server/maxengine_server_entrypoint.sh @@ -16,4 +16,4 @@ cd /maxtext python3 -m MaxText.maxengine_server \ -MaxText/configs/base.yml $@ +maxtext/configs/base.yml $@ diff --git a/src/maxtext/inference/mlperf/README.md b/src/maxtext/inference/mlperf/README.md index 82e23471ce..e87859e95e 100644 --- a/src/maxtext/inference/mlperf/README.md +++ b/src/maxtext/inference/mlperf/README.md @@ -100,7 +100,7 @@ export SAVE_QUANT_PARAMS_PATH=gs://${USER}-bkt/quantized/llama2-70b-chat # other tokenizers under src/maxtext/assets/ directory. export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets}}"'/tokenizer.llama2' cd maxtext && \ -python3 -m maxtext.decode src/MaxText/configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} +python3 -m maxtext.decode src/maxtext/configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=llama2-70b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=11 attention=dot_product quantization=int8 save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} ``` Your checkpoint is generated at `$SAVE_QUANT_PARAMS_PATH`. This is used to set `load_parameters_path` param below in `MAXENGINE_ARGS` env variable. @@ -125,7 +125,7 @@ export MODEL_SIZE=llama3.1-405b export QUANTIZE_TYPE=int8 cd maxtext && \ -python3 -m maxtext.checkpoint_conversion.load_and_quantize_checkpoint src/MaxText/configs/base.yml tokenizer_path=${TOKENIZER} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=${MODEL_SIZE} ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1 attention=dot_product quantization=${QUANTIZE_TYPE} save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} async_checkpointing=false +python3 -m maxtext.checkpoint_conversion.load_and_quantize_checkpoint src/maxtext/configs/base.yml tokenizer_path=${TOKENIZER} load_parameters_path=${LOAD_PARAMS_PATH} max_prefill_predict_length=1024 max_target_length=2048 model_name=${MODEL_SIZE} ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 scan_layers=false weight_dtype=bfloat16 per_device_batch_size=1 attention=dot_product quantization=${QUANTIZE_TYPE} save_quantized_params_path=${SAVE_QUANT_PARAMS_PATH} async_checkpointing=false ``` The quantized checkpoint is saved at `${SAVE_QUANT_PARAMS_PATH}` diff --git a/src/maxtext/inference/mlperf/gpu/benchmarks_llama2-70b-h100_8.sh b/src/maxtext/inference/mlperf/gpu/benchmarks_llama2-70b-h100_8.sh index 7bfd4cdc39..a541459404 100755 --- a/src/maxtext/inference/mlperf/gpu/benchmarks_llama2-70b-h100_8.sh +++ b/src/maxtext/inference/mlperf/gpu/benchmarks_llama2-70b-h100_8.sh @@ -64,7 +64,7 @@ export XLA_PYTHON_CLIENT_MEM_FRACTION=0.94 echo XLA_FLAGS: $XLA_FLAGS if [[ -z ${MAXENGINE_CONFIG_FILEPATH} ]] ; then - export MAXENGINE_CONFIG_FILEPATH="$(dirname $0)/../../configs/inference.yml" + export MAXENGINE_CONFIG_FILEPATH="$(dirname $0)/../../configs/inference/inference.yml" fi if [[ -z ${QUANTIZATION} ]] ; then diff --git a/src/maxtext/inference/mlperf/llama_offline_run.sh b/src/maxtext/inference/mlperf/llama_offline_run.sh index 52181195e3..9d2991df49 100755 --- a/src/maxtext/inference/mlperf/llama_offline_run.sh +++ b/src/maxtext/inference/mlperf/llama_offline_run.sh @@ -84,7 +84,7 @@ then fi if [[ -z ${MAXENGINE_CONFIG_FILEPATH} ]] ; then - MAXENGINE_CONFIG_FILEPATH="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml" + MAXENGINE_CONFIG_FILEPATH="${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml" fi diff --git a/src/maxtext/inference/mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh b/src/maxtext/inference/mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh index 5a80aa4c0a..900d84cf92 100644 --- a/src/maxtext/inference/mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh +++ b/src/maxtext/inference/mlperf/trillium/benchmarks_llama2-70b-trillium_2x4.sh @@ -74,7 +74,7 @@ if [[ -z ${QUANTIZATION} ]] ; then export QUANT_PATH="" # export QUANTIZATION="intmp" # export QUANT_MP="qkv_subchannel_512" -# export QUANT_PATH="/home/${USER}/maxtext/MaxText/configs/quantization/${QUANT_MP}.json" +# export QUANT_PATH="/home/${USER}/maxtext/maxtext/configs/quantization/${QUANT_MP}.json" fi if [[ -z ${KV_QUANT_DTYPE} ]] ; then diff --git a/src/maxtext/scratch_code/gemma_7b.sh b/src/maxtext/scratch_code/gemma_7b.sh index 0e69af9dc4..1c09ac0100 100644 --- a/src/maxtext/scratch_code/gemma_7b.sh +++ b/src/maxtext/scratch_code/gemma_7b.sh @@ -3,6 +3,6 @@ export M_PER_DEVICE_BATCH_SIZE=24 export M_MAX_PREFILL_PREDICT_LENGTH=1024 export M_MAX_TARGET_LENGTH=2048 -#python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false +#python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false -python3 -m MaxText.maxengine_server "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false \ No newline at end of file +python3 -m MaxText.maxengine_server "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma run_name=runner_2024-03-06-04-17 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=gemma-7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 scan_layers=false \ No newline at end of file diff --git a/src/maxtext/scratch_code/run_inference_microbenchmark.sh b/src/maxtext/scratch_code/run_inference_microbenchmark.sh index c9b6a878c4..989875c4d6 100644 --- a/src/maxtext/scratch_code/run_inference_microbenchmark.sh +++ b/src/maxtext/scratch_code/run_inference_microbenchmark.sh @@ -1,6 +1,6 @@ # llama2-7b python3 -m maxtext.inference.inference_microbenchmark \ -"${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \ +"${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ async_checkpointing=false \ attention=autoselected \ dataset_path=gs://maxtext-dataset \ diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 9d71c2029c..9f309cb909 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -15,11 +15,11 @@ """ SFT training script that calls a trainer in Tunix to run SFT on a MaxText model using `HuggingFaceH4/ultrachat_200k` dataset. The configurations for the dataset -are defined inside `src/MaxText/configs/sft.yml`. +are defined inside `src/maxtext/configs/post_train/sft.yml`. Example command: Training & Evaluation: - python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml \ + python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml \ run_name=$RUN_NAME base_output_directory=$BASE_OUTPUT_DIRECTORY \ model_name=$MODEL_NAME load_parameters_path=$CHECKPOINT_PATH \ hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH \ @@ -27,7 +27,7 @@ eval_interval=2 eval_steps=2 steps=10 profiler=xplane weight_dtype=bfloat16 Training: - python3 -m maxtext.trainers.post_train.sft.train_sft src/MaxText/configs/sft.yml \ + python3 -m maxtext.trainers.post_train.sft.train_sft src/maxtext/configs/post_train/sft.yml \ run_name=$RUN_NAME base_output_directory=$BASE_OUTPUT_DIRECTORY \ model_name=$MODEL_NAME load_parameters_path=$CHECKPOINT_PATH \ hf_access_token=$HF_ACCESS_TOKEN tokenizer_path=$TOKENIZER_PATH \ diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index bcd7d5ddd8..197b23fcdb 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -37,8 +37,8 @@ import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager from MaxText import sharding -from MaxText.configs import types from MaxText.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from maxtext.configs import types from maxtext.inference.page_manager import PageState from maxtext.common import checkpointing from maxtext.multimodal import processor as mm_processor diff --git a/src/maxtext/vllm_decode.py b/src/maxtext/vllm_decode.py index 67d2fa0179..bc27e13e34 100644 --- a/src/maxtext/vllm_decode.py +++ b/src/maxtext/vllm_decode.py @@ -15,7 +15,7 @@ An example script to perform decoding using vLLM via Tunix or via MaxText on vLLM. Example usage with Tunix: - python3 -m maxtext.vllm_decode MaxText/configs/base.yml \ + python3 -m maxtext.vllm_decode maxtext/configs/base.yml \ model_name=llama3.1-8b tokenizer_path=meta-llama/Llama-3.1-8B-Instruct \ tokenizer_type=huggingface hf_access_token= \ load_parameters_path= \ @@ -109,6 +109,7 @@ def decode_with_vllm( decode_sampling_nucleus_p: float, decode_sampling_top_k: float, debug_sharding: bool, + vllm_config_path: str | None = None, ) -> None: """Decode using vLLM with a MaxText model implementation. @@ -129,6 +130,7 @@ def decode_with_vllm( decode_sampling_temperature: Temperature for sampling. decode_sampling_nucleus_p: Nucleus sampling probability. decode_sampling_top_k: Top-k sampling probability. + vllm_config_path: Path to vLLM config file. Defaults to MAXTEXT_PKG_DIR/configs/vllm.yml. """ # Prepare vLLM Arguments diff --git a/tests/end_to_end/gpu/a3/test_convergence_125m_params.sh b/tests/end_to_end/gpu/a3/test_convergence_125m_params.sh index 7858dc263c..20f2fe0591 100644 --- a/tests/end_to_end/gpu/a3/test_convergence_125m_params.sh +++ b/tests/end_to_end/gpu/a3/test_convergence_125m_params.sh @@ -47,7 +47,7 @@ then CMD_DATA=" hf_path=parquet hf_data_files=gs://maxtext-dataset/hf/c4/c4-train-*.parquet dataset_type=hf tokenizer_path=${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/llama2-tokenizer" fi -TRAIN_CMD="python3 -m MaxText.train ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml run_name=$RUN_NAME hardware=gpu \ +TRAIN_CMD="python3 -m MaxText.train ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml run_name=$RUN_NAME hardware=gpu \ steps=$STEPS dcn_data_parallelism=1 learning_rate=3e-4 \ base_emb_dim=1024 base_num_query_heads=8 base_num_kv_heads=8 base_mlp_dim=3584 base_num_decoder_layers=8 \ ici_fsdp_parallelism=8 metrics_file=metrics.txt per_device_batch_size=4 \ diff --git a/tests/end_to_end/gpu/a3/test_convergence_1b_params.sh b/tests/end_to_end/gpu/a3/test_convergence_1b_params.sh index 4c49889ea6..4ea75426c2 100644 --- a/tests/end_to_end/gpu/a3/test_convergence_1b_params.sh +++ b/tests/end_to_end/gpu/a3/test_convergence_1b_params.sh @@ -47,7 +47,7 @@ then CMD_DATA=" hf_path=parquet hf_data_files=gs://maxtext-dataset/hf/c4/c4-train-*.parquet dataset_type=hf tokenizer_path=${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}/llama2-tokenizer" fi -TRAIN_CMD="python3 -m MaxText.train ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml run_name=$RUN_NAME hardware=gpu \ +TRAIN_CMD="python3 -m MaxText.train ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml run_name=$RUN_NAME hardware=gpu \ steps=$STEPS dcn_data_parallelism=1 learning_rate=3e-4 \ ici_fsdp_parallelism=8 metrics_file=metrics.txt per_device_batch_size=4 \ max_target_length=2048 enable_checkpointing=false attention=dot_product \ diff --git a/tests/end_to_end/gpu/a3/test_gemma3_logits.sh b/tests/end_to_end/gpu/a3/test_gemma3_logits.sh index e5ee235c6c..9bd61b2a8c 100644 --- a/tests/end_to_end/gpu/a3/test_gemma3_logits.sh +++ b/tests/end_to_end/gpu/a3/test_gemma3_logits.sh @@ -39,10 +39,10 @@ python3 -m MaxText.utils.ckpt_scripts.convert_gemma3_chkpt --base_model_path ${C # export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `src/MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. -#JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true +#JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true export UNSCANNED_CKPT_PATH=gs://runner-maxtext-logs/unscanned_chkpt_2025-04-16-00-01/checkpoints/0/items export NVTE_FUSED_ATTN=1 # # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} hardware=gpu attention=cudnn_flash_te per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} hardware=gpu attention=cudnn_flash_te per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 diff --git a/tests/end_to_end/gpu/a3/test_llama2_7b.sh b/tests/end_to_end/gpu/a3/test_llama2_7b.sh index 449bd6b234..b2ce7175a9 100644 --- a/tests/end_to_end/gpu/a3/test_llama2_7b.sh +++ b/tests/end_to_end/gpu/a3/test_llama2_7b.sh @@ -38,7 +38,7 @@ export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `src/MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint_${idx} -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' hardware=gpu async_checkpointing=${ASYNC_CHECKPOINTING} +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' hardware=gpu async_checkpointing=${ASYNC_CHECKPOINTING} export RUN_NAME="llama-2-1vm-$(date +%Y-%m-%d-%H-%M)" @@ -59,9 +59,9 @@ export XLA_FLAGS="--xla_dump_to=$BASE_OUTPUT_PATH/$RUN_NAME/HLO_dumps/ --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false --xla_disable_hlo_passes=rematerialization" -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=$RUN_NAME hardware=gpu steps=30 dcn_data_parallelism=1 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b enable_checkpointing=true attention=cudnn_flash_te remat_policy=minimal_with_context use_iota_embed=true scan_layers=false dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} base_output_directory=$BASE_OUTPUT_DIRECTORY +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME hardware=gpu steps=30 dcn_data_parallelism=1 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b enable_checkpointing=true attention=cudnn_flash_te remat_policy=minimal_with_context use_iota_embed=true scan_layers=false dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} base_output_directory=$BASE_OUTPUT_DIRECTORY export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 export TF_FORCE_GPU_ALLOW_GROWTH=true -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false hardware=gpu async_checkpointing=${ASYNC_CHECKPOINTING} +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false hardware=gpu async_checkpointing=${ASYNC_CHECKPOINTING} diff --git a/tests/end_to_end/gpu/mixtral/test_8x7b.sh b/tests/end_to_end/gpu/mixtral/test_8x7b.sh index 19e9d5a6d1..6ba351c835 100644 --- a/tests/end_to_end/gpu/mixtral/test_8x7b.sh +++ b/tests/end_to_end/gpu/mixtral/test_8x7b.sh @@ -25,7 +25,7 @@ fi export DATASET_PATH=gs://maxtext-dataset # Run pre-training - dropping implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=mixtral-8x7b hardware=gpu \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=mixtral-8x7b hardware=gpu \ base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \ run_name=dropping_pre_training async_checkpointing=false \ attention=cudnn_flash_te capacity_factor=1.25 dtype=bfloat16 \ @@ -36,7 +36,7 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT echo "Finished pre-training" # Run fine-tuning - dropping implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=mixtral-8x7b hardware=gpu \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=mixtral-8x7b hardware=gpu \ load_parameters_path=${SCANNED_CHECKPOINT} \ base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \ run_name=dropping_pre_training async_checkpointing=true \ @@ -49,7 +49,7 @@ echo "Finished fine-tuning" # # TODO(b/391864113): Add this once the bug is fixed # # Run decoding with converted ckpt - dropping implementation -# python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=mixtral-8x7b hardware=gpu \ +# python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=mixtral-8x7b hardware=gpu \ # run_name=unscanned_decoding load_parameters_path=${UNSCANNED_CKPT_PATH} \ # async_checkpointing=false attention=dot_product capacity_factor=0.1 \ # ici_expert_parallelism=8 ici_fsdp_parallelism=1 max_prefill_predict_length=11 \ diff --git a/tests/end_to_end/gpu/te/run_single_node_model_parallel.sh b/tests/end_to_end/gpu/te/run_single_node_model_parallel.sh index f62704a218..9d50ecdb8c 100755 --- a/tests/end_to_end/gpu/te/run_single_node_model_parallel.sh +++ b/tests/end_to_end/gpu/te/run_single_node_model_parallel.sh @@ -154,7 +154,7 @@ if [[ "$TRACE" == "true" ]]; then fi # Updating the model config file as we can't pass base_num_decoder_layers=1 in additional-args if [ -n "$NUM_DECODER_LAYERS" ]; then - MODEL_CONFIG="$MAXTEXT_DIR/MaxText/configs/models/$MODEL.yml" + MODEL_CONFIG="$MAXTEXT_DIR/maxtext/configs/models/$MODEL.yml" original_num_decoder_layers=$(grep "base_num_decoder_layers" "$MODEL_CONFIG" | awk -F': ' '{print $2}') sed -i "s/base_num_decoder_layers: .*/base_num_decoder_layers: $NUM_DECODER_LAYERS/" "$MODEL_CONFIG" echo "=== Setting base_num_decoder_layers=$NUM_DECODER_LAYERS in $MODEL_CONFIG" diff --git a/tests/end_to_end/gpu/test_collective_matmul_llama2_7b.sh b/tests/end_to_end/gpu/test_collective_matmul_llama2_7b.sh index d5caa9dcc4..7587301f1e 100755 --- a/tests/end_to_end/gpu/test_collective_matmul_llama2_7b.sh +++ b/tests/end_to_end/gpu/test_collective_matmul_llama2_7b.sh @@ -42,7 +42,7 @@ export XLA_FLAGS="--xla_dump_hlo_as_text --xla_gpu_multi_streamed_windowed_einsum=true" python3 -m MaxText.train \ - "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${MODEL} \ per_device_batch_size=0.125 \ steps=1 \ diff --git a/tests/end_to_end/gpu/test_fp8_gemm_llama2_7b.sh b/tests/end_to_end/gpu/test_fp8_gemm_llama2_7b.sh index 72b8f42ab6..5f25c5fae9 100755 --- a/tests/end_to_end/gpu/test_fp8_gemm_llama2_7b.sh +++ b/tests/end_to_end/gpu/test_fp8_gemm_llama2_7b.sh @@ -39,7 +39,7 @@ export XLA_FLAGS="--xla_dump_hlo_as_text --xla_disable_hlo_passes=rematerialization" python3 -m MaxText.train \ - "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${MODEL} \ quantization=fp8 \ per_device_batch_size=0.125 \ diff --git a/tests/end_to_end/test_checkpoint_compatibility.sh b/tests/end_to_end/test_checkpoint_compatibility.sh index 39c6c41c77..af134e4cd2 100644 --- a/tests/end_to_end/test_checkpoint_compatibility.sh +++ b/tests/end_to_end/test_checkpoint_compatibility.sh @@ -20,7 +20,7 @@ bash tools/setup/setup_gcsfuse.sh DATASET_GCS_BUCKET=$DATASET_PATH MOUNT_PATH=/t echo "Run_1: Starting the first run using the grain input pipeline" -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml run_name=$RUN_NAME steps=3 ${model_params}\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml run_name=$RUN_NAME steps=3 ${model_params}\ max_target_length=128 per_device_batch_size=1\ metrics_file=run_1_metrics.txt checkpoint_period=2 async_checkpointing=false\ dataset_path=/tmp/gcsfuse base_output_directory=$OUTPUT_PATH\ @@ -32,7 +32,7 @@ echo "Finished Run_1 at step 2" echo "Run_2: Resuming using the tfds input pipeline" echo -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml run_name=$RUN_NAME steps=5 ${model_params}\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml run_name=$RUN_NAME steps=5 ${model_params}\ max_target_length=128 per_device_batch_size=1 attention=$ATTENTION\ metrics_file=run_2_metrics.txt checkpoint_period=2 async_checkpointing=false\ dataset_path=/tmp/gcsfuse base_output_directory=$OUTPUT_PATH\ @@ -42,7 +42,7 @@ echo "Finished Run_2 at step 4" echo "Run_3: Resuming using the grain input pipeline" echo -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml run_name=$RUN_NAME steps=7 ${model_params}\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml run_name=$RUN_NAME steps=7 ${model_params}\ max_target_length=128 per_device_batch_size=1\ metrics_file=run_3_metrics.txt checkpoint_period=2 async_checkpointing=false\ dataset_path=/tmp/gcsfuse base_output_directory=$OUTPUT_PATH\ diff --git a/tests/end_to_end/test_checkpointing.sh b/tests/end_to_end/test_checkpointing.sh index b596ebe976..bb9b2ba57b 100644 --- a/tests/end_to_end/test_checkpointing.sh +++ b/tests/end_to_end/test_checkpointing.sh @@ -36,7 +36,7 @@ then fi # This command runs training for some steps and saves a checkpoint. -CMD1="python3 -m MaxText.train ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml run_name=$RUN_NAME steps=5 max_target_length=128 per_device_batch_size=1\ +CMD1="python3 -m MaxText.train ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml run_name=$RUN_NAME steps=5 max_target_length=128 per_device_batch_size=1\ metrics_file=saved_metrics.txt checkpoint_period=3 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ async_checkpointing=$ASYNC_CHECKPOINTING collect_stack_trace=$COLLECT_STACK_TRACE attention=$ATTENTION" CMD1+=$model_params @@ -44,7 +44,7 @@ CMD1+=$CMD_DATA # This command restores the checkpoint from the previous run and continue training from the restored checkpoint. # This ensures actual new training steps are executed after restoring checkpoint from the above training run. -CMD2="python3 -m MaxText.train ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml run_name=$RUN_NAME steps=10 max_target_length=128 per_device_batch_size=1\ +CMD2="python3 -m MaxText.train ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml run_name=$RUN_NAME steps=10 max_target_length=128 per_device_batch_size=1\ metrics_file=restored_metrics.txt base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ async_checkpointing=$ASYNC_CHECKPOINTING collect_stack_trace=$COLLECT_STACK_TRACE attention=$ATTENTION" CMD2+=$model_params diff --git a/tests/end_to_end/test_generate_param_only_checkpoint.sh b/tests/end_to_end/test_generate_param_only_checkpoint.sh index 4c7b9381ed..d9ac2c7a88 100644 --- a/tests/end_to_end/test_generate_param_only_checkpoint.sh +++ b/tests/end_to_end/test_generate_param_only_checkpoint.sh @@ -58,7 +58,7 @@ model_params="base_emb_dim=384 base_num_query_heads=8 base_num_kv_heads=8 base_m echo echo "Create a test training checkpoint" echo -$cmd python3 -m MaxText.train ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml \ +$cmd python3 -m MaxText.train ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml \ run_name=${training_ckpt_run_id} \ base_output_directory=${base_output_directory} \ dataset_path=${dataset_path} attention=${attention} \ @@ -82,7 +82,7 @@ echo echo "Generate a decode checkpoint from the test training checkpoint" echo -$cmd python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +$cmd python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ run_name=${decode_ckpt_run_id} attention=${attention} \ base_output_directory=${base_output_directory} \ dataset_path=${dataset_path} async_checkpointing=false \ @@ -104,7 +104,7 @@ fi echo echo "Run decode using the generated checkpoint" echo -$cmd python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +$cmd python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ run_name=${run_id}-decode-steps-50 \ base_output_directory=${base_output_directory} \ dataset_path=${dataset_path} \ diff --git a/tests/end_to_end/test_mtc_phase_2_save_path.sh b/tests/end_to_end/test_mtc_phase_2_save_path.sh index bf53a48781..2239022d47 100644 --- a/tests/end_to_end/test_mtc_phase_2_save_path.sh +++ b/tests/end_to_end/test_mtc_phase_2_save_path.sh @@ -9,5 +9,5 @@ export TPU_PREMAPPED_BUFFER_SIZE=20000014336 export TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES=20000014336 # Train and save checkpoint -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"/configs/base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ steps=100 checkpoint_period=200 run_name=$RUN_NAME local_checkpoint_directory=/local local_checkpoint_period=20 enable_multi_tier_checkpointing=True multi_tier_checkpointing_backup_interval_minutes=5 metrics_file='saved_metrics.txt' diff --git a/tests/end_to_end/test_multi_tier_checkpointing.sh b/tests/end_to_end/test_multi_tier_checkpointing.sh index 258941e71e..68600f1388 100644 --- a/tests/end_to_end/test_multi_tier_checkpointing.sh +++ b/tests/end_to_end/test_multi_tier_checkpointing.sh @@ -9,11 +9,11 @@ export TPU_PREMAPPED_BUFFER_SIZE=20000014336 export TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES=20000014336 # Train and save checkpoint -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ steps=100 enable_emergency_checkpoint=true checkpoint_period=200 local_checkpoint_directory=/local local_checkpoint_period=20 run_name=$RUN_NAME metrics_file='saved_metrics.txt' # Retrieve checkpoint -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml remat_policy=full base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ steps=110 enable_emergency_checkpoint=true checkpoint_period=200 local_checkpoint_directory=/local local_checkpoint_period=20 run_name=$RUN_NAME metrics_file='restored_metrics.txt' diff --git a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md index abfebf43d3..1af5fe3ed6 100644 --- a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md +++ b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md @@ -34,7 +34,7 @@ DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. The You can train from scratch to generate a new checkpoint. One example command to run pretraining with V3 on v5p-256. ```sh -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ run_name=matmul_pre_training \ per_device_batch_size=4 \ @@ -68,7 +68,7 @@ After you have a MaxText compatible checkpoint, you could fine-tune it with diff One example command to run general finetuning with V3 on v5p-256. ```sh -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ dataset_path=${DATASET_PATH} \ load_parameters_path=${CONVERTED_CHECKPOINT} \ @@ -93,7 +93,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \ Fine-tuning with MTP on v5p-256 ```sh -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ base_output_directory=gs://your-output-bucket/ \ dataset_path=gs://your-dataset-bucket/ \ load_parameters_path=gs://your-bucket/deepseek-v3/0/items \ @@ -114,7 +114,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \ One example command to run supervised finetuning with V3 on v5p-256. Supervised fine-tuning is only working with HuggingFace conversational datasets. And, you can customize the dataset path using the `hf_path` config and provide your access token with `hf_access_token` config. ```sh -python3 -m MaxText.sft_trainer src/MaxText/configs/sft.yml \ +python3 -m MaxText.sft_trainer src/maxtext/configs/post_train/sft.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ load_parameters_path=${CONVERTED_CHECKPOINT} \ run_name=matmul_supervised_fine_tuning \ @@ -140,7 +140,7 @@ python3 -m MaxText.sft_trainer src/MaxText/configs/sft.yml \ One example command to run decoding with V3 on v5p-256 with unscanned checkpoint for fast decoding. ```sh -python3 -m maxtext.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/maxtext/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ load_parameters_path=${CONVERTED_CHECKPOINT} \ run_name=decode \ @@ -188,7 +188,7 @@ Run command below to compare logits between HuggingFace and MaxText. ```sh python3 -m tests.utils.forward_pass_logit_checker \ - src/MaxText/configs/base.yml \ + src/maxtext/configs/base.yml \ tokenizer_type=huggingface \ tokenizer_path=deepseek-ai/DeepSeek-V2-Lite \ load_parameters_path=${CONVERTED_CHECKPOINT} \ diff --git a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh index 989952466e..8c791f1814 100644 --- a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh +++ b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh @@ -65,14 +65,14 @@ if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION} fi -python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=4 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1e-4 --rtol=1e-4 --max_kl_div=5e-6 +python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=4 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1e-4 --rtol=1e-4 --max_kl_div=5e-6 # Run pre-training - tokamax_gmm implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=4 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=4 # Run fine-tuning - tokamax_gmm implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 checkpoint_storage_concurrent_gb=1024 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 checkpoint_storage_concurrent_gb=1024 # Run decoding - tokamax_gmm implementation # Note decode requires the access token for huggingface tokenizer even if the model is not gated -python3 -m maxtext.decode ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=4 ici_expert_parallelism=1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " +python3 -m maxtext.decode ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=4 ici_expert_parallelism=1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " diff --git a/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh index e59c7be5b3..8cc2bb87c4 100644 --- a/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh +++ b/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh @@ -49,8 +49,8 @@ if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION} fi -python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1.5 --rtol=1.5 --max_kl_div=0.1 +python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1.5 --rtol=1.5 --max_kl_div=0.1 # Run decoding - tokamax_gmm implementation # Note decode requires the access token for huggingface tokenizer even if the model is not gated -python3 -m maxtext.decode ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " +python3 -m maxtext.decode ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " diff --git a/tests/end_to_end/tpu/deepseek/v3-671b/test_deepseek_mtp.sh b/tests/end_to_end/tpu/deepseek/v3-671b/test_deepseek_mtp.sh index 875e2b3e74..e405e121c4 100644 --- a/tests/end_to_end/tpu/deepseek/v3-671b/test_deepseek_mtp.sh +++ b/tests/end_to_end/tpu/deepseek/v3-671b/test_deepseek_mtp.sh @@ -33,7 +33,7 @@ export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs # Run fine-tuning with MTP enabled # We add `mtp_num_layers=1` and `mtp_loss_scaling_factor=0.1` to activate the MTP block. -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ dataset_path=${DATASET_PATH} \ load_parameters_path=${CONVERTED_CHECKPOINT} \ diff --git a/tests/end_to_end/tpu/gemma/2b/test_gemma.sh b/tests/end_to_end/tpu/gemma/2b/test_gemma.sh index 5d48de944f..16aeba4751 100644 --- a/tests/end_to_end/tpu/gemma/2b/test_gemma.sh +++ b/tests/end_to_end/tpu/gemma/2b/test_gemma.sh @@ -33,37 +33,37 @@ export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items export RUN_NAME=unscanned_chkpt_${idx} # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `src/MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-2b' force_unroll=true +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-2b' force_unroll=true export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune_${idx} -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-2b +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-2b # Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt_${idx} -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-2b' force_unroll=true +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-2b' force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Gemma-2b -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2b per_device_batch_size=1 model_name=gemma-2b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false attention=dot_product --max_kl_div=0.015 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2b per_device_batch_size=1 model_name=gemma-2b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false attention=dot_product --max_kl_div=0.015 # We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. # This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 2B. # To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. -python3 -m MaxText.train_compile "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=gemma-2b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 +python3 -m MaxText.train_compile "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=gemma-2b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 diff --git a/tests/end_to_end/tpu/gemma/7b/1_test_gemma.sh b/tests/end_to_end/tpu/gemma/7b/1_test_gemma.sh index 917ea6d73c..3adcc3d1f5 100644 --- a/tests/end_to_end/tpu/gemma/7b/1_test_gemma.sh +++ b/tests/end_to_end/tpu/gemma/7b/1_test_gemma.sh @@ -39,5 +39,5 @@ export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `src/MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. export RUN_NAME=unscanned_chkpt -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-7b' force_unroll=true skip_jax_distributed_system=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-7b' force_unroll=true skip_jax_distributed_system=true echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items" diff --git a/tests/end_to_end/tpu/gemma/7b/2_test_gemma.sh b/tests/end_to_end/tpu/gemma/7b/2_test_gemma.sh index c6c8d2d678..8b1fa83841 100644 --- a/tests/end_to_end/tpu/gemma/7b/2_test_gemma.sh +++ b/tests/end_to_end/tpu/gemma/7b/2_test_gemma.sh @@ -39,28 +39,28 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items export ASYNC_CHECKPOINTING=True # True so that the jax distributed system is initialized # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b checkpoint_period=5 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=gemma-7b checkpoint_period=5 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b # Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-7b' force_unroll=true +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-7b' force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" # We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. # This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 7B. # To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. -python3 -m MaxText.train_compile "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=gemma-7b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 +python3 -m MaxText.train_compile "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=gemma-7b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 diff --git a/tests/end_to_end/tpu/gemma2/27b/1_test_gemma.sh b/tests/end_to_end/tpu/gemma2/27b/1_test_gemma.sh index 6425e1617f..68be2d1bae 100644 --- a/tests/end_to_end/tpu/gemma2/27b/1_test_gemma.sh +++ b/tests/end_to_end/tpu/gemma2/27b/1_test_gemma.sh @@ -42,5 +42,5 @@ export RUN_NAME=unscanned_chkpt export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=gemma2-27b force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=gemma2-27b force_unroll=true echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/${RUN_NAME}/checkpoints/0/items" diff --git a/tests/end_to_end/tpu/gemma2/27b/2_test_gemma.sh b/tests/end_to_end/tpu/gemma2/27b/2_test_gemma.sh index bed4444cca..212c06c307 100644 --- a/tests/end_to_end/tpu/gemma2/27b/2_test_gemma.sh +++ b/tests/end_to_end/tpu/gemma2/27b/2_test_gemma.sh @@ -39,11 +39,11 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-27b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-27b attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-27b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-27b attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Gemma2-27b # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_27b per_device_batch_size=1 model_name=gemma2-27b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0 --max_kl_div=0.15 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_27b per_device_batch_size=1 model_name=gemma2-27b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0 --max_kl_div=0.15 diff --git a/tests/end_to_end/tpu/gemma2/2b/test_gemma2.sh b/tests/end_to_end/tpu/gemma2/2b/test_gemma2.sh index ed051b012e..9a76fb606f 100644 --- a/tests/end_to_end/tpu/gemma2/2b/test_gemma2.sh +++ b/tests/end_to_end/tpu/gemma2/2b/test_gemma2.sh @@ -35,33 +35,33 @@ export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items export RUN_NAME=unscanned_chkpt_${idx} # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `src/MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma2-2b' force_unroll=true +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma2-2b' force_unroll=true export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-2b prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-2b prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune_${idx} -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma2-2b checkpoint_period=5 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma2-2b checkpoint_period=5 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma2-2b +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma2-2b # Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt_${idx} -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma2-2b' force_unroll=true +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma2-2b' force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-2b prompt="I love to" # We also test whether the forward pass logits match the golden logits for Gemma2-2b # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_2b per_device_batch_size=1 model_name=gemma2-2b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_2b per_device_batch_size=1 model_name=gemma2-2b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0 diff --git a/tests/end_to_end/tpu/gemma2/2b/test_gemma2_to_hf.sh b/tests/end_to_end/tpu/gemma2/2b/test_gemma2_to_hf.sh index 0701ef2c46..36c93557c0 100644 --- a/tests/end_to_end/tpu/gemma2/2b/test_gemma2_to_hf.sh +++ b/tests/end_to_end/tpu/gemma2/2b/test_gemma2_to_hf.sh @@ -31,7 +31,7 @@ export CKPT_PATH=gs://maxtext-gemma/unified/gemma2/2b/unscanned/2025-08-05-18-06 # export HF_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/hf/${idx} export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx} -python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${MODEL_NAME} \ hf_access_token=${HF_TOKEN} \ load_parameters_path=${CKPT_PATH} \ @@ -44,7 +44,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MA # We also test whether the forward pass logits match the original HF model # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${CKPT_PATH} \ model_name=${MODEL_NAME} \ diff --git a/tests/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh b/tests/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh index c0afc085b3..26e584e616 100644 --- a/tests/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh +++ b/tests/end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh @@ -29,7 +29,7 @@ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu export MODEL_BUCKET=gs://maxtext-gemma/unified/gemma2 # To get unscanned ckpt: -python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${MODEL_NAME} \ hf_access_token=${HF_TOKEN} \ base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx} \ @@ -38,7 +38,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEX export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0/items # To get scanned ckpt, flip the scan_layers: -python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${MODEL_NAME} \ hf_access_token=${HF_TOKEN} \ base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/scanned/${idx} \ @@ -49,7 +49,7 @@ export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/scanned/${idx}/0/ite # We also test whether the forward pass logits match the original HF model # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${UNSCANNED_CKPT_PATH} \ model_name=${MODEL_NAME} \ @@ -59,7 +59,7 @@ python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_ --run_hf_model=true # We can run decoding for unscanned checkpoints. -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data export DATASET_PATH=gs://maxtext-dataset @@ -69,7 +69,7 @@ export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/gemm2-2b # We can also run finetuning by using the scanned converted checkpoint. # Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune_${idx} -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${SCANNED_CKPT_PATH} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=${MODEL_NAME} checkpoint_period=5 scan_layers=true +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${SCANNED_CKPT_PATH} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=${MODEL_NAME} checkpoint_period=5 scan_layers=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=decode_test_${FINETUNE_RUN_NAME} max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=true prompt='I love to' attention=\'dot_product\' +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=decode_test_${FINETUNE_RUN_NAME} max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=true prompt='I love to' attention=\'dot_product\' diff --git a/tests/end_to_end/tpu/gemma2/9b/1_test_gemma.sh b/tests/end_to_end/tpu/gemma2/9b/1_test_gemma.sh index ae27315a78..9438647103 100644 --- a/tests/end_to_end/tpu/gemma2/9b/1_test_gemma.sh +++ b/tests/end_to_end/tpu/gemma2/9b/1_test_gemma.sh @@ -42,5 +42,5 @@ export RUN_NAME=unscanned_chkpt export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=gemma2-9b force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=gemma2-9b force_unroll=true echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/${RUN_NAME}/checkpoints/0/items" diff --git a/tests/end_to_end/tpu/gemma2/9b/2_test_gemma.sh b/tests/end_to_end/tpu/gemma2/9b/2_test_gemma.sh index 5834e3155b..60e8a88642 100644 --- a/tests/end_to_end/tpu/gemma2/9b/2_test_gemma.sh +++ b/tests/end_to_end/tpu/gemma2/9b/2_test_gemma.sh @@ -40,11 +40,11 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-9b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma2-9b attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-9b attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma2-9b attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Gemma2-9b # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_9b per_device_batch_size=1 model_name=gemma2-9b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0 --max_kl_div=0.15 \ No newline at end of file +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_gemma2_9b per_device_batch_size=1 model_name=gemma2-9b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype='float32' --atol=1.0 --rtol=1.0 --max_kl_div=0.15 \ No newline at end of file diff --git a/tests/end_to_end/tpu/gemma3/12b/test_gemma3.sh b/tests/end_to_end/tpu/gemma3/12b/test_gemma3.sh index 0261c5109c..718f171e4f 100644 --- a/tests/end_to_end/tpu/gemma3/12b/test_gemma3.sh +++ b/tests/end_to_end/tpu/gemma3/12b/test_gemma3.sh @@ -37,22 +37,22 @@ export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items export RUN_NAME=unscanned_chkpt_${idx} # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" # # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 # Finetune by using the scanned converted checkpoint by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. For Googlers, uncomment the line below if you want to use the pre-converted checkpoint. # export CONVERTED_CHECKPOINT=gs://maxtext-model-checkpoints/gemma3-12b/2025-03-19-21-16/scanned/0/items FINETUNE_RUN_NAME=runner_finetune_${idx} -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load_parameters_path PRETRAIN_RUN_NAME=runner_pretrain_${idx} -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03 diff --git a/tests/end_to_end/tpu/gemma3/27b/test_gemma3.sh b/tests/end_to_end/tpu/gemma3/27b/test_gemma3.sh index c30e09a701..07d61b36b7 100644 --- a/tests/end_to_end/tpu/gemma3/27b/test_gemma3.sh +++ b/tests/end_to_end/tpu/gemma3/27b/test_gemma3.sh @@ -37,22 +37,22 @@ export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items export RUN_NAME=unscanned_chkpt_${idx} # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" # # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 # Finetune by using the scanned converted checkpoint by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. For Googlers, uncomment the line below if you want to use the pre-converted checkpoint. # export CONVERTED_CHECKPOINT=gs://maxtext-model-checkpoints/gemma3-27b/2025-03-20-00-12/scanned/0/items FINETUNE_RUN_NAME=runner_finetune_${idx} -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load_parameters_path PRETRAIN_RUN_NAME=runner_pretrain_${idx} -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03 diff --git a/tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh b/tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh index bc55f04555..c36139acbf 100644 --- a/tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh +++ b/tests/end_to_end/tpu/gemma3/4b/test_gemma3.sh @@ -37,22 +37,22 @@ export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items export RUN_NAME=unscanned_chkpt_${idx} # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_NAME} force_unroll=true export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" # # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_NAME} per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false --atol=1.0 --rtol=1.0 # Finetune by using the scanned converted checkpoint by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. For Googlers, uncomment the line below if you want to use the pre-converted checkpoint. # export CONVERTED_CHECKPOINT=gs://maxtext-model-checkpoints/gemma3-4b/2025-03-18-19-03/scanned/0/items FINETUNE_RUN_NAME=runner_finetune_${idx} -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$FINETUNE_RUN_NAME steps=10 enable_checkpointing=true sharding_tolerance=0.03 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load_parameters_path PRETRAIN_RUN_NAME=runner_pretrain_${idx} -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml model_name=$MODEL_NAME base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=$PRETRAIN_RUN_NAME steps=10 enable_checkpointing=false sharding_tolerance=0.03 diff --git a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh index e5ed8b495f..e7b26c1f0d 100644 --- a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh +++ b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_multimodal_sft.sh @@ -31,7 +31,7 @@ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu export MODEL_BUCKET=gs://maxtext-gemma/unified/gemma3 # 1. Convert the HuggingFace checkpoint to MaxText unscanned ckpt: -python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${MODEL_NAME} \ hf_access_token=${HF_TOKEN} \ base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx} \ @@ -40,11 +40,11 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEX # 2. Decode the converted checkpoint to make sure it works export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0/items -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' # 3. SFT the MaxText converted checkpoint on ChartQA dataset export BASE_OUTPUT_DIRECTORY=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/sft -python -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/sft-vision-chartqa.yml \ +python -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//sft-vision-chartqa.yml \ run_name=$idx \ model_name=$MODEL_NAME tokenizer_path="google/gemma-3-4b-pt" \ per_device_batch_size=1 \ @@ -61,12 +61,12 @@ python -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src # 4. Decode from the finetuned checkpoint from step 3 export FINAL_CKPT_STEP=$((SFT_STEPS - 1)) export FINETUNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${idx}/checkpoints/${FINAL_CKPT_STEP}/items -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${FINETUNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${FINETUNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=$SCAN_LAYERS use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' # 5. Convert the SFT checkpoint back to HuggingFace format. export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx} export CKPT_PATH="gs://maxtext-gemma/unified/gemma3/4b/unscanned/sft/2025-08-08-18-28/2025-08-08-18-28/checkpoints/9/items" -python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${MODEL_NAME} \ hf_access_token=${HF_TOKEN} \ load_parameters_path=${CKPT_PATH} \ diff --git a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh index a1d4fa727d..163e3f1e02 100644 --- a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh +++ b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh @@ -32,7 +32,7 @@ export CKPT_PATH=gs://maxtext-gemma/unified/gemma3/4b/unscanned/2025-08-05-18-18 # export HF_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/hf/${idx} export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx} -python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${MODEL_NAME} \ hf_access_token=${HF_TOKEN} \ load_parameters_path=${CKPT_PATH} \ @@ -47,7 +47,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MA # We also test whether the forward pass logits match the original HF model # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${CKPT_PATH} \ model_name=${MODEL_NAME} \ diff --git a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh index ed2284e3ff..8d705d4598 100644 --- a/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh +++ b/tests/end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh @@ -31,7 +31,7 @@ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu export MODEL_BUCKET=gs://maxtext-gemma/unified/gemma3 # To get unscanned ckpt: -python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${MODEL_NAME} \ hf_access_token=${HF_TOKEN} \ base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx} \ @@ -42,7 +42,7 @@ export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0 # # To get scanned ckpt, flip the scan_layers. # ToDo: gemma3 multimodal scanned ckpt -# python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \ +# python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/maxtext/configs/base.yml \ # model_name=${MODEL_NAME} \ # hf_access_token=${HF_TOKEN} \ # base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/scanned/${idx} \ @@ -55,7 +55,7 @@ export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0 # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` # ToDo: improve forward_pass_logit_checker to test multi-modal prompt -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${UNSCANNED_CKPT_PATH} \ model_name=${MODEL_NAME} \ @@ -67,9 +67,9 @@ python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_ # We can run decoding for unscanned checkpoints. if [ ${USE_MULTIMODAL} == true ]; then - python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=false use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' + python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=false use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' else - python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' + python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' fi # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data @@ -80,11 +80,11 @@ export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs/gemma3-4b # We can also run finetuning by using the scanned converted checkpoint. # Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune_${idx} -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=${MODEL_NAME} checkpoint_period=5 scan_layers=false +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=${MODEL_NAME} checkpoint_period=5 scan_layers=false # Now, run decoding on the checkpoint generated from our finetune run. if [ ${USE_MULTIMODAL} == true ]; then - python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=false use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' + python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=false use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \\' image_path=\'tests/assets/test_image.jpg\' attention=\'dot_product\' else - python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' + python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\' fi diff --git a/tests/end_to_end/tpu/gemma3/Run_Gemma3.md b/tests/end_to_end/tpu/gemma3/Run_Gemma3.md index ab33b388aa..4b38cef3b8 100644 --- a/tests/end_to_end/tpu/gemma3/Run_Gemma3.md +++ b/tests/end_to_end/tpu/gemma3/Run_Gemma3.md @@ -25,7 +25,7 @@ We provide examples for checkpoint conversion and decoding/training/finetuning G You can train from scratch to generate a new checkpoint. One example command to run pretraining Gemma3-4B model is as follows: ```sh -python3 -m MaxText.train src/MaxText/configs/base.yml model_name=gemma3-4b base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=runner_pretrain_gemma3_4b steps=10 enable_checkpointing=false sharding_tolerance=0.03 +python3 -m MaxText.train src/maxtext/configs/base.yml model_name=gemma3-4b base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=runner_pretrain_gemma3_4b steps=10 enable_checkpointing=false sharding_tolerance=0.03 ``` ## Checkpoint Conversion @@ -35,12 +35,12 @@ To obtain the Gemma3 model weights, follow the instructions provided on [Kaggle] After the conversion, you will have a MaxText compatible checkpoint which allows you to fine-tune it with different datasets. One example command to fine-tune a Gemma3-4B model is as follows: ``` -python3 -m MaxText.train src/MaxText/configs/base.yml model_name=gemma3-4b base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=runner_finetune_gemma3_4b steps=10 enable_checkpointing=true sharding_tolerance=0.03 +python3 -m MaxText.train src/maxtext/configs/base.yml model_name=gemma3-4b base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 per_device_batch_size=1 run_name=runner_finetune_gemma3_4b steps=10 enable_checkpointing=true sharding_tolerance=0.03 ``` ## Decoding One example to use a converted checkpoint to decode with prompt "I love to": ``` -python3 -m maxtext.decode src/MaxText/configs/base.yml model_name=gemma3-4b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_decode_gemma3_4b max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false prompt="I love to" +python3 -m maxtext.decode src/maxtext/configs/base.yml model_name=gemma3-4b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma3 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_decode_gemma3_4b max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false prompt="I love to" ``` \ No newline at end of file diff --git a/tests/end_to_end/tpu/gpt_oss/120b/test_gpt_oss.sh b/tests/end_to_end/tpu/gpt_oss/120b/test_gpt_oss.sh index 751d2e7865..f5527af03c 100644 --- a/tests/end_to_end/tpu/gpt_oss/120b/test_gpt_oss.sh +++ b/tests/end_to_end/tpu/gpt_oss/120b/test_gpt_oss.sh @@ -51,17 +51,17 @@ export DATASET_PATH=gs://maxtext-dataset # Test whether the forward pass logits match the golden logits # default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check model_name=${MODEL_NAME} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product sparse_matmul=True megablox=True per_device_batch_size=1 max_target_length=4 max_prefill_predict_length=4 dtype=float32 --atol=0.5 --rtol=0.5 --max_kl_div=3e-3 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check model_name=${MODEL_NAME} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product sparse_matmul=True megablox=True per_device_batch_size=1 max_target_length=4 max_prefill_predict_length=4 dtype=float32 --atol=0.5 --rtol=0.5 --max_kl_div=3e-3 # Run pre-training - megablox implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=32 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=32 # Run fine-tuning - megablox implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32 # Run supervised fine-tuning - megablox implementation -python3 -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32 +python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=32 # Run decoding - megablox implementation # Note decode requires the access token for huggingface tokenizer even if the model is not gated -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=128 prompt="I love to" ici_fsdp_parallelism=32 ici_tensor_parallelism=1 +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=128 prompt="I love to" ici_fsdp_parallelism=32 ici_tensor_parallelism=1 diff --git a/tests/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh b/tests/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh index b71e19c415..734d2f059e 100644 --- a/tests/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh +++ b/tests/end_to_end/tpu/gpt_oss/20b/test_gpt_oss.sh @@ -55,17 +55,17 @@ export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=81920' # Test whether the forward pass logits match the golden logits # default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check model_name=${MODEL_NAME} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product sparse_matmul=True megablox=True per_device_batch_size=1 max_target_length=4 max_prefill_predict_length=4 dtype=float32 --atol=0.1 --rtol=0.1 --max_kl_div=3e-4 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check model_name=${MODEL_NAME} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product sparse_matmul=True megablox=True per_device_batch_size=1 max_target_length=4 max_prefill_predict_length=4 dtype=float32 --atol=0.1 --rtol=0.1 --max_kl_div=3e-4 # Run pre-training - megablox implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=4 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=4 # Run fine-tuning - megablox implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 # Run supervised fine-tuning - megablox implementation -python3 -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 +python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//sft.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=megablox_supervised_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=hf enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=1 ici_expert_parallelism=4 # Run decoding - megablox implementation # Note decode requires the access token for huggingface tokenizer even if the model is not gated -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=128 prompt="I love to" ici_fsdp_parallelism=1 ici_tensor_parallelism=4 +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True megablox=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=128 prompt="I love to" ici_fsdp_parallelism=1 ici_tensor_parallelism=4 diff --git a/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md b/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md index d7a159a4dc..71388dd444 100644 --- a/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md +++ b/tests/end_to_end/tpu/gpt_oss/run_gpt_oss.md @@ -58,7 +58,7 @@ python3 -m MaxText.utils.ckpt_scripts.convert_gpt_oss_unscanned_ckpt --base-mode You can train from scratch to generate a new checkpoint. One example command to run pretraining with gpt-oss-20b on v5p-8. ```sh -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ base_output_directory=${BASE_OUTPUT_PATH} \ run_name=megablox_pre_training \ model_name=gpt-oss-20b \ @@ -84,7 +84,7 @@ After you have a MaxText-compatible scanned checkpoint, you could finetune it wi One example command to run general finetuning with gpt-oss-20b on v5p-8. ```sh -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ base_output_directory=${BASE_OUTPUT_PATH} \ run_name=megablox_fine_tuning \ model_name=gpt-oss-20b \ @@ -110,7 +110,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \ One example command to run supervised finetuning with gpt-oss-20b on v5p-8. Supervised finetuning is only working with HuggingFace conversational datasets. And, you can customize the dataset path using the `hf_path` config. If using [gated dataset](https://huggingface.co/docs/hub/en/datasets-gated) or [gated model](https://huggingface.co/docs/hub/en/models-gated), you need additionally provide the access token with `hf_access_token` config. ```sh -python3 -m MaxText.sft_trainer src/MaxText/configs/sft.yml \ +python3 -m MaxText.sft_trainer src/maxtext/configs/post_train/sft.yml \ base_output_directory=${BASE_OUTPUT_PATH} \ run_name=megablox_supervised_fine_tuning \ model_name=gpt-oss-20b \ @@ -137,7 +137,7 @@ python3 -m MaxText.sft_trainer src/MaxText/configs/sft.yml \ One example command to run decoding with gpt-oss-20b on v5p-8 with unscanned checkpoint for fast decoding. ```sh -python3 -m maxtext.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/maxtext/configs/base.yml \ base_output_directory=${BASE_OUTPUT_PATH} \ run_name=decode \ model_name=gpt-oss-20b \ @@ -182,7 +182,7 @@ Run command below to compare logits between HuggingFace and MaxText. ```sh python3 -m tests.utils.forward_pass_logit_checker \ - src/MaxText/configs/base.yml \ + src/maxtext/configs/base.yml \ base_output_directory=${BASE_OUTPUT_PATH} \ run_name=forward_logits_check \ model_name=gpt-oss-20b \ diff --git a/tests/end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh b/tests/end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh index 25b45cf4d1..20d0244c50 100644 --- a/tests/end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh +++ b/tests/end_to_end/tpu/llama2/13b/1_test_llama2_13b.sh @@ -45,5 +45,5 @@ export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. export RUN_NAME=unscanned_chkpt -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='llama2-13b' force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='llama2-13b' force_unroll=true echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items" diff --git a/tests/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh b/tests/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh index f7b2934ae6..ef10b82cb8 100644 --- a/tests/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh +++ b/tests/end_to_end/tpu/llama2/13b/2_test_llama2_13b.sh @@ -39,23 +39,23 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` # We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach." +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach." # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach." +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" autoregressive_decode_assert="read. I love to write. I love to teach." # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} # Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" diff --git a/tests/end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh b/tests/end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh index 96c7925c8d..c87e7f75ac 100644 --- a/tests/end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh +++ b/tests/end_to_end/tpu/llama2/70b/1_test_llama2_70b.sh @@ -45,5 +45,5 @@ export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. export RUN_NAME=unscanned_chkpt -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='llama2-70b' force_unroll=true skip_jax_distributed_system=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='llama2-70b' force_unroll=true skip_jax_distributed_system=true echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items" diff --git a/tests/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh b/tests/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh index fb40ce99aa..34e0150b09 100644 --- a/tests/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh +++ b/tests/end_to_end/tpu/llama2/70b/2_test_llama2_70b.sh @@ -44,26 +44,26 @@ export ASYNC_CHECKPOINTING=true # True so that jax distributed system is initial # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} checkpoint_period=5 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} model_name=${MODEL_VARIATION} checkpoint_period=5 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} # Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=gs://runner-maxtext-logs tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.llama2 load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=${ASYNC_CHECKPOINTING} scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Llama2-70b -python3 -m tests.utils.forward_pass_logit_checker --atol=0.2 --rtol=0.2 "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-70b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false async_checkpointing=${ASYNC_CHECKPOINTING} +python3 -m tests.utils.forward_pass_logit_checker --atol=0.2 --rtol=0.2 "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-70b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false async_checkpointing=${ASYNC_CHECKPOINTING} diff --git a/tests/end_to_end/tpu/llama2/7b/test_llama2_7b.sh b/tests/end_to_end/tpu/llama2/7b/test_llama2_7b.sh index b29d8994aa..5365c84d04 100644 --- a/tests/end_to_end/tpu/llama2/7b/test_llama2_7b.sh +++ b/tests/end_to_end/tpu/llama2/7b/test_llama2_7b.sh @@ -41,43 +41,43 @@ export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint_${idx} -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true # Like before, we define `UNSCANNED_CKPT_PATH` to refer to the checkpoint subdirectory exactly export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${DIRECT_PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint converted directly from Meta's PyTorch checkpoint aka `CONVERTED_CHECKPOINT`. Note that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=runner_decode_unscanned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=runner_decode_unscanned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning # We run this with context parallelism and the `ici_context_parallelism` flag as an integration test for context parallelism -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_finetuning_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_context_parallelism=4 steps=10 per_device_batch_size=1 checkpoint_period=5 packing=false +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml load_parameters_path=${CONVERTED_CHECKPOINT} run_name=runner_finetuning_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_context_parallelism=4 steps=10 per_device_batch_size=1 checkpoint_period=5 packing=false # We also run pre-training of Llama2-7b, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from # We run this with context parallelism and the `ici_context_parallelism` flag as an integration test for context parallelism -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml run_name=runner_pretraining_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_context_parallelism=4 steps=10 per_device_batch_size=1 packing=false +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml run_name=runner_pretraining_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} async_checkpointing=${ASYNC_CHECKPOINTING} per_device_batch_size=1 model_name='llama2-7b' ici_context_parallelism=4 steps=10 per_device_batch_size=1 packing=false # Now, run decoding on the checkpoint generated from our finetune run. Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. So, we can use the `MaxText.generate_param_only_checkpoint` to convert # the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run, say the checkpoint saved at finetuning step #5 # Also, `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAMETER_CHECKPOINT_RUN=generate_param_only_checkpoint_${idx} -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/runner_finetuning_${idx}/checkpoints/5/items run_name=${PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/runner_finetuning_${idx}/checkpoints/5/items run_name=${PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' force_unroll=true # Like before, we define `NEW_CKPT_PATH` to refer to the checkpoint subdirectory exactly export NEW_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${PARAMETER_CHECKPOINT_RUN}/checkpoints/0/items # We run decoding on the fine-tuned parameter checkpoint -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${NEW_CKPT_PATH} run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml load_parameters_path=${NEW_CKPT_PATH} run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false # We also test whether the forward pass logits match the golden logits for Llama2-7b -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false --rtol=0.1 --atol=0.1 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=llama2-7b ici_tensor_parallelism=4 max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 scan_layers=false --rtol=0.1 --atol=0.1 # Converting MaxText orbax checkpoint to HF -JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CONVERTED_CHECKPOINT} run_name=convert_to_hf model_name=llama2-7b hf_model_path=/tmp/hf_llama2 +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CONVERTED_CHECKPOINT} run_name=convert_to_hf model_name=llama2-7b hf_model_path=/tmp/hf_llama2 # Test whether the forward pass logits match the golden logits for Huggingface checkpoint converted from MaxText orbax checkpoint # We run this with context parallelism and the `ici_context_parallelism` flag as an integration test for context parallelism -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_hf ici_context_parallelism=4 model_name=llama2-7b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --hf_model_path=/tmp/hf_llama2 --max_kl_div=1e-4 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_hf ici_context_parallelism=4 model_name=llama2-7b max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --hf_model_path=/tmp/hf_llama2 --max_kl_div=1e-4 diff --git a/tests/end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh b/tests/end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh index d7e46fd78f..b5f74b4044 100644 --- a/tests/end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh +++ b/tests/end_to_end/tpu/llama3.1/405b/2_test_llama3.1_405b.sh @@ -40,10 +40,10 @@ export UNSCANNED_CHECKPOINT=gs://maxtext-llama/llama3.1_405b_bf16/unscanned/0/it # We run finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning. # We use a small per_device_batch_size and SGD optimizer for the model to fit on a v4-128. This config is only used for unit testing. export FINETUNE_RUN_NAME=runner_finetune -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_type=synthetic tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${SCANNED_CHECKCKPOINT} per_device_batch_size=0.25 ici_tensor_parallelism=4 run_name=${FINETUNE_RUN_NAME} steps=10 enable_checkpointing=false model_name=${MODEL_VARIATION} logits_dot_in_fp32=false weight_dtype=bfloat16 opt_type=sgd +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_type=synthetic tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${SCANNED_CHECKCKPOINT} per_device_batch_size=0.25 ici_tensor_parallelism=4 run_name=${FINETUNE_RUN_NAME} steps=10 enable_checkpointing=false model_name=${MODEL_VARIATION} logits_dot_in_fp32=false weight_dtype=bfloat16 opt_type=sgd # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${SCANNED_CHECKCKPOINT} per_device_batch_size=0.0625 ici_tensor_parallelism=4 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product weight_dtype=bfloat16 prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${SCANNED_CHECKCKPOINT} per_device_batch_size=0.0625 ici_tensor_parallelism=4 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product weight_dtype=bfloat16 prompt="I love to" # We also test whether the forward pass logits match the golden logits for Llama3.1-405B -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=0.0625 ici_tensor_parallelism=4 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype=float32 activations_in_float32=true matmul_precision=float32 weight_dtype=float32 async_checkpointing=false --max_kl_div=1e-4 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=0.0625 ici_tensor_parallelism=4 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic scan_layers=false dtype=float32 activations_in_float32=true matmul_precision=float32 weight_dtype=float32 async_checkpointing=false --max_kl_div=1e-4 diff --git a/tests/end_to_end/tpu/llama3.1/405b/3_test_llama3.1_405b.sh b/tests/end_to_end/tpu/llama3.1/405b/3_test_llama3.1_405b.sh index 43eda1db98..aafa98c276 100644 --- a/tests/end_to_end/tpu/llama3.1/405b/3_test_llama3.1_405b.sh +++ b/tests/end_to_end/tpu/llama3.1/405b/3_test_llama3.1_405b.sh @@ -17,7 +17,7 @@ export SAVE_QUANT_PARAMS_PATH=gs://maxtext-llama/llama3.1_405b_int8 export QUANTIZE_TYPE="int8" JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.load_and_quantize_checkpoint \ - "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken \ tokenizer_type=tiktoken \ load_parameters_path=${UNSCANNED_CHECKPOINT} \ diff --git a/tests/end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh b/tests/end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh index d9a3086c88..85c15a0726 100644 --- a/tests/end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh +++ b/tests/end_to_end/tpu/llama3.1/70b/1_test_llama3.1_70b.sh @@ -44,5 +44,5 @@ export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. export RUN_NAME=unscanned_chkpt -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items" diff --git a/tests/end_to_end/tpu/llama3.1/70b/2_test_llama3.1_70b.sh b/tests/end_to_end/tpu/llama3.1/70b/2_test_llama3.1_70b.sh index 2531fc32a2..97d8bf10bb 100644 --- a/tests/end_to_end/tpu/llama3.1/70b/2_test_llama3.1_70b.sh +++ b/tests/end_to_end/tpu/llama3.1/70b/2_test_llama3.1_70b.sh @@ -43,14 +43,14 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items # TODO(mohitkhatwani): Fix XLAResourceExhaustion when loading unscanned model # # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -# python3 -m maxtext.decode src/MaxText/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +# python3 -m maxtext.decode src/maxtext/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 # We also test whether the forward pass logits match the golden logits for Llama3.1-70B -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 async_checkpointing=false dtype=float32 activations_in_float32=true matmul_precision=float32 --max_kl_div=1e-4 \ No newline at end of file +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 async_checkpointing=false dtype=float32 activations_in_float32=true matmul_precision=float32 --max_kl_div=1e-4 \ No newline at end of file diff --git a/tests/end_to_end/tpu/llama3.1/70b/3_test_llama3.1_70b.sh b/tests/end_to_end/tpu/llama3.1/70b/3_test_llama3.1_70b.sh index 842c075895..5a44b4d6ea 100644 --- a/tests/end_to_end/tpu/llama3.1/70b/3_test_llama3.1_70b.sh +++ b/tests/end_to_end/tpu/llama3.1/70b/3_test_llama3.1_70b.sh @@ -16,27 +16,27 @@ rm $CHECKPOINT_ORIGINAL/scanned_chkpt $CHECKPOINT_ORIGINAL/unscanned_chkpt ${CHE JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt --base-model-path=$CHECKPOINT_ORIGINAL --model-size=$MODEL_SIZE --maxtext-model-path=$CHECKPOINT_TPU_SCANNED --huggingface-checkpoint=true # Let's verify the original checkpoint to see if it matches with Huggingface golden logits -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=true --max_kl_div=1e-4 --hf_model_path=$CHECKPOINT_ORIGINAL --golden_logits_path=$GOLDEN_LOGITS +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=true --max_kl_div=1e-4 --hf_model_path=$CHECKPOINT_ORIGINAL --golden_logits_path=$GOLDEN_LOGITS # Let's verify the generated scanned checkpoint to see if it matches with Huggingface golden logits -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=true --max_kl_div=1e-4 --golden_logits_path=$GOLDEN_LOGITS +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=true --max_kl_div=1e-4 --golden_logits_path=$GOLDEN_LOGITS # If not, we can convert the checkpoint back from MaxText to Huggingface and compare with the original one -JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=convert_to_hf model_name=${MODEL_SIZE} hf_model_path=$CHECKPOINT_TPU_CONVERTED_BACK +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=convert_to_hf model_name=${MODEL_SIZE} hf_model_path=$CHECKPOINT_TPU_CONVERTED_BACK python3 -m tests.utils.hf_checkpoint_conversion_checker --original_ckpt=${CHECKPOINT_ORIGINAL} --converted_ckpt=${CHECKPOINT_TPU_CONVERTED_BACK} # If everything looks good, we move on to convert to the unrolled checkpoint for performant serving -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=${RUN_NAME} model_name=${MODEL_SIZE} force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=${RUN_NAME} model_name=${MODEL_SIZE} force_unroll=true # Let's verify the generated unscanned checkpoint to see if it matches with Huggingface golden logits -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_UNSCANNED} run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --max_kl_div=1e-4 --golden_logits_path=$GOLDEN_LOGITS +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_UNSCANNED} run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --max_kl_div=1e-4 --golden_logits_path=$GOLDEN_LOGITS # Now we are good to go, serve with performance! -JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=false load_parameters_path=$CHECKPOINT_TPU_UNSCANNED +JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=false load_parameters_path=$CHECKPOINT_TPU_UNSCANNED # You can also check the results from scanned version, just double check, not necessary -JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=true load_parameters_path=$CHECKPOINT_TPU_SCANNED/0/items +JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=true load_parameters_path=$CHECKPOINT_TPU_SCANNED/0/items # Example output # Input `I love to` -> ` read, but I don't have much time. How can I read more books? diff --git a/tests/end_to_end/tpu/llama3.1/8b/1_test_llama3.1_8b.sh b/tests/end_to_end/tpu/llama3.1/8b/1_test_llama3.1_8b.sh index d5dcc39b16..a904fd8713 100644 --- a/tests/end_to_end/tpu/llama3.1/8b/1_test_llama3.1_8b.sh +++ b/tests/end_to_end/tpu/llama3.1/8b/1_test_llama3.1_8b.sh @@ -44,6 +44,6 @@ export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. export RUN_NAME=unscanned_chkpt -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items" diff --git a/tests/end_to_end/tpu/llama3.1/8b/2_test_llama3.1_8b.sh b/tests/end_to_end/tpu/llama3.1/8b/2_test_llama3.1_8b.sh index c783bddd89..c4fc94d960 100644 --- a/tests/end_to_end/tpu/llama3.1/8b/2_test_llama3.1_8b.sh +++ b/tests/end_to_end/tpu/llama3.1/8b/2_test_llama3.1_8b.sh @@ -43,25 +43,25 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune # We run this with context parallelism and the `ici_context_parallelism` flag as an integration test for context parallelism -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} ici_context_parallelism=4 steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 packing=false +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} ici_context_parallelism=4 steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 packing=false # We also test whether the forward pass logits match the golden logits for LLama3.1-8B # We run this with context parallelism and the `ici_context_parallelism` flag as an integration test for context parallelism -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 ici_context_parallelism=4 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --max_kl_div=1e-4 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 ici_context_parallelism=4 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --max_kl_div=1e-4 # Converting MaxText orbax checkpoint to HF -JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CONVERTED_CHECKPOINT} run_name=convert_to_hf model_name=${MODEL_VARIATION} hf_model_path=/tmp/hf_llama3_1 +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CONVERTED_CHECKPOINT} run_name=convert_to_hf model_name=${MODEL_VARIATION} hf_model_path=/tmp/hf_llama3_1 # Installing torch for running forward pass of a Huggingface checkpoint python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu # Test whether the forward pass logits match the golden logits for Huggingface checkpoint converted from MaxText orbax checkpoint # We run this with context parallelism and the `ici_context_parallelism` flag as an integration test for context parallelism -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_hf per_device_batch_size=1 ici_context_parallelism=4 model_name=${MODEL_VARIATION} max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --hf_model_path=/tmp/hf_llama3_1 --max_kl_div=1e-4 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_hf per_device_batch_size=1 ici_context_parallelism=4 model_name=${MODEL_VARIATION} max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --hf_model_path=/tmp/hf_llama3_1 --max_kl_div=1e-4 diff --git a/tests/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh b/tests/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh index 0595b00824..0c10b421d4 100644 --- a/tests/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh +++ b/tests/end_to_end/tpu/llama3.1/8b/3_test_llama3.1_8b.sh @@ -22,27 +22,27 @@ rm $CHECKPOINT_ORIGINAL/scanned_chkpt $CHECKPOINT_ORIGINAL/unscanned_chkpt ${CHE JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt --base-model-path=$CHECKPOINT_ORIGINAL --model-size=$MODEL_SIZE --maxtext-model-path=$CHECKPOINT_TPU_SCANNED --huggingface-checkpoint=true # Let's verify the original checkpoint to see if it matches with Huggingface golden logits -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=true --max_kl_div=1e-4 --hf_model_path=$CHECKPOINT_ORIGINAL --golden_logits_path=$GOLDEN_LOGITS +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=true --max_kl_div=1e-4 --hf_model_path=$CHECKPOINT_ORIGINAL --golden_logits_path=$GOLDEN_LOGITS # Let's verify the generated scanned checkpoint to see if it matches with Huggingface golden logits -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=true --max_kl_div=1e-4 --golden_logits_path=$GOLDEN_LOGITS +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=true --max_kl_div=1e-4 --golden_logits_path=$GOLDEN_LOGITS # If not, we can convert the checkpoint back from MaxText to Huggingface and compare with the original one -JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=convert_to_hf model_name=${MODEL_SIZE} hf_model_path=$CHECKPOINT_TPU_CONVERTED_BACK +JAX_PLATFORMS=cpu python3 -m MaxText.utils.ckpt_scripts.llama_mistral_mixtral_orbax_to_hf "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=gs://runner-maxtext-logs load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=convert_to_hf model_name=${MODEL_SIZE} hf_model_path=$CHECKPOINT_TPU_CONVERTED_BACK python3 -m tests.utils.hf_checkpoint_conversion_checker --original_ckpt=${CHECKPOINT_ORIGINAL} --converted_ckpt=${CHECKPOINT_TPU_CONVERTED_BACK} # If everything looks good, we move on to convert to the unrolled checkpoint for performant serving -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=${RUN_NAME} model_name=${MODEL_SIZE} force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=${RUN_NAME} model_name=${MODEL_SIZE} force_unroll=true # Let's verify the generated unscanned checkpoint to see if it matches with Huggingface golden logits -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_UNSCANNED} run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --max_kl_div=1e-4 --golden_logits_path=$GOLDEN_LOGITS +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path=$TOKENIZER tokenizer_type=tiktoken load_parameters_path=${CHECKPOINT_TPU_UNSCANNED} run_name=forward_pass_test_hf per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=3 max_target_length=4 dataset_type=synthetic dtype=float32 activations_in_float32=true matmul_precision=float32 async_checkpointing=false scan_layers=false --max_kl_div=1e-4 --golden_logits_path=$GOLDEN_LOGITS # Now we are good to go, serve with performance! -JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=false load_parameters_path=$CHECKPOINT_TPU_UNSCANNED +JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=false load_parameters_path=$CHECKPOINT_TPU_UNSCANNED # You can also check the results from scanned version, just double check, not necessary -JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=true load_parameters_path=$CHECKPOINT_TPU_SCANNED/0/items +JAX_PLATFORMS=tpu python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path=$TOKENIZER tokenizer_type=tiktoken run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=true load_parameters_path=$CHECKPOINT_TPU_SCANNED/0/items ##### Output from huggingface llama 8B Instruct checkpoint on MaxText: #Input `I love to` -> ` travel and explore new places, but I also love to stay at home and relax. I'm a bit of a homebody, and I enjoy spending time with my family and friends. I'm a bit of a foodie, and I love trying new recipes and experimenting with different flavors and ingredients. I'm also a bit of a movie buff, and I love watching classic films and new releases alike. diff --git a/tests/end_to_end/tpu/llama3.1/8b/run_sft.sh b/tests/end_to_end/tpu/llama3.1/8b/run_sft.sh index 5729cad242..0982739a91 100644 --- a/tests/end_to_end/tpu/llama3.1/8b/run_sft.sh +++ b/tests/end_to_end/tpu/llama3.1/8b/run_sft.sh @@ -47,7 +47,7 @@ PRE_TRAINED_MODEL_TOKENIZER=meta-llama/Llama-3.1-8B-Instruct if [ -z "${PRE_TRAINED_MODEL_CKPT_PATH}" ]; then echo "PRE_TRAINED_MODEL_CKPT_PATH is not set. Converting Hugging Face checkpoint to MaxText format." CONVERTED_CKPT_DIR=${BASE_OUTPUT_DIRECTORY}/${PRE_TRAINED_MODEL}/${RUN_NAME}/maxtext-checkpoint - python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \ + python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ model_name=${PRE_TRAINED_MODEL} \ hf_access_token=${HF_TOKEN} \ base_output_directory=${CONVERTED_CKPT_DIR} \ @@ -57,7 +57,7 @@ fi echo "Running fine-tuning on checkpoint: ${PRE_TRAINED_MODEL_CKPT_PATH}" # Run Supervised Fine-Tuning on MaxText checkpoint using HuggingFaceH4/ultrachat_200k dataset -python3 -m maxtext.trainers.post_train.sft.train_sft "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/sft.yml \ +python3 -m maxtext.trainers.post_train.sft.train_sft "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/sft.yml \ run_name=${RUN_NAME} base_output_directory=${BASE_OUTPUT_DIRECTORY}/${PRE_TRAINED_MODEL} \ model_name=${PRE_TRAINED_MODEL} load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \ hf_access_token=$HF_TOKEN tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \ diff --git a/tests/end_to_end/tpu/llama3.3/70b/1_test_llama3.3_70b.sh b/tests/end_to_end/tpu/llama3.3/70b/1_test_llama3.3_70b.sh index b5cb4da1ce..5603272311 100644 --- a/tests/end_to_end/tpu/llama3.3/70b/1_test_llama3.3_70b.sh +++ b/tests/end_to_end/tpu/llama3.3/70b/1_test_llama3.3_70b.sh @@ -44,6 +44,6 @@ export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. export RUN_NAME=unscanned_chkpt -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items" diff --git a/tests/end_to_end/tpu/llama3.3/70b/2_test_llama3.3_70b.sh b/tests/end_to_end/tpu/llama3.3/70b/2_test_llama3.3_70b.sh index 89adba1909..425b23cde7 100644 --- a/tests/end_to_end/tpu/llama3.3/70b/2_test_llama3.3_70b.sh +++ b/tests/end_to_end/tpu/llama3.3/70b/2_test_llama3.3_70b.sh @@ -47,14 +47,14 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items # TODO(mohitkhatwani): Fix XLAResourceExhaustion when loading unscanned model # # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -# python3 -m maxtext.decode src/MaxText/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +# python3 -m maxtext.decode src/maxtext/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_VARIATION} run_name=${FINETUNE_RUN_NAME} base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken steps=10 per_device_batch_size=1 load_parameters_path=${CONVERTED_CHECKPOINT} +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml model_name=${MODEL_VARIATION} run_name=${FINETUNE_RUN_NAME} base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken tokenizer_type=tiktoken steps=10 per_device_batch_size=1 load_parameters_path=${CONVERTED_CHECKPOINT} # We also test whether the forward pass logits match the golden logits for Llama3.3-70B-Instruct -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 async_checkpointing=false dtype=float32 activations_in_float32=true matmul_precision=float32 --max_kl_div=1e-4 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 async_checkpointing=false dtype=float32 activations_in_float32=true matmul_precision=float32 --max_kl_div=1e-4 diff --git a/tests/end_to_end/tpu/llama3/70b/1_test_llama3_70b.sh b/tests/end_to_end/tpu/llama3/70b/1_test_llama3_70b.sh index 0dcb4d8891..466dea9964 100644 --- a/tests/end_to_end/tpu/llama3/70b/1_test_llama3_70b.sh +++ b/tests/end_to_end/tpu/llama3/70b/1_test_llama3_70b.sh @@ -44,5 +44,5 @@ export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. export RUN_NAME=unscanned_chkpt -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items" diff --git a/tests/end_to_end/tpu/llama3/70b/2_test_llama3_70b.sh b/tests/end_to_end/tpu/llama3/70b/2_test_llama3_70b.sh index 242693b580..2b772d1c15 100644 --- a/tests/end_to_end/tpu/llama3/70b/2_test_llama3_70b.sh +++ b/tests/end_to_end/tpu/llama3/70b/2_test_llama3_70b.sh @@ -42,26 +42,26 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} # Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Llama3-70B -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 async_checkpointing=false scan_layers=false --atol=0.1 --rtol=0.1 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 async_checkpointing=false scan_layers=false --atol=0.1 --rtol=0.1 diff --git a/tests/end_to_end/tpu/llama3/8b/1_test_llama3_8b.sh b/tests/end_to_end/tpu/llama3/8b/1_test_llama3_8b.sh index 8f7c66942f..f2d1edbd26 100644 --- a/tests/end_to_end/tpu/llama3/8b/1_test_llama3_8b.sh +++ b/tests/end_to_end/tpu/llama3/8b/1_test_llama3_8b.sh @@ -44,6 +44,6 @@ export CONVERTED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_chkpt # Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. # We can do this by running `MaxText.generate_param_only_checkpoint` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. export RUN_NAME=unscanned_chkpt -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='llama3-8b' force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='llama3-8b' force_unroll=true echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items" diff --git a/tests/end_to_end/tpu/llama3/8b/2_test_llama3_8b.sh b/tests/end_to_end/tpu/llama3/8b/2_test_llama3_8b.sh index 42f4a8509b..82cfec2152 100644 --- a/tests/end_to_end/tpu/llama3/8b/2_test_llama3_8b.sh +++ b/tests/end_to_end/tpu/llama3/8b/2_test_llama3_8b.sh @@ -42,26 +42,26 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints/0/items # We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. # So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic async_checkpointing=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} steps=10 async_checkpointing=false model_name=${MODEL_VARIATION} checkpoint_period=5 # We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) steps=5 enable_checkpointing=false model_name=${MODEL_VARIATION} # Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. # So, we can use the `src/MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_PATH` from our previous finetuning run. # `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding export PARAM_RUN_NAME=param_chkpt -python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true +python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} load_full_state_path=${BASE_OUTPUT_PATH}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name=${MODEL_VARIATION} force_unroll=true # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${BASE_OUTPUT_PATH}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=4 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_VARIATION} attention=dot_product prompt="I love to" # We also test whether the forward pass logits match the golden logits for Llama3-8B -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 async_checkpointing=false scan_layers=false --atol=0.1 --rtol=0.1 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer_llama3.tiktoken load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 dataset_type=synthetic dtype=float32 async_checkpointing=false scan_layers=false --atol=0.1 --rtol=0.1 diff --git a/tests/end_to_end/tpu/llama4/2_test_llama4.sh b/tests/end_to_end/tpu/llama4/2_test_llama4.sh index e3ca54c652..d23ef4cd50 100644 --- a/tests/end_to_end/tpu/llama4/2_test_llama4.sh +++ b/tests/end_to_end/tpu/llama4/2_test_llama4.sh @@ -36,4 +36,4 @@ echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items # Step 2: run logit checking -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_VARIATION} attention=dot_product per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 scan_layers=false --atol=0.01 --rtol=0.01 async_checkpointing=false sparse_matmul=false weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=float32 float32_logits=true float32_qk_product=true ici_expert_parallelism=16 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_VARIATION} attention=dot_product per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 scan_layers=false --atol=0.01 --rtol=0.01 async_checkpointing=false sparse_matmul=false weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=float32 float32_logits=true float32_qk_product=true ici_expert_parallelism=16 diff --git a/tests/end_to_end/tpu/llama4/Run_Llama4.md b/tests/end_to_end/tpu/llama4/Run_Llama4.md index c4660b3bb5..9f1ccc3002 100644 --- a/tests/end_to_end/tpu/llama4/Run_Llama4.md +++ b/tests/end_to_end/tpu/llama4/Run_Llama4.md @@ -43,7 +43,7 @@ JAX_PLATFORMS=CPU python -m MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt --b You can train from scratch to generate a new checkpoint. One example command to run pretraining with Llama4 Maverick on a v5p-128. ```sh -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ run_name=matmul_pre_training \ per_device_batch_size=1 \ @@ -65,7 +65,7 @@ python3 -m MaxText.train src/MaxText/configs/base.yml \ In order to run an example decoding with Llama4 Scout, you can use a command such as the following: ```sh -python3 -m maxtext.decode src/MaxText/configs/base.yml \ +python3 -m maxtext.decode src/maxtext/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ run_name=decode \ model_name=llama4-17b-16e \ diff --git a/tests/end_to_end/tpu/llama_finetuning_test.sh b/tests/end_to_end/tpu/llama_finetuning_test.sh index 490fc500d6..695245de61 100644 --- a/tests/end_to_end/tpu/llama_finetuning_test.sh +++ b/tests/end_to_end/tpu/llama_finetuning_test.sh @@ -13,7 +13,7 @@ DATASET_PATH=gs://maxtext-dataset export LOSS_THRESHOLD=2.5 -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${base_ckpt_path} model_name='llama2-7b' dataset_path=${DATASET_PATH} async_checkpointing=false model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 per_device_batch_size=.25 metrics_file='metrics.txt' +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=runner_direct_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${base_ckpt_path} model_name='llama2-7b' dataset_path=${DATASET_PATH} async_checkpointing=false model_name='llama2-7b' ici_tensor_parallelism=4 steps=10 per_device_batch_size=.25 metrics_file='metrics.txt' # Assert training loss is smaller than input LOSS_THRESHOLD python3 tests/end_to_end/tpu/eval_assert.py final_loss metrics.txt $LOSS_THRESHOLD \ No newline at end of file diff --git a/tests/end_to_end/tpu/mistral/7b/test_mistral-7b.sh b/tests/end_to_end/tpu/mistral/7b/test_mistral-7b.sh index 4bd83aa3a4..c362d4906e 100644 --- a/tests/end_to_end/tpu/mistral/7b/test_mistral-7b.sh +++ b/tests/end_to_end/tpu/mistral/7b/test_mistral-7b.sh @@ -34,13 +34,13 @@ export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/ # Generate unscanned ckpt for efficient decoding test export RUN_NAME=unscanned_ckpt -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=${RUN_NAME} model_name='mistral-7b' force_unroll=true +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=${RUN_NAME} model_name='mistral-7b' force_unroll=true echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints" export DATASET_PATH=gs://maxtext-dataset # Run decoding with converted ckpt - matmul implementation -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding per_device_batch_size=1 model_name=mistral-7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=16 prompt='"[INST] I love to [/INST]"' attention=dot_product megablox=False sparse_matmul=False +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml load_parameters_path=${SCANNED_CHECKPOINT} run_name=scanned_decoding per_device_batch_size=1 model_name=mistral-7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=16 prompt='"[INST] I love to [/INST]"' attention=dot_product megablox=False sparse_matmul=False # Test whether the forward pass logits match the golden logits - matmul implementation -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mistral-7b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False sparse_matmul=False --atol=3 --rtol=1 --token_size=4 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mistral-7b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 max_prefill_predict_length=11 max_target_length=11 dataset_type=synthetic dtype=float32 megablox=False sparse_matmul=False --atol=3 --rtol=1 --token_size=4 diff --git a/tests/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh b/tests/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh index 81ad187bd8..c2af7c04ce 100644 --- a/tests/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh +++ b/tests/end_to_end/tpu/mixtral/8x22b/2_test_mixtral.sh @@ -23,7 +23,7 @@ export DATASET_PATH=gs://maxtext-dataset export TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v3 # Run pre-training without load_parameters_path - megablox implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \ run_name=pre_training_megablox per_device_batch_size=4 enable_checkpointing=false \ model_name=mixtral-8x22b ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 \ @@ -32,7 +32,7 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT weight_dtype=bfloat16 megablox=True sparse_matmul=True # Run pre-training without load_parameters_path - matmul implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \ run_name=pre_training_matmul per_device_batch_size=4 enable_checkpointing=false \ model_name=mixtral-8x22b ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 \ @@ -41,7 +41,7 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT weight_dtype=bfloat16 megablox=False sparse_matmul=False # Run pre-training without load_parameters_path - dropping implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} \ run_name=pre_training_dropping per_device_batch_size=4 enable_checkpointing=false \ model_name=mixtral-8x22b ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 \ diff --git a/tests/end_to_end/tpu/mixtral/8x7b/1_test_mixtral.sh b/tests/end_to_end/tpu/mixtral/8x7b/1_test_mixtral.sh index 4e1da17f81..a62633a208 100644 --- a/tests/end_to_end/tpu/mixtral/8x7b/1_test_mixtral.sh +++ b/tests/end_to_end/tpu/mixtral/8x7b/1_test_mixtral.sh @@ -40,5 +40,5 @@ fusermount -u "$PARAM_DIR" # Generate unscanned ckpt for efficient decoding test export SCANNED_CHECKPOINT=${BASE_OUTPUT_PATH}/${MODEL_VARIATION}/scanned_ckpt/0/items export RUN_NAME=unscanned_ckpt -JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=${RUN_NAME} model_name='mixtral-8x7b' force_unroll=true skip_jax_distributed_system=True +JAX_PLATFORMS=cpu python3 -m MaxText.generate_param_only_checkpoint "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${SCANNED_CHECKPOINT} run_name=${RUN_NAME} model_name='mixtral-8x7b' force_unroll=true skip_jax_distributed_system=True echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/${RUN_NAME}/checkpoints" diff --git a/tests/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh b/tests/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh index d7145259a2..c79dfb2afd 100644 --- a/tests/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh +++ b/tests/end_to_end/tpu/mixtral/8x7b/2_test_mixtral.sh @@ -35,21 +35,21 @@ export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned_ckpt/checkpoints/0/item # Run decoding with converted ckpt - matmul implementation # TODO(ranran): add decoding test for megablox implementation -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false # Run decoding with converted ckpt - dropping implementation -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false capacity_factor=1.25 +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=unscanned_decoding per_device_batch_size=1 model_name=mixtral-8x7b async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=24 prompt='"[INST] I love to [/INST]"' megablox=False sparse_matmul=False scan_layers=false capacity_factor=1.25 # Test whether the forward pass logits match the golden logits - matmul implementation -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=11 dtype=float32 megablox=False sparse_matmul=False scan_layers=false --token_size=4 --max_kl_div=3e-3 +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=matmul_forward_pass_test per_device_batch_size=1 model_name=mixtral-8x7b tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 ici_tensor_parallelism=1 ici_fsdp_parallelism=-1 max_prefill_predict_length=11 max_target_length=11 dtype=float32 megablox=False sparse_matmul=False scan_layers=false --token_size=4 --max_kl_div=3e-3 # To repeat duplicate tests, we have MoE unit test to verify outputs matching for matmul, megablox, and ragged_dot implementation at https://github.com/AI-Hypercomputer/maxtext/blob/5c4090b8d5713a1a25cab85df89b0ec9c9862635/MaxText/tests/unit/moe_test.py#L338-L411 # Run pre-training - megablox implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=megablox_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=megablox_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 # Run pre-training - matmul implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=matmul_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=matmul_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False # Run pre-training - dropping implementation -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=dropping_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False capacity_factor=1.25 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_PATH} dataset_path=${DATASET_PATH} run_name=dropping_pre_training per_device_batch_size=4 enable_checkpointing=false model_name=mixtral-8x7b ici_fsdp_parallelism=-1 steps=5 max_target_length=1024 async_checkpointing=false tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.mistral-v1 attention=flash dtype=bfloat16 weight_dtype=bfloat16 megablox=False sparse_matmul=False capacity_factor=1.25 diff --git a/tests/end_to_end/tpu/qwen/moe/qwen3-235b-a22b/1_test_qwen3_235b_a22b.sh b/tests/end_to_end/tpu/qwen/moe/qwen3-235b-a22b/1_test_qwen3_235b_a22b.sh index df16ef8ed6..f60705b75b 100644 --- a/tests/end_to_end/tpu/qwen/moe/qwen3-235b-a22b/1_test_qwen3_235b_a22b.sh +++ b/tests/end_to_end/tpu/qwen/moe/qwen3-235b-a22b/1_test_qwen3_235b_a22b.sh @@ -40,7 +40,7 @@ echo "Validating MaxText checkpoint at ${MAXTEXT_CHECKPOINT_PATH}" echo "Against original HF model: ${HF_MODEL_PATH}" # This command runs the core validation logic. -JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ tokenizer_type=huggingface \ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \ megablox=False \ diff --git a/tests/end_to_end/tpu/qwen/moe/qwen3-30b-a3b/1_test_qwen3_30b_a3b.sh b/tests/end_to_end/tpu/qwen/moe/qwen3-30b-a3b/1_test_qwen3_30b_a3b.sh index 2f7446b692..db836b27ab 100644 --- a/tests/end_to_end/tpu/qwen/moe/qwen3-30b-a3b/1_test_qwen3_30b_a3b.sh +++ b/tests/end_to_end/tpu/qwen/moe/qwen3-30b-a3b/1_test_qwen3_30b_a3b.sh @@ -40,7 +40,7 @@ echo "Validating MaxText checkpoint at ${MAXTEXT_CHECKPOINT_PATH}" echo "Against original HF model: ${HF_MODEL_PATH}" # This command runs the core validation logic. -JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \ +JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ tokenizer_type=huggingface \ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \ megablox=False \ diff --git a/tests/end_to_end/tpu/qwen/moe/qwen3-480b-a35b/1_test_qwen3_480b_a35b.sh b/tests/end_to_end/tpu/qwen/moe/qwen3-480b-a35b/1_test_qwen3_480b_a35b.sh index 34526877c5..417f7def76 100644 --- a/tests/end_to_end/tpu/qwen/moe/qwen3-480b-a35b/1_test_qwen3_480b_a35b.sh +++ b/tests/end_to_end/tpu/qwen/moe/qwen3-480b-a35b/1_test_qwen3_480b_a35b.sh @@ -40,7 +40,7 @@ echo "Validating MaxText checkpoint at ${MAXTEXT_CHECKPOINT_PATH}" echo "Against original HF model: ${HF_MODEL_PATH}" # This command runs the core validation logic. -JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \ +JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ tokenizer_type=huggingface \ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \ megablox=False \ diff --git a/tests/end_to_end/tpu/qwen/moe/run_qwen_moe.md b/tests/end_to_end/tpu/qwen/moe/run_qwen_moe.md index 356fd0111f..617b17d943 100644 --- a/tests/end_to_end/tpu/qwen/moe/run_qwen_moe.md +++ b/tests/end_to_end/tpu/qwen/moe/run_qwen_moe.md @@ -44,7 +44,7 @@ Pre-training and Fine-tuning After converting the checkpoint, you can use it for fine-tuning or start a pre-training run from scratch. The command below is an example for fine-tuning on a v5p-512 slice. To pre-train, simply remove the `load_parameters_path` argument. ``` -python3 -m MaxText.train src/MaxText/configs/base.yml\ +python3 -m MaxText.train src/maxtext/configs/base.yml\ base_output_directory=${BASE_OUTPUT_DIRECTORY}\ dataset_path=${DATASET_PATH}\ load_parameters_path=gs://your-gcs-bucket/qwen3_maxtext_ckpt/0/items\ @@ -67,7 +67,7 @@ Decoding To generate text with a trained model, use the `decode` command. The command below is an example for decoding on a v5p-512 slice. ``` -python3 -m maxtext.decode src/MaxText/configs/base.yml\ +python3 -m maxtext.decode src/maxtext/configs/base.yml\ load_parameters_path=gs://your-gcs-bucket/qwen3_maxtext_ckpt/0/items\ tokenizer_type=huggingface\ tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer\ diff --git a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh index d444fea988..77476eaf6b 100644 --- a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh +++ b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh @@ -40,7 +40,7 @@ echo "Validating MaxText checkpoint at ${MAXTEXT_CHECKPOINT_PATH}" echo "Against original HF model: ${HF_MODEL_PATH}" # This command runs the core validation logic. -JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \ +JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ tokenizer_type=huggingface \ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \ megablox=False \ diff --git a/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md b/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md index ad5b5162ed..6b452bf364 100644 --- a/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md +++ b/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md @@ -36,7 +36,7 @@ Pre-training and Fine-tuning After converting the checkpoint, you can use it for fine-tuning or start a pre-training run from scratch. The command below is an example for fine-tuning on a v5p-512 slice. To pre-train, simply remove the `load_parameters_path` argument. ``` -python3 -m MaxText.train src/MaxText/configs/base.yml \ +python3 -m MaxText.train src/maxtext/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ dataset_path=${DATASET_PATH} \ load_parameters_path=gs://your-gcs-bucket/qwen3_next_maxtext_ckpt/0/items \ diff --git a/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh b/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh index 9c36b34c9c..7ff15fdba9 100644 --- a/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh +++ b/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh @@ -30,7 +30,7 @@ export CKPT_PATH=gs://maxtext-qwen/qwen3/4b/unscanned/2025-08-04-21-31/0/items # export HF_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/hf/${idx} export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx} -python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${MODEL_NAME} \ hf_access_token=${HF_TOKEN} \ load_parameters_path=${CKPT_PATH} \ @@ -43,7 +43,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MA # We also test whether the forward pass logits match the original HF model # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \ load_parameters_path=${CKPT_PATH} \ model_name=${MODEL_NAME} \ diff --git a/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh b/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh index 0e3ece3f24..935f84f7ef 100644 --- a/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh +++ b/tests/end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh @@ -28,7 +28,7 @@ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu export MODEL_BUCKET=gs://maxtext-qwen/qwen3 # To get unscanned ckpt: -python -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${MODEL_NAME} \ hf_access_token=${HF_TOKEN} \ base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx} \ @@ -38,7 +38,7 @@ export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0 # We also test whether the forward pass logits match the original HF model # to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu` -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ tokenizer_path=${TOKENIZER_PATH}\ load_parameters_path=${UNSCANNED_CKPT_PATH} \ model_name=${MODEL_NAME} \ @@ -48,7 +48,7 @@ python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_ --run_hf_model=True # We can run decoding for unscanned checkpoints. -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" # # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data export DATASET_PATH=gs://maxtext-dataset @@ -58,7 +58,7 @@ export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs # We can also run finetuning by using the scanned converted checkpoint. # Note that scanned checkpoint helps with efficient finetuning export FINETUNE_RUN_NAME=runner_finetune_${idx} -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} checkpoint_period=5 +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} checkpoint_period=5 # Now, run decoding on the checkpoint generated from our finetune run. -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to" diff --git a/tests/end_to_end/tpu/run_sft.sh b/tests/end_to_end/tpu/run_sft.sh index 9277bb34e3..9476eaa9d6 100644 --- a/tests/end_to_end/tpu/run_sft.sh +++ b/tests/end_to_end/tpu/run_sft.sh @@ -49,7 +49,7 @@ RUN_NAME=$(date +%Y-%m-%d-%H-%M-%S) if [ -z "${PRE_TRAINED_MODEL_CKPT_PATH}" ]; then echo "PRE_TRAINED_MODEL_CKPT_PATH is not set. Converting Hugging Face checkpoint to MaxText format." CONVERTED_CKPT_DIR=${BASE_OUTPUT_DIRECTORY}/${PRE_TRAINED_MODEL}/${RUN_NAME}/maxtext-checkpoint - python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ + python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${PRE_TRAINED_MODEL} \ hf_access_token=${HF_TOKEN} \ base_output_directory=${CONVERTED_CKPT_DIR} \ @@ -59,7 +59,7 @@ fi echo "Running fine-tuning on checkpoint: ${PRE_TRAINED_MODEL_CKPT_PATH}" # Run Supervised Fine-Tuning on MaxText checkpoint using HuggingFaceH4/ultrachat_200k dataset -python3 -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/sft.yml \ +python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//sft.yml \ run_name=${RUN_NAME} base_output_directory=${BASE_OUTPUT_DIRECTORY}/${PRE_TRAINED_MODEL} \ model_name=${PRE_TRAINED_MODEL} load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \ hf_access_token=$HF_TOKEN tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \ @@ -82,7 +82,7 @@ echo "Fine-tuned model checkpoint: ${FINE_TUNED_MODEL_CKPT_PATH}" # Convert the fine-tuned MaxText checkpoint to Hugging Face checkpoint export LOCAL_PATH=./tmp/hf/${PRE_TRAINED_MODEL}/${RUN_NAME} -python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \ +python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml \ model_name=${PRE_TRAINED_MODEL} \ hf_access_token=${HF_TOKEN} \ load_parameters_path=${FINE_TUNED_MODEL_CKPT_PATH} \ diff --git a/tests/end_to_end/tpu/test_checkpoint_resharding.sh b/tests/end_to_end/tpu/test_checkpoint_resharding.sh index 57bf9ace65..b21426ebf4 100644 --- a/tests/end_to_end/tpu/test_checkpoint_resharding.sh +++ b/tests/end_to_end/tpu/test_checkpoint_resharding.sh @@ -6,12 +6,12 @@ OUTPUT_PATH=${2} DATASET_PATH=${3} # Train and save checkpoint - sharded with DCN Data Parallelism + ICI FSDP Parallelism -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=$RUN_NAME steps=101\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME steps=101\ metrics_file='saved_metrics.txt' checkpoint_period=20 base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=4 ici_tensor_parallelism=1 collect_stack_trace=False # Retrieve checkpoint - sharded with DCN Data Parallelism + ICI FSDP + Tensor Parallelism -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=$RUN_NAME steps=102\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME steps=102\ metrics_file='restored_metrics.txt' base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH\ dcn_data_parallelism=2 dcn_fsdp_parallelism=1 ici_fsdp_parallelism=2 ici_tensor_parallelism=2 collect_stack_trace=False diff --git a/tests/end_to_end/tpu/test_convergence_1b_params.sh b/tests/end_to_end/tpu/test_convergence_1b_params.sh index 35a674074e..49fd76c5d5 100644 --- a/tests/end_to_end/tpu/test_convergence_1b_params.sh +++ b/tests/end_to_end/tpu/test_convergence_1b_params.sh @@ -56,7 +56,7 @@ then hf_eval_files=$DATASET_PATH/hf/c4/c4-validation-*.parquet " fi -TRAIN_CMD="python3 -m MaxText.train ${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/configs/base.yml \ +TRAIN_CMD="python3 -m MaxText.train ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml \ steps=$STEPS eval_steps=$EVAL_STEPS eval_interval=$EVAL_INTERVAL \ per_device_batch_size=$PER_DEVICE_BATCH_SIZE learning_rate=3e-4 enable_checkpointing=false \ max_target_length=2048 global_parameter_scale=1 \ diff --git a/tests/end_to_end/tpu/test_decode_load_quantized_ckpt.sh b/tests/end_to_end/tpu/test_decode_load_quantized_ckpt.sh index 28cc13e5c7..cac72ff8a5 100644 --- a/tests/end_to_end/tpu/test_decode_load_quantized_ckpt.sh +++ b/tests/end_to_end/tpu/test_decode_load_quantized_ckpt.sh @@ -44,7 +44,7 @@ mkdir -p $OUTDIR echo # Run script ${cmd} python3 -m MaxText.${script_name} \ - "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${LOAD_PARAMETERS_PATH} \ checkpoint_is_quantized=True \ diff --git a/tests/end_to_end/tpu/test_decode_save_quantized_ckpt.sh b/tests/end_to_end/tpu/test_decode_save_quantized_ckpt.sh index 83e2cdbdfd..b8a5bbe3b7 100644 --- a/tests/end_to_end/tpu/test_decode_save_quantized_ckpt.sh +++ b/tests/end_to_end/tpu/test_decode_save_quantized_ckpt.sh @@ -51,7 +51,7 @@ mkdir -p $OUTDIR echo # Run command ${cmd} python3 -m maxtext.decode \ - "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${LOAD_PARAMETERS_PATH} \ max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ diff --git a/tests/end_to_end/tpu/test_dpo.sh b/tests/end_to_end/tpu/test_dpo.sh index 1fdd9c17ea..9e965778c3 100644 --- a/tests/end_to_end/tpu/test_dpo.sh +++ b/tests/end_to_end/tpu/test_dpo.sh @@ -9,7 +9,7 @@ export GEMMA_2B_CKPT_PATH=$(gcloud storage ls gs://maxtext-gemma/gemma2/2b | sor LOGS="gs://maxtext-external/logs" # tfds pipeline -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma \ run_name="$RUN_NAME-tfds" model_name=gemma2-2b base_output_directory=${LOGS} \ load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ per_device_batch_size=0.5 allow_split_physical_axes=True \ @@ -18,7 +18,7 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT # grain pipeline mkdir -p /tmp/anthropic_rlhf || true gcloud storage cp -r gs://maxtext-dataset/dpo/anthropic_rlhf/array_record /tmp/anthropic_rlhf -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/dpo.yml tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/tokenizer.gemma \ run_name="$RUN_NAME-grain" model_name=gemma2-2b base_output_directory=${LOGS} \ load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ dataset_type=grain grain_worker_count=16 \ @@ -28,7 +28,7 @@ python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxT ici_data_parallelism=2 ici_tensor_parallelism=2 ici_fsdp_parallelism=1 # hf pipeline -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/dpo.yml tokenizer_path='google/gemma-2-2b-it' \ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/dpo.yml tokenizer_path='google/gemma-2-2b-it' \ run_name="$RUN_NAME-grain" model_name=gemma2-2b base_output_directory=${LOGS} \ load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items \ dataset_type=hf hf_access_token=$HF_TOKEN hf_path='Anthropic/hh-rlhf' \ diff --git a/tests/end_to_end/tpu/test_gpt3.sh b/tests/end_to_end/tpu/test_gpt3.sh index 064a15037e..114d48cfe7 100644 --- a/tests/end_to_end/tpu/test_gpt3.sh +++ b/tests/end_to_end/tpu/test_gpt3.sh @@ -9,7 +9,7 @@ export RUN_NAME=test_${TIMESTAMP} python3 -m MaxText.utils.ckpt_scripts.convert_gpt3_ckpt_from_paxml --paxml-ckpt-path=${PAXML_CKPT_PATH} --maxtext-model-name=gpt3-52k --run-name=${RUN_NAME} --base-output-directory=${OUTPUT_PATH} # Run gpt3-52k with the converted ckpt -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=${RUN_NAME} model_name=gpt3-52k\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=${RUN_NAME} model_name=gpt3-52k\ steps=10 per_device_batch_size=6 enable_checkpointing=true async_checkpointing=false\ remat_policy=full max_target_length=2048 base_output_directory=${OUTPUT_PATH}\ dataset_type=synthetic diff --git a/tests/end_to_end/tpu/test_sft_trainer.sh b/tests/end_to_end/tpu/test_sft_trainer.sh index 3caa88a68a..b0c3a8d8d3 100755 --- a/tests/end_to_end/tpu/test_sft_trainer.sh +++ b/tests/end_to_end/tpu/test_sft_trainer.sh @@ -19,7 +19,7 @@ PER_DEVICE_BATCH_SIZE=1 LOSS_THRESHOLD=100.0 # Set to large value so test is guaranteed to pass # SFT with HF pipeline -python3 -m MaxText.sft_trainer "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/sft.yml \ +python3 -m MaxText.sft_trainer "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/sft.yml \ run_name=${RUN_NAME}-hf base_output_directory=${BASE_OUTPUT_DIRECTORY} \ model_name=${PRE_TRAINED_MODEL} load_parameters_path=${PRE_TRAINED_MODEL_CKPT_PATH} \ dataset_type=hf hf_access_token=$HF_TOKEN tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} \ @@ -45,7 +45,7 @@ largest_dir="${sorted_dirs[-1]}" FINE_TUNED_MODEL_CKPT_PATH=${CHECKPOINTS_PATH}/${largest_dir}/items # Decode -python3 -m maxtext.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/sft.yml \ +python3 -m maxtext.decode "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/sft.yml \ run_name=${RUN_NAME}-hf-decode \ model_name=${PRE_TRAINED_MODEL} tokenizer_path=${PRE_TRAINED_MODEL_TOKENIZER} tokenizer_type=huggingface \ load_parameters_path=${FINE_TUNED_MODEL_CKPT_PATH} \ diff --git a/tests/end_to_end/tpu/test_vocab_creation.sh b/tests/end_to_end/tpu/test_vocab_creation.sh index 67b330088f..c437fdc95d 100644 --- a/tests/end_to_end/tpu/test_vocab_creation.sh +++ b/tests/end_to_end/tpu/test_vocab_creation.sh @@ -8,7 +8,7 @@ VOCAB_PATH=$OUTPUT_PATH/vocab_test_creation_$RUN_NAME #Train -python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"/configs/base.yml run_name=$RUN_NAME steps=5 enable_checkpointing=False\ +python3 -m MaxText.train "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml run_name=$RUN_NAME steps=5 enable_checkpointing=False\ base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH tokenizer_path=$VOCAB_PATH python3 tests/end_to_end/tpu/eval_assert.py vocab_creation $VOCAB_PATH diff --git a/tests/integration/aot_identical_test.py b/tests/integration/aot_identical_test.py index ac3e8ef969..40150204b8 100644 --- a/tests/integration/aot_identical_test.py +++ b/tests/integration/aot_identical_test.py @@ -26,7 +26,7 @@ import hashlib import re import jax -from MaxText.globals import MAXTEXT_PKG_DIR +from tests.utils.test_helpers import get_test_config_path from MaxText import train_compile from MaxText import train @@ -123,7 +123,7 @@ def assert_compile_and_real_match_hlo(self, test_name, *extra_args): # xla flag only sets once for train.main os.makedirs(local_landing_dir, exist_ok=True) train_argv = ( - (None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")) + (None, get_test_config_path()) + tuple(shared_args) + ( f"dump_hlo_xla_flags=--xla_dump_to={local_landing_dir} " @@ -139,7 +139,7 @@ def assert_compile_and_real_match_hlo(self, test_name, *extra_args): os.makedirs(local_landing_dir, exist_ok=True) topology = self.get_device_user_facing_name() aot_args = [f"compile_topology={topology}", "compile_topology_num_slices=1"] - compile_argv = (None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")) + tuple(shared_args) + tuple(aot_args) + compile_argv = (None, get_test_config_path()) + tuple(shared_args) + tuple(aot_args) train_compile.main(compile_argv) shutil.move(local_landing_dir, compile_dump_dir) jax.clear_caches() @@ -191,7 +191,7 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args): # Run train.py and dump jaxpr train_argv = ( None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), f"dump_jaxpr_local_dir={train_dump_dir}", ) + tuple(shared_args) train.main(train_argv) @@ -201,7 +201,7 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args): topology = self.get_device_user_facing_name() compile_argv = ( None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), f"dump_jaxpr_local_dir={compile_dump_dir}", f"compile_topology={topology}", "compile_topology_num_slices=1", diff --git a/tests/integration/decode_tests.py b/tests/integration/decode_tests.py index 17c9786f5c..1d5084aa3a 100644 --- a/tests/integration/decode_tests.py +++ b/tests/integration/decode_tests.py @@ -23,7 +23,7 @@ from contextlib import redirect_stdout from maxtext.decode import main as decode_main -from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT +from MaxText.globals import MAXTEXT_ASSETS_ROOT from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory pytestmark = [pytest.mark.tpu_only, pytest.mark.external_serving, pytest.mark.integration_test] @@ -80,7 +80,7 @@ class DecodeTests(unittest.TestCase): ], "decode_sampling": [ None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), "base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", f"load_parameters_path={GEMMA_2B_CKPT_PATH}", diff --git a/tests/integration/determinism_test.py b/tests/integration/determinism_test.py index 0dc7d2fad7..245de59581 100644 --- a/tests/integration/determinism_test.py +++ b/tests/integration/determinism_test.py @@ -21,13 +21,12 @@ import datetime import json -import os import unittest import pytest from MaxText.train import main as train_main -from MaxText.globals import MAXTEXT_PKG_DIR +from tests.utils.test_helpers import get_test_config_path pytestmark = pytest.mark.integration_test @@ -52,7 +51,7 @@ def test_determinism(self): run_name = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") common_config = [ None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), "steps=5", "enable_checkpointing=False", "enable_data_shuffling=True", diff --git a/tests/integration/gradient_accumulation_test.py b/tests/integration/gradient_accumulation_test.py index 7cc130ecb3..9b86ff5eb0 100644 --- a/tests/integration/gradient_accumulation_test.py +++ b/tests/integration/gradient_accumulation_test.py @@ -27,7 +27,7 @@ from MaxText.train import main as train_main from MaxText.sft_trainer import main as sft_main -from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR +from MaxText.globals import MAXTEXT_ASSETS_ROOT from maxtext.common.gcloud_stub import is_decoupled from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory @@ -151,7 +151,7 @@ def test_sft_grad_accumulate_same_loss(self): sft_main( [ None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), "base_output_directory=gs://runner-maxtext-logs", "dataset_path=gs://maxtext-dataset", "gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off). diff --git a/tests/integration/smoke/inference_microbenchmark_smoke_test.py b/tests/integration/smoke/inference_microbenchmark_smoke_test.py index fc79a9ae11..f3d0dfb7a9 100644 --- a/tests/integration/smoke/inference_microbenchmark_smoke_test.py +++ b/tests/integration/smoke/inference_microbenchmark_smoke_test.py @@ -43,7 +43,7 @@ def test(self): config = pyconfig.initialize( [ None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "tpu_smoke_test.yml"), + os.path.join(MAXTEXT_PKG_DIR, "configs", "tpu", "tpu_smoke_test.yml"), rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}", "ici_autoregressive_parallelism=-1", "ici_fsdp_parallelism=1", diff --git a/tests/integration/smoke/train_gpu_smoke_test.py b/tests/integration/smoke/train_gpu_smoke_test.py index 383e552f28..4dd2ac4116 100644 --- a/tests/integration/smoke/train_gpu_smoke_test.py +++ b/tests/integration/smoke/train_gpu_smoke_test.py @@ -43,7 +43,7 @@ def test_tiny_config(self): train_main( [ None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "gpu_smoke_test.yml"), + os.path.join(MAXTEXT_PKG_DIR, "configs", "gpu", "gpu_smoke_test.yml"), # pylint: disable=f-string-without-interpolation f"base_output_directory={self.base_output_directory}", "run_name=runner_test", diff --git a/tests/integration/smoke/train_smoke_test.py b/tests/integration/smoke/train_smoke_test.py index 34ef7e6abe..13c06bf701 100644 --- a/tests/integration/smoke/train_smoke_test.py +++ b/tests/integration/smoke/train_smoke_test.py @@ -20,7 +20,7 @@ from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory from MaxText.train import main as train_main -from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT +from MaxText.globals import MAXTEXT_ASSETS_ROOT from maxtext.common.gcloud_stub import is_decoupled @@ -70,7 +70,7 @@ def test_tiny_config_no_scan(self): train_main( [ None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), # pylint: disable=f-string-without-interpolation f"base_output_directory={self.base_output_directory}", "run_name=runner_test", diff --git a/tests/integration/smoke/train_using_ragged_dot_smoke_test.py b/tests/integration/smoke/train_using_ragged_dot_smoke_test.py index 6a9cf48fbd..c7d979f5fa 100644 --- a/tests/integration/smoke/train_using_ragged_dot_smoke_test.py +++ b/tests/integration/smoke/train_using_ragged_dot_smoke_test.py @@ -21,10 +21,9 @@ from absl.testing import parameterized from tests.utils.test_helpers import get_test_config_path -from MaxText import globals as maxtext_globals, train +from MaxText import globals as train train_main = train.main -MAXTEXT_PKG_DIR = maxtext_globals.MAXTEXT_PKG_DIR gettempdir = tempfile.gettempdir diff --git a/tests/integration/train_tests.py b/tests/integration/train_tests.py index 5d028df2e2..de220b3cb0 100644 --- a/tests/integration/train_tests.py +++ b/tests/integration/train_tests.py @@ -18,7 +18,7 @@ import pytest import jax from MaxText.train import main as train_main -from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_PKG_DIR +from MaxText.globals import MAXTEXT_ASSETS_ROOT from maxtext.common.gcloud_stub import is_decoupled from tests.utils.test_helpers import get_test_config_path, get_test_dataset_path, get_test_base_output_directory from absl.testing import absltest @@ -476,7 +476,7 @@ def test_base_model_shardy_false(self): def test_tpu_zero1_gradient_accumulation(self): zero1_ga = [ # tests Zero-1 optimizer sharding with gradient accumulation None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), "base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", "dataset_path=gs://maxtext-dataset", diff --git a/tests/integration/xaot_test.py b/tests/integration/xaot_test.py index cad68b08c0..cf45125045 100644 --- a/tests/integration/xaot_test.py +++ b/tests/integration/xaot_test.py @@ -24,7 +24,7 @@ import os import shutil import jax -from MaxText.globals import MAXTEXT_PKG_DIR +from tests.utils.test_helpers import get_test_config_path from MaxText import train_compile from MaxText import train @@ -94,9 +94,7 @@ def run_compile_then_load(self, test_name, *extra_args): f"compiled_trainstep_file={self.pickle_file}", ] - compile_argv = ( - (None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")) + tuple(shared_args) + tuple(compile_specific_args) - ) + compile_argv = (None, get_test_config_path()) + tuple(shared_args) + tuple(compile_specific_args) print(f"\n--- Starting Compilation Step for {test_name} ---") # Clear caches before compile to ensure clean state @@ -112,9 +110,7 @@ def run_compile_then_load(self, test_name, *extra_args): f"compiled_trainstep_file={self.pickle_file}", ] - train_argv = ( - (None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")) + tuple(shared_args) + tuple(load_specific_args) - ) + train_argv = (None, get_test_config_path()) + tuple(shared_args) + tuple(load_specific_args) print(f"\n--- Starting Load/Train Step for {test_name} ---") # Clear caches before train to ensure we are actually loading from the pickle diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index cb25c48b2d..2914518ebc 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -15,7 +15,6 @@ """Tests for Attentions.""" import itertools -import os.path import random import sys import unittest @@ -37,7 +36,6 @@ MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, ) -from MaxText.globals import MAXTEXT_PKG_DIR from MaxText.layers.attention_mla import MLA from MaxText.layers.attention_op import ChunkedCausalMask, _generate_chunk_attention_mask, _make_bidirectional_block_mask from MaxText.layers.attentions import Attention @@ -1118,7 +1116,7 @@ def test_forward_serve_vllm(self, mock_sharded_ragged_paged_attention): vllm_config_arguments["attention"] = "vllm_rpa" vllm_config_arguments["chunk_attn_window_size"] = 128 config = pyconfig.initialize( - [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [sys.argv[0], get_test_config_path()], **vllm_config_arguments, ) diff --git a/tests/unit/configs_test.py b/tests/unit/configs_test.py index 44dda1df3a..09ceb03f44 100644 --- a/tests/unit/configs_test.py +++ b/tests/unit/configs_test.py @@ -32,11 +32,11 @@ from pydantic import ValidationError from yaml import YAMLError -from MaxText.configs import types as pydantic_types +from maxtext.configs import types as pydantic_types from MaxText.globals import MAXTEXT_REPO_ROOT # Define the root directory where configuration files are located. -CONFIGS_DIR = os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "configs") +CONFIGS_DIR = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "configs") @functools.lru_cache(maxsize=None) @@ -113,14 +113,14 @@ def run_config_validation(config_file_path: str): BASE_CONFIGS = [ os.path.join(CONFIGS_DIR, "base.yml"), - os.path.join(CONFIGS_DIR, "dpo.yml"), - os.path.join(CONFIGS_DIR, "gpu_smoke_test.yml"), - os.path.join(CONFIGS_DIR, "rl.yml"), - os.path.join(CONFIGS_DIR, "rl_mt_jt.yml"), - os.path.join(CONFIGS_DIR, "sft.yml"), - os.path.join(CONFIGS_DIR, "sft-vision-chartqa.yml"), - os.path.join(CONFIGS_DIR, "sft-vision-slidevqa.yml"), - os.path.join(CONFIGS_DIR, "tpu_smoke_test.yml"), + os.path.join(CONFIGS_DIR, "post_train", "dpo.yml"), + os.path.join(CONFIGS_DIR, "gpu/gpu_smoke_test.yml"), + os.path.join(CONFIGS_DIR, "post_train", "rl.yml"), + os.path.join(CONFIGS_DIR, "post_train", "rl_mt_jt.yml"), + os.path.join(CONFIGS_DIR, "post_train", "sft.yml"), + os.path.join(CONFIGS_DIR, "post_train", "sft-vision-chartqa.yml"), + os.path.join(CONFIGS_DIR, "post_train", "sft-vision-slidevqa.yml"), + os.path.join(CONFIGS_DIR, "tpu/tpu_smoke_test.yml"), ] @@ -151,11 +151,11 @@ def test_gemma_configs(config_file): # --- Test Group 3: Llama Model Family --- LLAMA_CONFIGS = [ - os.path.join(CONFIGS_DIR, "models", "gpu", "llama2_7b.yml"), - os.path.join(CONFIGS_DIR, "models", "gpu", "llama2_70b.yml"), - os.path.join(CONFIGS_DIR, "models", "gpu", "llama3_8b.yml"), - os.path.join(CONFIGS_DIR, "models", "gpu", "llama3_70b.yml"), - os.path.join(CONFIGS_DIR, "models", "gpu", "llama3.1_405b.yml"), + os.path.join(CONFIGS_DIR, "gpu", "models", "llama2_7b.yml"), + os.path.join(CONFIGS_DIR, "gpu", "models", "llama2_70b.yml"), + os.path.join(CONFIGS_DIR, "gpu", "models", "llama3_8b.yml"), + os.path.join(CONFIGS_DIR, "gpu", "models", "llama3_70b.yml"), + os.path.join(CONFIGS_DIR, "gpu", "models", "llama3.1_405b.yml"), os.path.join(CONFIGS_DIR, "models", "llama2-7b.yml"), os.path.join(CONFIGS_DIR, "models", "llama2-13b.yml"), os.path.join(CONFIGS_DIR, "models", "llama2-70b.yml"), @@ -215,9 +215,9 @@ def test_deepseek_configs(config_file): os.path.join(CONFIGS_DIR, "models", "mistral-7b.yml"), os.path.join(CONFIGS_DIR, "models", "mixtral-8x7b.yml"), os.path.join(CONFIGS_DIR, "models", "mixtral-8x22b.yml"), - os.path.join(CONFIGS_DIR, "models", "gpu", "mixtral_8x1b.yml"), - os.path.join(CONFIGS_DIR, "models", "gpu", "mixtral_8x2b.yml"), - os.path.join(CONFIGS_DIR, "models", "gpu", "mixtral_8x7b.yml"), + os.path.join(CONFIGS_DIR, "gpu", "models", "mixtral_8x1b.yml"), + os.path.join(CONFIGS_DIR, "gpu", "models", "mixtral_8x2b.yml"), + os.path.join(CONFIGS_DIR, "gpu", "models", "mixtral_8x7b.yml"), ] @@ -263,30 +263,30 @@ def test_kimi_configs(config_file): # --- Test Group 9: Inference-specific Configs --- INFERENCE_CONFIGS = [ - os.path.join(CONFIGS_DIR, "inference.yml"), - os.path.join(CONFIGS_DIR, "inference_jetstream.yml"), - os.path.join(CONFIGS_DIR, "v5e", "llama2_70b_v5e-16.yml"), - os.path.join(CONFIGS_DIR, "v5e", "llama3_70b_v5e-16.yml"), - os.path.join(CONFIGS_DIR, "v5e", "llama3_405b_v5e-64.yml"), - os.path.join(CONFIGS_DIR, "v6e", "inference", "llama4_maverick_v6e-64.yml"), + os.path.join(CONFIGS_DIR, "inference", "inference.yml"), + os.path.join(CONFIGS_DIR, "inference", "inference_jetstream.yml"), + os.path.join(CONFIGS_DIR, "tpu", "v5e", "llama2_70b_v5e-16.yml"), + os.path.join(CONFIGS_DIR, "tpu", "v5e", "llama3_70b_v5e-16.yml"), + os.path.join(CONFIGS_DIR, "tpu", "v5e", "llama3_405b_v5e-64.yml"), + os.path.join(CONFIGS_DIR, "tpu", "v6e", "inference", "llama4_maverick_v6e-64.yml"), os.path.join( MAXTEXT_REPO_ROOT, "src", "maxtext", - "inference", "configs", - "multi_host", + "inference", + "multihost", "disaggregation", "llama3_405b_v6e-16-16.yml", ), os.path.join( - MAXTEXT_REPO_ROOT, "src", "maxtext", "inference", "configs", "multi_host", "interleaved", "llama2_70b_v5e-16.yml" + MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "inference", "multihost", "interleaved", "llama2_70b_v5e-16.yml" ), os.path.join( - MAXTEXT_REPO_ROOT, "src", "maxtext", "inference", "configs", "multi_host", "interleaved", "llama3_70b_v5e-16.yml" + MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "inference", "multihost", "interleaved", "llama3_70b_v5e-16.yml" ), os.path.join( - MAXTEXT_REPO_ROOT, "src", "maxtext", "inference", "configs", "multi_host", "interleaved", "llama3_405b_v5e-64.yml" + MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "inference", "multihost", "interleaved", "llama3_405b_v5e-64.yml" ), ] diff --git a/tests/unit/configs_value_test.py b/tests/unit/configs_value_test.py index 4cf528b973..bd7f5a7d0b 100644 --- a/tests/unit/configs_value_test.py +++ b/tests/unit/configs_value_test.py @@ -21,12 +21,12 @@ import pydantic from MaxText import pyconfig -from MaxText.configs import types from MaxText.globals import MAXTEXT_REPO_ROOT from MaxText.pyconfig import initialize_pydantic +from maxtext.configs import types # Path to the base.yml config. This assumes that `pytest` is run from the project root. -_BASE_CONFIG_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "configs", "base.yml") +_BASE_CONFIG_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "base.yml") class ConfigTest(unittest.TestCase): diff --git a/tests/unit/deepseek32_vs_reference_test.py b/tests/unit/deepseek32_vs_reference_test.py index 3a4191b6d4..0cb119d9f1 100644 --- a/tests/unit/deepseek32_vs_reference_test.py +++ b/tests/unit/deepseek32_vs_reference_test.py @@ -32,7 +32,6 @@ """ -import os.path import math from dataclasses import dataclass, asdict from typing import Optional @@ -51,11 +50,11 @@ import jax.numpy as jnp from flax import nnx -from MaxText.globals import MAXTEXT_PKG_DIR from MaxText import pyconfig from MaxText.layers import embeddings, attention_mla from MaxText.common_types import MODEL_MODE_TRAIN from maxtext.utils import maxtext_utils +from tests.utils.test_helpers import get_test_config_path # ----------------------------------------------------------------------------- @@ -754,7 +753,7 @@ def get_jax_mla_weights(pt_mla, cfg): def get_cfg_and_mesh(config, run_name, dtype, batch_size, seq_len): """Returns MaxText configuration and mesh.""" cfg = pyconfig.initialize( - [None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [None, get_test_config_path()], run_name=run_name, enable_checkpointing=False, model_name="default", diff --git a/tests/unit/grain_data_processing_test.py b/tests/unit/grain_data_processing_test.py index fbb80cba13..cbb11bcd2e 100644 --- a/tests/unit/grain_data_processing_test.py +++ b/tests/unit/grain_data_processing_test.py @@ -30,7 +30,7 @@ from MaxText import pyconfig from MaxText.input_pipeline import _grain_data_processing from MaxText.input_pipeline import input_pipeline_interface -from MaxText.globals import MAXTEXT_PKG_DIR, MAXTEXT_ASSETS_ROOT, MAXTEXT_REPO_ROOT +from MaxText.globals import MAXTEXT_ASSETS_ROOT, MAXTEXT_REPO_ROOT from maxtext.common.gcloud_stub import is_decoupled from tests.utils.test_helpers import get_test_base_output_directory, get_test_config_path, get_test_dataset_path @@ -246,7 +246,7 @@ def setUp(self): json.dump(mixture_config, f) self.config = pyconfig.initialize( - [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [sys.argv[0], get_test_config_path()], per_device_batch_size=1, run_name="test", mesh_axes=["data"], @@ -300,7 +300,7 @@ def setUp(self): base_output_directory = "gs://max-experiments/" self.config = pyconfig.initialize( - [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [sys.argv[0], get_test_config_path()], per_device_batch_size=1, run_name="test", mesh_axes=["data"], @@ -373,7 +373,7 @@ def setUp(self): base_output_directory = "gs://max-experiments/" self.config = pyconfig.initialize( - [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [sys.argv[0], get_test_config_path()], per_device_batch_size=1, run_name="test", mesh_axes=["data"], diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index ffbc715913..fbd8fc9b7e 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -13,7 +13,6 @@ # limitations under the License. """ Mixture of Experts (MoE) tests. """ -import os.path import unittest import pytest @@ -30,7 +29,6 @@ from maxtext.common.gcloud_stub import is_decoupled from MaxText import pyconfig from MaxText.common_types import Config, DType -from MaxText.globals import MAXTEXT_PKG_DIR from MaxText.layers import linears from MaxText.layers import moe from MaxText.layers.initializers import NdInitializer, nd_dense_init, variable_to_logically_partitioned @@ -687,7 +685,7 @@ def test_megablox_context_parallelism(self): @pytest.mark.tpu_only def test_megablox_expert_context_parallelism(self): cfg = pyconfig.initialize( - [None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [None, get_test_config_path()], run_name="moe_block_megablox_ep_cp_test", enable_checkpointing=False, model_name="mixtral-8x7b", @@ -720,7 +718,7 @@ def test_megablox_expert_context_parallelism(self): @pytest.mark.tpu_only def test_megablox_expert_tensor_parallelism(self): cfg = pyconfig.initialize( - [None, os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [None, get_test_config_path()], run_name="moe_block_megablox_ep_tp_test", enable_checkpointing=False, model_name="mixtral-8x7b", diff --git a/tests/unit/pyconfig_deprecated_test.py b/tests/unit/pyconfig_deprecated_test.py index 160ce07633..753925d584 100644 --- a/tests/unit/pyconfig_deprecated_test.py +++ b/tests/unit/pyconfig_deprecated_test.py @@ -20,6 +20,7 @@ from MaxText import pyconfig from MaxText.globals import MAXTEXT_PKG_DIR from MaxText.pyconfig_deprecated import resolve_config_path +from tests.utils.test_helpers import get_test_config_path class PyconfigTest(unittest.TestCase): @@ -35,7 +36,7 @@ def test_basic_override(self): def test_empty_string_parse_as_empty_string(self): config = pyconfig.initialize( - [os.path.join(MAXTEXT_PKG_DIR, "train.py"), os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()], skip_jax_distributed_system=True, # We should check for this automatically instead - b/407047411 quantization="", ) @@ -82,7 +83,7 @@ def test_logical_axis_partial_override(self): def test_multiple_unmodifiable_configs(self): config_train = pyconfig.initialize( - [os.path.join(MAXTEXT_PKG_DIR, "train.py"), os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()], per_device_batch_size=1.0, run_name="test", enable_checkpointing=False, @@ -97,7 +98,7 @@ def test_multiple_unmodifiable_configs(self): ici_fsdp_parallelism=4, ) config_inference = pyconfig.initialize( - [os.path.join(MAXTEXT_PKG_DIR, "decode.py"), os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [os.path.join(MAXTEXT_PKG_DIR, "decode.py"), get_test_config_path()], per_device_batch_size=1.0, run_name="test", enable_checkpointing=False, @@ -120,7 +121,7 @@ def test_multiple_unmodifiable_configs(self): def test_overriding_model(self): config = pyconfig.initialize( - [os.path.join(MAXTEXT_PKG_DIR, "train.py"), os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [os.path.join(MAXTEXT_PKG_DIR, "train.py"), get_test_config_path()], skip_jax_distributed_system=True, model_name="gemma-7b", override_model_config=True, diff --git a/tests/unit/qwen3_omni_layers_test.py b/tests/unit/qwen3_omni_layers_test.py index ac19cd64a6..8613283ff9 100644 --- a/tests/unit/qwen3_omni_layers_test.py +++ b/tests/unit/qwen3_omni_layers_test.py @@ -83,7 +83,7 @@ ) # Initialize config once for all tests -base_config_path = os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "configs", "base.yml") +base_config_path = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "base.yml") jax_config = pyconfig.initialize( ["", base_config_path], model_name="qwen3-omni-30b-a3b", @@ -584,7 +584,7 @@ class TestQwen3OmniPreprocessing(unittest.TestCase): """Test MaxText Qwen3 Omni preprocessor against HuggingFace reference.""" def setUp(self): - self.base_config_path = os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "configs", "base.yml") + self.base_config_path = os.path.join(MAXTEXT_REPO_ROOT, "src", "maxtext", "configs", "base.yml") self.image_path = os.path.join(MAXTEXT_REPO_ROOT, "tests", "assets", "test_image.jpg") self.video_path = os.path.join(MAXTEXT_REPO_ROOT, "tests", "assets", "test_video.mp4") self.maxtext_config = pyconfig.initialize( diff --git a/tests/unit/sft_data_processing_test.py b/tests/unit/sft_data_processing_test.py index c92ccd293c..3ad9d57dfe 100644 --- a/tests/unit/sft_data_processing_test.py +++ b/tests/unit/sft_data_processing_test.py @@ -314,7 +314,7 @@ def setUp(self): tokenizer_path = os.path.join(MAXTEXT_ASSETS_ROOT, "llama2-chat-tokenizer") self.config = pyconfig.initialize( - [os.path.join(MAXTEXT_PKG_DIR, "sft_trainer"), os.path.join(MAXTEXT_PKG_DIR, "configs", "sft.yml")], + [os.path.join(MAXTEXT_PKG_DIR, "sft_trainer"), os.path.join(MAXTEXT_PKG_DIR, "configs", "post_train", "sft.yml")], per_device_batch_size=2, run_name="test", mesh_axes=["data"], diff --git a/tests/unit/sft_hooks_test.py b/tests/unit/sft_hooks_test.py index 3186fcb06a..6464d4a984 100644 --- a/tests/unit/sft_hooks_test.py +++ b/tests/unit/sft_hooks_test.py @@ -36,7 +36,7 @@ class SFTHooksTest(unittest.TestCase): def setUp(self): super().setUp() self.config = pyconfig.initialize( - ["", os.path.join(MAXTEXT_PKG_DIR, "configs", "sft.yml")], + ["", os.path.join(MAXTEXT_PKG_DIR, "configs", "post_train", "sft.yml")], per_device_batch_size=1, run_name="test", base_output_directory="test", diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index ba87eab068..1dd31413dc 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -26,7 +26,6 @@ import pytest from MaxText.train_compile import main as train_compile_main -from MaxText.globals import MAXTEXT_PKG_DIR from tests.utils.test_helpers import get_test_config_path pytestmark = [pytest.mark.external_training] @@ -151,7 +150,7 @@ def test_save_compiled_tpu7x(self): train_compile_main( ( None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), f"compiled_trainstep_file={compiled_trainstep_file}", "compile_topology=tpu7x-16", "compile_topology_num_slices=1", @@ -169,7 +168,7 @@ def test_save_compiled_tpu7x_two_slices(self): train_compile_main( ( None, - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), f"compiled_trainstep_file={compiled_trainstep_file}", "compile_topology=tpu7x-8", "compile_topology_num_slices=2", @@ -740,7 +739,7 @@ def test_qwen3_next(self): train_compile_main( ( "", - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), f"compiled_trainstep_file={compiled_trainstep_file}", "compile_topology=v5p-256", "compile_topology_num_slices=1", @@ -757,7 +756,7 @@ def test_deepseek32(self): train_compile_main( ( "", - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), f"compiled_trainstep_file={compiled_trainstep_file}", "compile_topology=v5p-256", "use_iota_embed=true", @@ -784,7 +783,7 @@ def test_olmo3_7b(self): train_compile_main( ( "", - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + get_test_config_path(), f"compiled_trainstep_file={compiled_trainstep_file}", "compile_topology=v5p-8", "compile_topology_num_slices=1", diff --git a/tests/utils/attention_test_util.py b/tests/utils/attention_test_util.py index e78b8da1ae..a8cac27cfd 100644 --- a/tests/utils/attention_test_util.py +++ b/tests/utils/attention_test_util.py @@ -13,7 +13,6 @@ # limitations under the License. """Test util for attention tests.""" -import os import sys from absl.testing import parameterized @@ -27,7 +26,6 @@ from maxtext.common.gcloud_stub import is_decoupled from MaxText import pyconfig from MaxText.common_types import AttentionType, DECODING_ACTIVE_SEQUENCE_INDICATOR, EP_AS_CONTEXT, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, ShardMode -from MaxText.globals import MAXTEXT_PKG_DIR from MaxText.layers.attention_mla import MLA from MaxText.sharding import maybe_shard_with_name from tests.utils.test_helpers import get_test_config_path @@ -81,7 +79,7 @@ def setUp(self): def init_mla(self, config_arguments, rope_type): """Helper function to initialize MLA with different model names.""" cfg = pyconfig.initialize( - [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], + [sys.argv[0], get_test_config_path()], **config_arguments, rope_type=rope_type, ) diff --git a/tests/utils/test_helper.py b/tests/utils/test_helper.py new file mode 100644 index 0000000000..a35ea2d780 --- /dev/null +++ b/tests/utils/test_helper.py @@ -0,0 +1,38 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test utilities file for helper for test configuration path selection. + +Provides a single helper to return the absolute path to a test config. When +running in decoupled mode (DECOUPLE_GCLOUD=TRUE) the decoupled test config is +returned. +""" + +import os +from maxtext.common.gcloud_stub import is_decoupled +from MaxText.globals import MAXTEXT_PKG_DIR + + +def get_test_config_path(): + """Return absolute path to the chosen test config file. + + Returns `decoupled_base_test.yml` when decoupled, otherwise `base.yml`. + """ + base_cfg = "base.yml" + if is_decoupled(): + base_cfg = "decoupled_base_test.yml" + return os.path.join(MAXTEXT_PKG_DIR, "configs", base_cfg) + + +__all__ = ["get_test_config_path"] diff --git a/tools/data_generation/generate_distillation_data.py b/tools/data_generation/generate_distillation_data.py index 2ed8a15474..a1b259723f 100644 --- a/tools/data_generation/generate_distillation_data.py +++ b/tools/data_generation/generate_distillation_data.py @@ -40,7 +40,7 @@ For more information, check out `python3 -m MaxText.generate_distillation_data --help`. Note: Make sure to run maxengine server in a new terminal before executing this command. Example command to run maxengine server: - python3 -m MaxText.maxengine_server src/MaxText/configs/base.yml \ + python3 -m MaxText.maxengine_server src/maxtext/configs/base.yml \ model_name=deepseek2-16b tokenizer_path=deepseek-ai/DeepSeek-V2-Lite-chat tokenizer_type=huggingface \ load_parameters_path= \ max_target_length=2048 max_prefill_predict_length=256 \