Phase 10.8: torchao/bnb quant sweep on iter1-independent. bf16 already
optimal; torchao int8-wo gives -19% VRAM at no F1 cost; all 4-bit
variants collapse (ModernBERT-large too quant-sensitive).
Phase 10.9: ONNX export + ORT eval. Legacy exporter only working path
(dynamo adds 56 Memcpy nodes); ORT fp32 -22% latency vs torch via
kernel fusion but bf16+flash-attn-2 still wins; fp16 broken on rotary;
dynamic int8 silently CPU-fallback + 0.5 F1 collapse.
Driver scripts wired to bun run py:quant / py:onnx; full reports at
results/eval/{quant,onnx}/REPORT.md.
118 lines
6.4 KiB
Markdown
118 lines
6.4 KiB
Markdown
# ONNX Export + Eval — iter1-independent ModernBERT-large
|
||
|
||
**Date:** 2026-04-07
|
||
**Checkpoint:** `checkpoints/finetune/iter1-independent/final/`
|
||
**Hardware:** RTX 3090 (sm_8.6, 24 GB), onnxruntime-gpu 1.24.4, onnx 1.21
|
||
**Driver:** `python/scripts/onnx_export_eval.py` (`bun run py:onnx`)
|
||
**Eval set:** 1,200-paragraph v2 holdout, proxy gold = GPT-5.4 + Opus-4.6
|
||
|
||
## TL;DR
|
||
|
||
ONNX export of this model is *technically* possible but the path is full of
|
||
dead ends. The dynamo exporter produces a graph with 56 Memcpy nodes that
|
||
makes ORT 8× slower than native torch and 4× more VRAM-heavy; the legacy
|
||
TorchScript exporter produces a clean graph that's actually 22% faster than
|
||
torch fp32 (kernel fusion); fp16 conversion breaks on the rotary embedding;
|
||
dynamic int8 quantization via ORT silently falls back to CPU and drops
|
||
~0.5 macro F1. **Net: torchao int8-wo from the earlier sweep is still the
|
||
right int8 deployment path. ONNX is not.**
|
||
|
||
## What we tried
|
||
|
||
| variant | exporter | size MB | ms/sample | VRAM MB | cat F1 | spec F1 | result |
|
||
|--------------------|----------------------|--------:|----------:|--------:|-------:|--------:|-----------------|
|
||
| onnx-fp32 (dynamo) | torch.onnx (dynamo) | 1583 | 42.92 | 15388 | 0.9337 | 0.8943 | works but unusable |
|
||
| onnx-int8 (dynamo) | dynamo + ORT int8 | 1580 | 42.82 | 15398 | 0.9337 | 0.8943 | no-op (no quant) |
|
||
| **onnx-fp32 (legacy)** | torch.onnx (TorchScript) | 1583 | **12.70** | 8228 | 0.9337 | 0.8952 | **clean graph, faster than torch** |
|
||
| onnx-fp16 (legacy) | onnxconverter_common | 754 | err | err | err | err | rotary type unify |
|
||
| onnx-int8 (legacy) | ORT quantize_dynamic | 527 | 95.91 | ~CPU | 0.3972 | 0.3364 | CPU fallback + accuracy collapse |
|
||
|
||
(All entries above were re-run from scratch — fp32 timing improved 3× moving
|
||
from dynamo to legacy export.)
|
||
|
||
## Six things broke along the way (workarounds in the script)
|
||
|
||
1. **Dynamo exporter optimizer crashes.** `torch.onnx.export(..., dynamo=True)`
|
||
succeeds at translation but the post-translation `InlinePass` optimizer
|
||
trips on `onnx_ir`. Workaround: `optimize=False`.
|
||
2. **Dynamo-exported graph is unusable on CUDA EP.** ORT inserts 56 Memcpy
|
||
nodes between layers because dynamo emits scalar tensors with CPU
|
||
placement metadata. Result: 42.9 ms/sample (8× torch fp32) and 15.4 GB
|
||
VRAM (4.4× torch fp32). The legacy exporter only inserts 1 Memcpy.
|
||
3. **`op_types_to_quantize=['MatMul']` quantizes nothing on the dynamo
|
||
graph.** Dynamo emits encoder linears as `Gemm` nodes, not `MatMul`.
|
||
Fix: `op_types_to_quantize=['MatMul', 'Gemm']`.
|
||
4. **Both ORT shape-inference paths choke on ModernBERT.** Symbolic
|
||
inference asserts in `_infer_Range` (rotary embedding limit input is
|
||
not a scalar); the C++ inference raises a (1024)/(7) dim mismatch on
|
||
the category head Gemm. The `skip_*` flags on `quant_pre_process` are
|
||
ignored, and `ONNXQuantizer.__init__` calls
|
||
`save_and_reload_model_with_shape_infer` unconditionally. Workaround:
|
||
monkey-patch `quant_utils.save_and_reload_model_with_shape_infer`
|
||
*and* the cached binding in `onnx_quantizer` to a no-op, then pass
|
||
`extra_options={'DefaultTensorType': onnx.TensorProto.FLOAT}` so the
|
||
quantizer can still type the head MatMul.
|
||
5. **fp16 conversion via `onnxconverter_common` breaks on rotary
|
||
embeddings.** Two distinct failure modes seen across exports:
|
||
`Type Error: Type (tensor(float16)) of output arg (val_58) of node
|
||
(node_Expand_56) does not match expected type (tensor(float))` (dynamo
|
||
graph) and `Type parameter (T) of Optype (Mul) bound to different types
|
||
(tensor(float) and tensor(float16) in node
|
||
(/model/backbone/rotary_emb_1/Mul_2)` (legacy graph). The converter
|
||
leaves the `inv_freq` buffer in fp32 and the surrounding Mul/Expand
|
||
ops then can't unify their type parameter. Could be patched with an
|
||
`op_block_list` for the rotary subgraph, but the cost/value isn't
|
||
there given the dynamic int8 result below.
|
||
6. **Dynamic int8 via ORT silently falls back to CPU.** The quantizer
|
||
replaces Gemm/MatMul with `MatMulInteger` + `DynamicQuantizeLinear`,
|
||
neither of which has CUDA kernels in onnxruntime-gpu 1.24. Session
|
||
creation succeeds with CUDAExecutionProvider but routes the
|
||
quantized ops to the CPU EP — observable from the `load_vram_mb`
|
||
collapsing from 2074 MB (fp32) to 266 MB (int8) and latency exploding
|
||
to 95.9 ms/sample. Per-channel int8 weights also drop accuracy from
|
||
0.934 → 0.397 on category and 0.895 → 0.336 on spec, further
|
||
confirming the kernel path is wrong (not just slow).
|
||
|
||
## What actually works
|
||
|
||
**onnx-fp32 via the legacy TorchScript exporter** is the one clean win:
|
||
12.70 ms/sample vs 16.29 for torch fp32 — a **22% latency improvement
|
||
from ORT's LayerNorm/Gelu/MatMul fusion** at bit-identical accuracy. VRAM
|
||
is 8228 MB vs 3504 MB for torch fp32 (the ORT session allocates a separate
|
||
~5 GB workspace), so the speedup costs you ~2.3× memory. On a single
|
||
3090 batch=64 inference run that's a fair trade.
|
||
|
||
But this is fp32 — bf16 torch + flash-attn-2 is *still* the strict winner
|
||
at 5.52 ms / 1741 MB (Phase 10.8 result). ORT can't run bf16 natively, and
|
||
fp16 conversion is broken. So even the working ONNX path is dominated by
|
||
what we already ship.
|
||
|
||
## Recommendation
|
||
|
||
**Don't use ONNX for this model on this hardware.** The torchao int8-wo
|
||
result from the quantization sweep (5.52 → 6.08 ms, 1741 → 1416 MB peak
|
||
VRAM, F1 within ±0.001) covers the "smaller deployment" use case more
|
||
cleanly than anything ONNX can offer here, and bf16 + flash-attn-2
|
||
remains the production default.
|
||
|
||
ONNX *would* be worth revisiting in any of these scenarios:
|
||
- **CPU-only deployment** — fp32 ONNX runs fine on CPUExecutionProvider
|
||
and ORT's int8 dynamic path is actually designed for this case. Worth
|
||
benchmarking if a CPU serving target ever shows up.
|
||
- **Cross-runtime portability** — TensorRT, OpenVINO, mobile runtimes.
|
||
These would each need their own export validation pass.
|
||
- **Static int8 with calibration** — `quantize_static` with a calibration
|
||
dataset can avoid the dynamic-quant CPU fallback path. Would need a
|
||
ModernBERT-friendly calibration loop and probably an `op_block_list`
|
||
to keep the rotary in fp32. Real engineering work, not a one-shot.
|
||
|
||
## Reproduce
|
||
|
||
```bash
|
||
bun run py:onnx
|
||
# writes to:
|
||
# results/eval/onnx/models/{model_fp32,model_fp16,model_int8_dyn}.onnx[.data]
|
||
# results/eval/onnx/summary.json
|
||
# results/eval/onnx/REPORT.md (this file)
|
||
```
|