Joey Eamigh 67beaede45
quantization + onnx sweeps
Phase 10.8: torchao/bnb quant sweep on iter1-independent. bf16 already
optimal; torchao int8-wo gives -19% VRAM at no F1 cost; all 4-bit
variants collapse (ModernBERT-large too quant-sensitive).

Phase 10.9: ONNX export + ORT eval. Legacy exporter only working path
(dynamo adds 56 Memcpy nodes); ORT fp32 -22% latency vs torch via
kernel fusion but bf16+flash-attn-2 still wins; fp16 broken on rotary;
dynamic int8 silently CPU-fallback + 0.5 F1 collapse.

Driver scripts wired to bun run py:quant / py:onnx; full reports at
results/eval/{quant,onnx}/REPORT.md.
2026-04-07 05:10:38 -04:00

118 lines
6.4 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ONNX Export + Eval — iter1-independent ModernBERT-large
**Date:** 2026-04-07
**Checkpoint:** `checkpoints/finetune/iter1-independent/final/`
**Hardware:** RTX 3090 (sm_8.6, 24 GB), onnxruntime-gpu 1.24.4, onnx 1.21
**Driver:** `python/scripts/onnx_export_eval.py` (`bun run py:onnx`)
**Eval set:** 1,200-paragraph v2 holdout, proxy gold = GPT-5.4 + Opus-4.6
## TL;DR
ONNX export of this model is *technically* possible but the path is full of
dead ends. The dynamo exporter produces a graph with 56 Memcpy nodes that
makes ORT 8× slower than native torch and 4× more VRAM-heavy; the legacy
TorchScript exporter produces a clean graph that's actually 22% faster than
torch fp32 (kernel fusion); fp16 conversion breaks on the rotary embedding;
dynamic int8 quantization via ORT silently falls back to CPU and drops
~0.5 macro F1. **Net: torchao int8-wo from the earlier sweep is still the
right int8 deployment path. ONNX is not.**
## What we tried
| variant | exporter | size MB | ms/sample | VRAM MB | cat F1 | spec F1 | result |
|--------------------|----------------------|--------:|----------:|--------:|-------:|--------:|-----------------|
| onnx-fp32 (dynamo) | torch.onnx (dynamo) | 1583 | 42.92 | 15388 | 0.9337 | 0.8943 | works but unusable |
| onnx-int8 (dynamo) | dynamo + ORT int8 | 1580 | 42.82 | 15398 | 0.9337 | 0.8943 | no-op (no quant) |
| **onnx-fp32 (legacy)** | torch.onnx (TorchScript) | 1583 | **12.70** | 8228 | 0.9337 | 0.8952 | **clean graph, faster than torch** |
| onnx-fp16 (legacy) | onnxconverter_common | 754 | err | err | err | err | rotary type unify |
| onnx-int8 (legacy) | ORT quantize_dynamic | 527 | 95.91 | ~CPU | 0.3972 | 0.3364 | CPU fallback + accuracy collapse |
(All entries above were re-run from scratch — fp32 timing improved 3× moving
from dynamo to legacy export.)
## Six things broke along the way (workarounds in the script)
1. **Dynamo exporter optimizer crashes.** `torch.onnx.export(..., dynamo=True)`
succeeds at translation but the post-translation `InlinePass` optimizer
trips on `onnx_ir`. Workaround: `optimize=False`.
2. **Dynamo-exported graph is unusable on CUDA EP.** ORT inserts 56 Memcpy
nodes between layers because dynamo emits scalar tensors with CPU
placement metadata. Result: 42.9 ms/sample (8× torch fp32) and 15.4 GB
VRAM (4.4× torch fp32). The legacy exporter only inserts 1 Memcpy.
3. **`op_types_to_quantize=['MatMul']` quantizes nothing on the dynamo
graph.** Dynamo emits encoder linears as `Gemm` nodes, not `MatMul`.
Fix: `op_types_to_quantize=['MatMul', 'Gemm']`.
4. **Both ORT shape-inference paths choke on ModernBERT.** Symbolic
inference asserts in `_infer_Range` (rotary embedding limit input is
not a scalar); the C++ inference raises a (1024)/(7) dim mismatch on
the category head Gemm. The `skip_*` flags on `quant_pre_process` are
ignored, and `ONNXQuantizer.__init__` calls
`save_and_reload_model_with_shape_infer` unconditionally. Workaround:
monkey-patch `quant_utils.save_and_reload_model_with_shape_infer`
*and* the cached binding in `onnx_quantizer` to a no-op, then pass
`extra_options={'DefaultTensorType': onnx.TensorProto.FLOAT}` so the
quantizer can still type the head MatMul.
5. **fp16 conversion via `onnxconverter_common` breaks on rotary
embeddings.** Two distinct failure modes seen across exports:
`Type Error: Type (tensor(float16)) of output arg (val_58) of node
(node_Expand_56) does not match expected type (tensor(float))` (dynamo
graph) and `Type parameter (T) of Optype (Mul) bound to different types
(tensor(float) and tensor(float16) in node
(/model/backbone/rotary_emb_1/Mul_2)` (legacy graph). The converter
leaves the `inv_freq` buffer in fp32 and the surrounding Mul/Expand
ops then can't unify their type parameter. Could be patched with an
`op_block_list` for the rotary subgraph, but the cost/value isn't
there given the dynamic int8 result below.
6. **Dynamic int8 via ORT silently falls back to CPU.** The quantizer
replaces Gemm/MatMul with `MatMulInteger` + `DynamicQuantizeLinear`,
neither of which has CUDA kernels in onnxruntime-gpu 1.24. Session
creation succeeds with CUDAExecutionProvider but routes the
quantized ops to the CPU EP — observable from the `load_vram_mb`
collapsing from 2074 MB (fp32) to 266 MB (int8) and latency exploding
to 95.9 ms/sample. Per-channel int8 weights also drop accuracy from
0.934 → 0.397 on category and 0.895 → 0.336 on spec, further
confirming the kernel path is wrong (not just slow).
## What actually works
**onnx-fp32 via the legacy TorchScript exporter** is the one clean win:
12.70 ms/sample vs 16.29 for torch fp32 — a **22% latency improvement
from ORT's LayerNorm/Gelu/MatMul fusion** at bit-identical accuracy. VRAM
is 8228 MB vs 3504 MB for torch fp32 (the ORT session allocates a separate
~5 GB workspace), so the speedup costs you ~2.3× memory. On a single
3090 batch=64 inference run that's a fair trade.
But this is fp32 — bf16 torch + flash-attn-2 is *still* the strict winner
at 5.52 ms / 1741 MB (Phase 10.8 result). ORT can't run bf16 natively, and
fp16 conversion is broken. So even the working ONNX path is dominated by
what we already ship.
## Recommendation
**Don't use ONNX for this model on this hardware.** The torchao int8-wo
result from the quantization sweep (5.52 → 6.08 ms, 1741 → 1416 MB peak
VRAM, F1 within ±0.001) covers the "smaller deployment" use case more
cleanly than anything ONNX can offer here, and bf16 + flash-attn-2
remains the production default.
ONNX *would* be worth revisiting in any of these scenarios:
- **CPU-only deployment** — fp32 ONNX runs fine on CPUExecutionProvider
and ORT's int8 dynamic path is actually designed for this case. Worth
benchmarking if a CPU serving target ever shows up.
- **Cross-runtime portability** — TensorRT, OpenVINO, mobile runtimes.
These would each need their own export validation pass.
- **Static int8 with calibration** — `quantize_static` with a calibration
dataset can avoid the dynamic-quant CPU fallback path. Would need a
ModernBERT-friendly calibration loop and probably an `op_block_list`
to keep the rotary in fp32. Real engineering work, not a one-shot.
## Reproduce
```bash
bun run py:onnx
# writes to:
# results/eval/onnx/models/{model_fp32,model_fp16,model_int8_dyn}.onnx[.data]
# results/eval/onnx/summary.json
# results/eval/onnx/REPORT.md (this file)
```