Joey Eamigh 67beaede45
quantization + onnx sweeps
Phase 10.8: torchao/bnb quant sweep on iter1-independent. bf16 already
optimal; torchao int8-wo gives -19% VRAM at no F1 cost; all 4-bit
variants collapse (ModernBERT-large too quant-sensitive).

Phase 10.9: ONNX export + ORT eval. Legacy exporter only working path
(dynamo adds 56 Memcpy nodes); ORT fp32 -22% latency vs torch via
kernel fusion but bf16+flash-attn-2 still wins; fp16 broken on rotary;
dynamic int8 silently CPU-fallback + 0.5 F1 collapse.

Driver scripts wired to bun run py:quant / py:onnx; full reports at
results/eval/{quant,onnx}/REPORT.md.
2026-04-07 05:10:38 -04:00

9.3 KiB
Raw Permalink Blame History

Quantization Sweep — iter1-independent ModernBERT-large

Date: 2026-04-07 Checkpoint: checkpoints/finetune/iter1-independent/final/ Hardware: RTX 3090 (sm_8.6, 24 GB) Eval set: 1,200-paragraph v2 holdout, proxy gold = GPT-5.4 + Opus-4.6 Driver: python/scripts/quantize_sweep.py (run via bun run py:quant)

Setup

For each variant the encoder (ModernBERT-large backbone, 28 layers, 112 nn.Linear modules) is converted to the target precision/scheme, while the attention pooler and the dual heads (category linear + 3 independent threshold MLPs) are kept in bf16. Heads are <0.3% of params and sit on already-distilled 1024-d representations — quantizing them buys nothing and risks the threshold margins that drive most of the spec error budget.

For every variant we measure end-to-end inference on the full 1,200-paragraph holdout at batch=64, max_seq=512, after 5 warmup batches:

  • encoder_mb — sum of param.numel() * param.element_size() over the encoder. Caveat: for torchao tensor subclasses (AffineQuantizedTensor) this reports the outer dtype (bf16) rather than the int8 storage, so the 790 MB figure for the torchao rows is an over-estimate; real on-disk storage is roughly half. The bnb 4-bit row (275 MB) is correct because Params4bit reports uint8 element_size.
  • ms/sample — wall-clock per paragraph at batch=64
  • peak VRAMtorch.cuda.max_memory_allocated() over the timed run (encoder fwd + activations)
  • F1 / QWK / ECE — full eval pipeline reused from src/finetune/eval.py

Results

variant enc MB ms/samp thru/s VRAM MB cat F1 (GPT) spec F1 (GPT) spec QWK cat F1 (Opus) spec F1 (Opus) notes
fp32 1579 16.29 61 3504 0.9337 0.8943 0.9321 0.9227 0.8825 sdpa (no flash-attn)
bf16 (baseline) 790 5.52 181 1741 0.9337 0.8952 0.9324 0.9227 0.8834 flash-attn-2
fp16 790 5.54 181 1741 0.9337 0.8952 0.9324 0.9227 0.8834 flash-attn-2
torchao int8-wo ~395* 6.08 165 1416 0.9345 0.8941 0.9330 0.9235 0.8815 weight-only int8
torchao int8-dyn ~395* 9.67 103 1774 0.9336 0.8918 0.9315 0.9243 0.8827 dyn act + int8 weight
torchao int4-wo requires mslk>=1.0.0
bnb LLM.int8 ~395* 7.76 129 2135 0.9361 0.8986 0.9308 0.9235 0.8827 mixed-precision outliers
bnb nf4 (DQ) 275 5.86 171 1287 0.3537 0.2205 0.2423 0.3576 0.2075 collapsed
bnb nf4 (no DQ) 275 5.86 171 1287 0.3537 0.2205 0.2423 0.3576 0.2075 collapsed
bnb fp4 (no DQ) 275 5.87 170 1287 0.1629 0.2085 0.2326 0.1686 0.1978 collapsed harder

*torchao subclass tensors report bf16 element_size; true storage ~395 MB.

Per-variant detail (per-class F1, MCC, AUC, confusion matrices, calibration bins) is in results/eval/quant/{variant}/metrics.json. Aggregate row-level data is in results/eval/quant/summary.json.

Findings

1. bf16 is already the production sweet spot

Flash-attention-2 + bf16 gives 3.0× the throughput of fp32 (181 vs 61 samples/sec) at half the VRAM (1.7 vs 3.5 GB) with bit-identical accuracy. This is what we already train and serve at; the sweep simply confirms there's no headroom in fp16/fp32 for this hardware.

2. fp16 ≡ bf16 on Ampere

