Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions optimized/tensorRT/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ onnx/
*.onnx
*.onnx.data

# FP8 calibration data — captured by build/make_calib.py from the model
# checkpoint, a reproducible producer artifact (~hundreds of MB). Regenerated
# on demand, never committed.
*.calib.npz

# T5 tokenizer — generally ignored (downloaded under models/<arch>/t5gemma/
# alongside the engine for legacy fallback), but the canonical copy ships
# bundled at scripts/tokenizer.json (arch-agnostic, 34 MB) so the build path
Expand All @@ -36,3 +41,7 @@ __pycache__/
# But keep __pycache__/ ignored even under build/ (the un-ignore above would
# accidentally re-include build/__pycache__/ otherwise).
build/**/__pycache__/
# Same for make_calib.py's captured calibration npz (the *.calib.npz rule
# above is otherwise shadowed by the build/ un-ignore — make_calib writes it
# next to the build scripts by default).
build/**/*.calib.npz
33 changes: 33 additions & 0 deletions optimized/tensorRT/build/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ python build_from_onnx.py same-s-decoder-fp32 # canonical ONNX is already FP32
python build_from_onnx.py sa3-m-fp32 # reads HF dit.onnx (already FP32)
python build_from_onnx.py all-fp32 # every FP32 target
python build_from_onnx.py all-both # canonical + FP32

# FP8 variant — opt-in, DiT-only. ~1.8x faster steps than FP16-mixed.
# Pair with `sa3_trt --precision fp8` at inference (fp8 DiT + fp16mixed decoder).
python build_from_onnx.py sa3-m-fp8 # reads HF dit_fp8.onnx (ModelOpt PTQ)
```

### Consumer deps
Expand Down Expand Up @@ -148,6 +152,33 @@ This wraps every RMSNorm chain, attention `Softmax`, and the RoPE region in `Cas

Naive `BuilderFlag.FP16` (without the surgery) catastrophically overflows in RMSNorm variance + attention softmax — the islands are mandatory. BF16 was tried earlier and compounds quantisation error over 8 sampling steps (cos-sim drifts from 0.99 single-step to 0.81 final-latent vs PT FP32) — audibly degraded.

### FP8 DiT (opt-in, ~1.8x)

`build_dit_fp8.py` extends the FP16-mixed recipe with a ModelOpt FP8 GEMM trunk: it takes the `dit_fp16mixed.onnx` plus a calibration `.npz` (real DiT inputs across the pingpong schedule) and produces `dit_fp8.onnx`: fp8 weight/activation Q/DQ on the MatMuls, the same FP32 islands re-applied (plus the conditioning front-end, which must stay FP32 or the t>=0.984 timestep features flush), and per-channel weight scales. Validated on sa3-m vs the FP16-mixed engine over the 47 reprompt Music prompts x 8 sigmas (L=646): worst single-step latent cosine 0.9982, 8-step compounded euler final-latent cosine mean 0.953 / median 0.957 / worst 0.873 over the 47 prompts (the compounded rollout is chaotic at the early sigmas, so judge by the distribution and by ear; decoded audio under the production pingpong sampler tracks the FP16-mixed generation at ~0.90 RMS-curve correlation), ~10.6-11.0 ms/step (vs ~18.7-19.4), ~1.8x. TRT tactic selection is nondeterministic per build; if a fresh engine benches noticeably slower, rebuild it. Under the stochastic pingpong sampler it yields a different but comparable sample.

First capture the calibration data from the model checkpoint with `make_calib.py` (drives the model's own pingpong `generate()` to record real DiT inputs across the sampling schedule; prompts come from the repo's own `interface/reprompt.py` Music examples, the deployment-matched reprompt format):

```bash
python make_calib.py \
--model-config <MODELS_ROOT>/SA3-M-hf/model_config.json \
--checkpoint <MODELS_ROOT>/SA3-M-hf/model.safetensors \
--out sa3-m.calib.npz
```

Then build the engine:

```bash
python build_dit_fp8.py \
--input $HF/onnx/sa3-m/dit_fp16mixed.onnx \
--calib sa3-m.calib.npz \
--onnx $HF/onnx/sa3-m/dit_fp8.onnx \
--engine ../models/$ARCH/sa3-m/dit_fp8.trt
```

`make_calib.py` needs only the repo + checkpoint (`torch`, `numpy`, `stable_audio_3`). `build_dit_fp8.py` additionally requires `nvidia-modelopt` + `onnxruntime-gpu` (the calibration-repair pass); consumers compile the published `dit_fp8.onnx` with plain `build_from_onnx.py sa3-m-fp8` (STRONGLY_TYPED, no ModelOpt, no calibration).

> **Not yet on HF.** `dit_fp8.onnx` + `dit_fp8.onnx.data` are not in the model repo yet, so `build_from_onnx.py sa3-m-fp8` and `sa3_trt --precision fp8` 404 until a producer run uploads them (under exactly those filenames). The consumer recipe and `--precision fp8` plumbing land here so the wiring is reviewed; the artifact upload is the follow-up step.

