Phase 10.8: torchao/bnb quant sweep on iter1-independent. bf16 already
optimal; torchao int8-wo gives -19% VRAM at no F1 cost; all 4-bit
variants collapse (ModernBERT-large too quant-sensitive).
Phase 10.9: ONNX export + ORT eval. Legacy exporter only working path
(dynamo adds 56 Memcpy nodes); ORT fp32 -22% latency vs torch via
kernel fusion but bf16+flash-attn-2 still wins; fp16 broken on rotary;
dynamic int8 silently CPU-fallback + 0.5 F1 collapse.
Driver scripts wired to bun run py:quant / py:onnx; full reports at
results/eval/{quant,onnx}/REPORT.md.
6.4 KiB
ONNX Export + Eval — iter1-independent ModernBERT-large
Date: 2026-04-07
Checkpoint: checkpoints/finetune/iter1-independent/final/
Hardware: RTX 3090 (sm_8.6, 24 GB), onnxruntime-gpu 1.24.4, onnx 1.21
Driver: python/scripts/onnx_export_eval.py (bun run py:onnx)
Eval set: 1,200-paragraph v2 holdout, proxy gold = GPT-5.4 + Opus-4.6
TL;DR
ONNX export of this model is technically possible but the path is full of dead ends. The dynamo exporter produces a graph with 56 Memcpy nodes that makes ORT 8× slower than native torch and 4× more VRAM-heavy; the legacy TorchScript exporter produces a clean graph that's actually 22% faster than torch fp32 (kernel fusion); fp16 conversion breaks on the rotary embedding; dynamic int8 quantization via ORT silently falls back to CPU and drops ~0.5 macro F1. Net: torchao int8-wo from the earlier sweep is still the right int8 deployment path. ONNX is not.
What we tried
| variant | exporter | size MB | ms/sample | VRAM MB | cat F1 | spec F1 | result |
|---|---|---|---|---|---|---|---|
| onnx-fp32 (dynamo) | torch.onnx (dynamo) | 1583 | 42.92 | 15388 | 0.9337 | 0.8943 | works but unusable |
| onnx-int8 (dynamo) | dynamo + ORT int8 | 1580 | 42.82 | 15398 | 0.9337 | 0.8943 | no-op (no quant) |
| onnx-fp32 (legacy) | torch.onnx (TorchScript) | 1583 | 12.70 | 8228 | 0.9337 | 0.8952 | clean graph, faster than torch |
| onnx-fp16 (legacy) | onnxconverter_common | 754 | err | err | err | err | rotary type unify |
| onnx-int8 (legacy) | ORT quantize_dynamic | 527 | 95.91 | ~CPU | 0.3972 | 0.3364 | CPU fallback + accuracy collapse |
(All entries above were re-run from scratch — fp32 timing improved 3× moving from dynamo to legacy export.)
Six things broke along the way (workarounds in the script)
- Dynamo exporter optimizer crashes.
torch.onnx.export(..., dynamo=True)succeeds at translation but the post-translationInlinePassoptimizer trips ononnx_ir. Workaround:optimize=False. - Dynamo-exported graph is unusable on CUDA EP. ORT inserts 56 Memcpy nodes between layers because dynamo emits scalar tensors with CPU placement metadata. Result: 42.9 ms/sample (8× torch fp32) and 15.4 GB VRAM (4.4× torch fp32). The legacy exporter only inserts 1 Memcpy.
op_types_to_quantize=['MatMul']quantizes nothing on the dynamo graph. Dynamo emits encoder linears asGemmnodes, notMatMul. Fix:op_types_to_quantize=['MatMul', 'Gemm'].- Both ORT shape-inference paths choke on ModernBERT. Symbolic
inference asserts in
_infer_Range(rotary embedding limit input is not a scalar); the C++ inference raises a (1024)/(7) dim mismatch on the category head Gemm. Theskip_*flags onquant_pre_processare ignored, andONNXQuantizer.__init__callssave_and_reload_model_with_shape_inferunconditionally. Workaround: monkey-patchquant_utils.save_and_reload_model_with_shape_inferand the cached binding inonnx_quantizerto a no-op, then passextra_options={'DefaultTensorType': onnx.TensorProto.FLOAT}so the quantizer can still type the head MatMul. - fp16 conversion via
onnxconverter_commonbreaks on rotary embeddings. Two distinct failure modes seen across exports:Type Error: Type (tensor(float16)) of output arg (val_58) of node (node_Expand_56) does not match expected type (tensor(float))(dynamo graph) andType parameter (T) of Optype (Mul) bound to different types (tensor(float) and tensor(float16) in node (/model/backbone/rotary_emb_1/Mul_2)(legacy graph). The converter leaves theinv_freqbuffer in fp32 and the surrounding Mul/Expand ops then can't unify their type parameter. Could be patched with anop_block_listfor the rotary subgraph, but the cost/value isn't there given the dynamic int8 result below. - Dynamic int8 via ORT silently falls back to CPU. The quantizer
replaces Gemm/MatMul with
MatMulInteger+DynamicQuantizeLinear, neither of which has CUDA kernels in onnxruntime-gpu 1.24. Session creation succeeds with CUDAExecutionProvider but routes the quantized ops to the CPU EP — observable from theload_vram_mbcollapsing from 2074 MB (fp32) to 266 MB (int8) and latency exploding to 95.9 ms/sample. Per-channel int8 weights also drop accuracy from 0.934 → 0.397 on category and 0.895 → 0.336 on spec, further confirming the kernel path is wrong (not just slow).
What actually works
onnx-fp32 via the legacy TorchScript exporter is the one clean win: 12.70 ms/sample vs 16.29 for torch fp32 — a 22% latency improvement from ORT's LayerNorm/Gelu/MatMul fusion at bit-identical accuracy. VRAM is 8228 MB vs 3504 MB for torch fp32 (the ORT session allocates a separate ~5 GB workspace), so the speedup costs you ~2.3× memory. On a single 3090 batch=64 inference run that's a fair trade.
But this is fp32 — bf16 torch + flash-attn-2 is still the strict winner at 5.52 ms / 1741 MB (Phase 10.8 result). ORT can't run bf16 natively, and fp16 conversion is broken. So even the working ONNX path is dominated by what we already ship.
Recommendation
Don't use ONNX for this model on this hardware. The torchao int8-wo result from the quantization sweep (5.52 → 6.08 ms, 1741 → 1416 MB peak VRAM, F1 within ±0.001) covers the "smaller deployment" use case more cleanly than anything ONNX can offer here, and bf16 + flash-attn-2 remains the production default.
ONNX would be worth revisiting in any of these scenarios:
- CPU-only deployment — fp32 ONNX runs fine on CPUExecutionProvider and ORT's int8 dynamic path is actually designed for this case. Worth benchmarking if a CPU serving target ever shows up.
- Cross-runtime portability — TensorRT, OpenVINO, mobile runtimes. These would each need their own export validation pass.
- Static int8 with calibration —
quantize_staticwith a calibration dataset can avoid the dynamic-quant CPU fallback path. Would need a ModernBERT-friendly calibration loop and probably anop_block_listto keep the rotary in fp32. Real engineering work, not a one-shot.
Reproduce
bun run py:onnx
# writes to:
# results/eval/onnx/models/{model_fp32,model_fp16,model_int8_dyn}.onnx[.data]
# results/eval/onnx/summary.json
# results/eval/onnx/REPORT.md (this file)