Identical latency, identical VRAM, identical F1. RTX 3090 has matched bf16/fp16 throughput on tensor cores and the model has no overflow issues in either format. Pick whichever the loader prefers.

3. torchao int8 weight-only is the only quantization variant worth shipping

  • VRAM 19% (1741 → 1416 MB) — meaningful for batched serving
  • F1 essentially unchanged (cat +0.0008, spec 0.0011 vs bf16 — both inside per-seed noise)
  • Latency +10% (5.52 → 6.08 ms/sample) — the int8 weight is dequantized to bf16 on the fly because RTX 3090 (sm_8.6) lacks the int8 tensor-core matmul kernel paths torchao would otherwise use; on H100/A100/Ada this same config would also be faster

The accuracy delta is statistically nothing — well within the ±0.002 std we observed across the 3-seed ensemble. This is the variant we'd ship as the "low-VRAM" deployment option.

4. torchao int8 dynamic activation: don't bother on this hardware

43% throughput (5.52 → 9.67 ms/sample) and more peak VRAM than bf16 (1774 vs 1741 MB) because the per-batch activation quantization adds work without unlocking int8 tensor cores. Pure regression on Ampere.

5. bnb LLM.int8: slowest int8, no accuracy upside

  • +41% latency (5.52 → 7.76 ms/sample) due to mixed-precision outlier handling
  • +23% VRAM (1741 → 2135 MB) — outlier columns are kept in fp16 plus scratch buffers
  • F1 +0.0024 cat, +0.0034 spec — within noise; not a real win

bnb LLM.int8 was designed for LLM-scale models where outlier features dominate quant error; for an encoder of this size on a single 3090 it just trades performance for nothing.

6. All 4-bit variants collapse — ModernBERT-large is too quant-sensitive

Both nf4 (with and without double-quantization) and fp4 produce essentially random predictions:

variant cat F1 spec F1 spec ECE
nf4 0.354 0.221 0.434
fp4 0.163 0.209 0.443

Per-layer dequantization is faithful — we verified that the dequantized weight of one MLP Wi layer differs from the original by mean 0.005 / max 0.11 (sub-1% error). But the relative output drift on a single Linear is already ~98% (mean), and that error compounds across 28 transformer blocks

  • GLU FFN paths until the [CLS]/pooled representation no longer carries the discriminative signal. The category head essentially collapses to a near-uniform prior (cat ECE 0.10 vs the 0.054 baseline) and the threshold heads collapse onto L1 because all three thresholds emit similar logits.

The fact that DQ vs no-DQ are bit-identical at this scale tells us the nf4 weight indices are stable under absmax requantization (only ~5% of the weight bytes change, all in the metadata block) — the catastrophe is inherent to 4-bit weight precision on this architecture, not to a quantization-config knob.

This is a real noteworthy null for the paper: naive post-training 4-bit weight quantization is not viable for ModernBERT-large on this task. Recovering 4-bit would require either (a) QAT, (b) per-channel calibration with a held-out activation distribution (GPTQ / AWQ-style), or (c) keeping the GLU FFN in 8-bit and only 4-bit'ing attention projections. None of these are reachable inside the remaining capstone time budget.

7. torchao int4-wo: dependency hole

torchao 0.17 requires mslk >= 1.0.0 for the new Int4Tensor.from_hp path. Not installed in the lockfile and not worth chasing given the bnb 4-bit collapse — even if the kernel ran cleanly we'd expect the same compounding error pattern.

Recommendations

Use case Variant Why
Production / paper headline bf16 Best of every dimension on this hardware
Low-VRAM batch serving torchao int8-wo 19% VRAM, accuracy intact, only 10% latency penalty
Multi-GPU sharded serving bf16 int8-wo's dequant overhead grows with replica count
Embedded / 4-bit not viable Needs QAT or AWQ-style calibration; future work

Paper-worthy notes

  1. Quantization story — bf16 is already the sweet spot; torchao int8-wo buys 19% VRAM with no accuracy cost; 4-bit fails. This adds another row to the speed/cost table.
  2. Architecture-specific quant fragility — ModernBERT-large's GLU FFN amplifies per-layer weight error across 28 blocks. This is a noteworthy counterpoint to the 4-bit-by-default LLM serving narrative and worth one paragraph in the discussion section alongside the DAPT and CORAL null results.
  3. Hardware caveat — int8 latency results would invert on Hopper/Ada/A100; the 3090 just doesn't have the matmul path. State the sm_8.6 caveat in the table caption.

Reproduce

# from repo root
bun run py:quant
# writes to results/eval/quant/{summary.json, REPORT.md, <variant>/metrics.json}

Run time: ~5 minutes total (most spent in fp32 + torchao build steps).