Each script also writes the ONNX to `<HF_REPO>/onnx/<engine>/<file>.onnx`. After all 8 are done:

```bash
Expand All @@ -166,6 +197,8 @@ git push
| `build_from_onnx.py` | One target → download ONNX from HF + compile to TRT. **For the SA3 DiTs, pulls `dit_fp16mixed.onnx` (the pre-processed island-wrapped graph)** so the consumer just needs to invoke `STRONGLY_TYPED` compilation — no `onnx-graphsurgeon` required | consumer |
| `build_dit_profile.py` | Build a DiT with custom `(min, opt, max)` profile shapes (experimental — short-form / fixed-shape variants). Operates on either ONNX flavor. | consumer |
| `build_dit_fp16mixed.py` | **Producer-side** ONNX surgery: takes the canonical FP32 `dit.onnx`, finds RMSNorm chains + attention `Softmax` + RoPE region, wraps each in `Cast(FP32) ↔ Cast(FP16)` islands, converts non-island weights to FP16, and writes both the modified `dit_fp16mixed.onnx` AND the TRT engine. Only re-run when the model retrains or the island recipe changes. Requires `onnx` + `onnx-graphsurgeon`. | producer |
| `make_calib.py` | **Producer-side** FP8 calibration capture: drives the model's own pingpong `generate()` and records the six DiT engine inputs across the schedule into a `*.calib.npz` for `build_dit_fp8.py`. Needs only the checkpoint (`torch` + `stable_audio_3`). | producer |
| `build_dit_fp8.py` | **Producer-side** FP8 trunk on top of `dit_fp16mixed.onnx`: ModelOpt FP8 PTQ (MatMul/Gemm, max calibration from a `.npz`), restores ModelOpt-corrupted initializers + recalibrates activation scales, re-applies the FP32 islands (incl. the conditioning front-end), and per-channel weight scales. Writes `dit_fp8.onnx` + the TRT engine. ~1.8x faster steps than FP16-mixed. Requires `nvidia-modelopt` + `onnxruntime-gpu`. | producer |
| `build_t5gemma.py` | Trace + export T5Gemma encoder ONNX + build TRT | producer |
| `build_same_s_decoder.py` | Trace + export SAME-S decoder ONNX + build TRT | producer |
| `build_same_s_encoder.py` | Trace + export SAME-S encoder ONNX + build TRT | producer |
Expand Down
22 changes: 19 additions & 3 deletions optimized/tensorRT/build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def _from_onnx(name):
{"label": "[opt-in] DiT sm-sfx FP32",
"command": _from_onnx("sa3-sm-sfx-fp32"),
"outputs": ["sa3-sm-sfx/dit_fp32.trt"]},
# FP8 variant, opt-in, DiT-only. ~1.8x faster steps than FP16-mixed
# (ModelOpt PTQ; producer build_dit_fp8.py). A different but comparable
# sample under the stochastic pingpong sampler.
{"label": "[opt-in] DiT medium FP8 (~1.8x)",
"command": _from_onnx("sa3-m-fp8"),
"outputs": ["sa3-m/dit_fp8.trt"],
"opt_in": True}, # built only by number, never via "Build all missing"
]


Expand Down Expand Up @@ -134,12 +141,19 @@ def render_menu(arch: str, arch_dir: Path) -> list[bool]:
size_s = fmt_size(sz) if sz >= 0 else f"{DIM}(missing){RESET}"
print(f" {tick} {DIM}{rel}{RESET} {size_s}")

n_missing = built_flags.count(False)
# "Build all missing" covers default targets only; opt-in variants (FP8)
# are built by number, so they are counted and reported separately.
n_missing = sum(1 for t, ok in zip(TARGETS, built_flags)
if not ok and not t.get("opt_in"))
n_optin_missing = sum(1 for t, ok in zip(TARGETS, built_flags)
if not ok and t.get("opt_in"))
print()
if n_missing == 0:
print(f" {BOLD}{GREEN}[A]{RESET} Build all missing {DIM}(nothing missing — all engines built){RESET}")
print(f" {BOLD}{GREEN}[A]{RESET} Build all missing {DIM}(nothing missing — all default engines built){RESET}")
else:
print(f" {BOLD}{YELLOW}[A]{RESET} Build all missing {DIM}({n_missing} target(s)){RESET}")
if n_optin_missing:
print(f" {DIM}(+{n_optin_missing} opt-in target(s) not in [A] — build by number){RESET}")
print(f" {BOLD}{DIM}[Q]{RESET} Quit")
return built_flags

Expand Down Expand Up @@ -180,7 +194,9 @@ def main() -> int:
return 0

if choice in ("a", "all"):
missing = [t for t, ok in zip(TARGETS, built_flags) if not ok]
# Opt-in variants (FP8) are excluded; build them by number.
missing = [t for t, ok in zip(TARGETS, built_flags)
if not ok and not t.get("opt_in")]
if not missing:
print(f" {DIM}Nothing to build.{RESET}")
continue
Expand Down
Loading