Phase 10.8: torchao/bnb quant sweep on iter1-independent. bf16 already
optimal; torchao int8-wo gives -19% VRAM at no F1 cost; all 4-bit
variants collapse (ModernBERT-large too quant-sensitive).
Phase 10.9: ONNX export + ORT eval. Legacy exporter only working path
(dynamo adds 56 Memcpy nodes); ORT fp32 -22% latency vs torch via
kernel fusion but bf16+flash-attn-2 still wins; fp16 broken on rotary;
dynamic int8 silently CPU-fallback + 0.5 F1 collapse.
Driver scripts wired to bun run py:quant / py:onnx; full reports at
results/eval/{quant,onnx}/REPORT.md.
9.3 KiB
Quantization Sweep — iter1-independent ModernBERT-large
Date: 2026-04-07
Checkpoint: checkpoints/finetune/iter1-independent/final/
Hardware: RTX 3090 (sm_8.6, 24 GB)
Eval set: 1,200-paragraph v2 holdout, proxy gold = GPT-5.4 + Opus-4.6
Driver: python/scripts/quantize_sweep.py (run via bun run py:quant)
Setup
For each variant the encoder (ModernBERT-large backbone, 28 layers, 112 nn.Linear modules) is converted to the target precision/scheme, while the attention pooler and the dual heads (category linear + 3 independent threshold MLPs) are kept in bf16. Heads are <0.3% of params and sit on already-distilled 1024-d representations — quantizing them buys nothing and risks the threshold margins that drive most of the spec error budget.
For every variant we measure end-to-end inference on the full 1,200-paragraph holdout at batch=64, max_seq=512, after 5 warmup batches:
- encoder_mb — sum of
param.numel() * param.element_size()over the encoder. Caveat: for torchao tensor subclasses (AffineQuantizedTensor) this reports the outer dtype (bf16) rather than the int8 storage, so the 790 MB figure for the torchao rows is an over-estimate; real on-disk storage is roughly half. The bnb 4-bit row (275 MB) is correct becauseParams4bitreportsuint8element_size. - ms/sample — wall-clock per paragraph at batch=64
- peak VRAM —
torch.cuda.max_memory_allocated()over the timed run (encoder fwd + activations) - F1 / QWK / ECE — full eval pipeline reused from
src/finetune/eval.py
Results
| variant | enc MB | ms/samp | thru/s | VRAM MB | cat F1 (GPT) | spec F1 (GPT) | spec QWK | cat F1 (Opus) | spec F1 (Opus) | notes |
|---|---|---|---|---|---|---|---|---|---|---|
| fp32 | 1579 | 16.29 | 61 | 3504 | 0.9337 | 0.8943 | 0.9321 | 0.9227 | 0.8825 | sdpa (no flash-attn) |
| bf16 (baseline) | 790 | 5.52 | 181 | 1741 | 0.9337 | 0.8952 | 0.9324 | 0.9227 | 0.8834 | flash-attn-2 |
| fp16 | 790 | 5.54 | 181 | 1741 | 0.9337 | 0.8952 | 0.9324 | 0.9227 | 0.8834 | flash-attn-2 |
| torchao int8-wo | ~395* | 6.08 | 165 | 1416 | 0.9345 | 0.8941 | 0.9330 | 0.9235 | 0.8815 | weight-only int8 |
| torchao int8-dyn | ~395* | 9.67 | 103 | 1774 | 0.9336 | 0.8918 | 0.9315 | 0.9243 | 0.8827 | dyn act + int8 weight |
| torchao int4-wo | — | — | — | — | — | — | — | — | — | requires mslk>=1.0.0 |
| bnb LLM.int8 | ~395* | 7.76 | 129 | 2135 | 0.9361 | 0.8986 | 0.9308 | 0.9235 | 0.8827 | mixed-precision outliers |
| bnb nf4 (DQ) | 275 | 5.86 | 171 | 1287 | 0.3537 | 0.2205 | 0.2423 | 0.3576 | 0.2075 | collapsed |
| bnb nf4 (no DQ) | 275 | 5.86 | 171 | 1287 | 0.3537 | 0.2205 | 0.2423 | 0.3576 | 0.2075 | collapsed |
| bnb fp4 (no DQ) | 275 | 5.87 | 170 | 1287 | 0.1629 | 0.2085 | 0.2326 | 0.1686 | 0.1978 | collapsed harder |
*torchao subclass tensors report bf16 element_size; true storage ~395 MB.
Per-variant detail (per-class F1, MCC, AUC, confusion matrices, calibration
bins) is in results/eval/quant/{variant}/metrics.json. Aggregate row-level
data is in results/eval/quant/summary.json.
Findings
1. bf16 is already the production sweet spot
Flash-attention-2 + bf16 gives 3.0× the throughput of fp32 (181 vs 61 samples/sec) at half the VRAM (1.7 vs 3.5 GB) with bit-identical accuracy. This is what we already train and serve at; the sweep simply confirms there's no headroom in fp16/fp32 for this hardware.
2. fp16 ≡ bf16 on Ampere
Identical latency, identical VRAM, identical F1. RTX 3090 has matched bf16/fp16 throughput on tensor cores and the model has no overflow issues in either format. Pick whichever the loader prefers.
3. torchao int8 weight-only is the only quantization variant worth shipping
- VRAM −19% (1741 → 1416 MB) — meaningful for batched serving
- F1 essentially unchanged (cat +0.0008, spec −0.0011 vs bf16 — both inside per-seed noise)
- Latency +10% (5.52 → 6.08 ms/sample) — the int8 weight is dequantized to bf16 on the fly because RTX 3090 (sm_8.6) lacks the int8 tensor-core matmul kernel paths torchao would otherwise use; on H100/A100/Ada this same config would also be faster
The accuracy delta is statistically nothing — well within the ±0.002 std we observed across the 3-seed ensemble. This is the variant we'd ship as the "low-VRAM" deployment option.
4. torchao int8 dynamic activation: don't bother on this hardware
−43% throughput (5.52 → 9.67 ms/sample) and more peak VRAM than bf16 (1774 vs 1741 MB) because the per-batch activation quantization adds work without unlocking int8 tensor cores. Pure regression on Ampere.
5. bnb LLM.int8: slowest int8, no accuracy upside
- +41% latency (5.52 → 7.76 ms/sample) due to mixed-precision outlier handling
- +23% VRAM (1741 → 2135 MB) — outlier columns are kept in fp16 plus scratch buffers
- F1 +0.0024 cat, +0.0034 spec — within noise; not a real win
bnb LLM.int8 was designed for LLM-scale models where outlier features dominate quant error; for an encoder of this size on a single 3090 it just trades performance for nothing.
6. All 4-bit variants collapse — ModernBERT-large is too quant-sensitive
Both nf4 (with and without double-quantization) and fp4 produce essentially random predictions:
| variant | cat F1 | spec F1 | spec ECE |
|---|---|---|---|
| nf4 | 0.354 | 0.221 | 0.434 |
| fp4 | 0.163 | 0.209 | 0.443 |
Per-layer dequantization is faithful — we verified that the dequantized weight of one MLP Wi layer differs from the original by mean 0.005 / max 0.11 (sub-1% error). But the relative output drift on a single Linear is already ~98% (mean), and that error compounds across 28 transformer blocks
- GLU FFN paths until the [CLS]/pooled representation no longer carries the discriminative signal. The category head essentially collapses to a near-uniform prior (cat ECE 0.10 vs the 0.054 baseline) and the threshold heads collapse onto L1 because all three thresholds emit similar logits.
The fact that DQ vs no-DQ are bit-identical at this scale tells us the nf4 weight indices are stable under absmax requantization (only ~5% of the weight bytes change, all in the metadata block) — the catastrophe is inherent to 4-bit weight precision on this architecture, not to a quantization-config knob.
This is a real noteworthy null for the paper: naive post-training 4-bit weight quantization is not viable for ModernBERT-large on this task. Recovering 4-bit would require either (a) QAT, (b) per-channel calibration with a held-out activation distribution (GPTQ / AWQ-style), or (c) keeping the GLU FFN in 8-bit and only 4-bit'ing attention projections. None of these are reachable inside the remaining capstone time budget.
7. torchao int4-wo: dependency hole
torchao 0.17 requires mslk >= 1.0.0 for the new Int4Tensor.from_hp path.
Not installed in the lockfile and not worth chasing given the bnb 4-bit
collapse — even if the kernel ran cleanly we'd expect the same compounding
error pattern.
Recommendations
| Use case | Variant | Why |
|---|---|---|
| Production / paper headline | bf16 | Best of every dimension on this hardware |
| Low-VRAM batch serving | torchao int8-wo | −19% VRAM, accuracy intact, only 10% latency penalty |
| Multi-GPU sharded serving | bf16 | int8-wo's dequant overhead grows with replica count |
| Embedded / 4-bit | not viable | Needs QAT or AWQ-style calibration; future work |
Paper-worthy notes
- Quantization story — bf16 is already the sweet spot; torchao int8-wo buys 19% VRAM with no accuracy cost; 4-bit fails. This adds another row to the speed/cost table.
- Architecture-specific quant fragility — ModernBERT-large's GLU FFN amplifies per-layer weight error across 28 blocks. This is a noteworthy counterpoint to the 4-bit-by-default LLM serving narrative and worth one paragraph in the discussion section alongside the DAPT and CORAL null results.
- Hardware caveat — int8 latency results would invert on Hopper/Ada/A100; the 3090 just doesn't have the matmul path. State the sm_8.6 caveat in the table caption.
Reproduce
# from repo root
bun run py:quant
# writes to results/eval/quant/{summary.json, REPORT.md, <variant>/metrics.json}
Run time: ~5 minutes total (most spent in fp32 + torchao build steps).