Joey Eamigh 67beaede45
quantization + onnx sweeps
Phase 10.8: torchao/bnb quant sweep on iter1-independent. bf16 already
optimal; torchao int8-wo gives -19% VRAM at no F1 cost; all 4-bit
variants collapse (ModernBERT-large too quant-sensitive).

Phase 10.9: ONNX export + ORT eval. Legacy exporter only working path
(dynamo adds 56 Memcpy nodes); ORT fp32 -22% latency vs torch via
kernel fusion but bf16+flash-attn-2 still wins; fp16 broken on rotary;
dynamic int8 silently CPU-fallback + 0.5 F1 collapse.

Driver scripts wired to bun run py:quant / py:onnx; full reports at
results/eval/{quant,onnx}/REPORT.md.
2026-04-07 05:10:38 -04:00

164 lines
9.3 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Quantization Sweep — iter1-independent ModernBERT-large
**Date:** 2026-04-07
**Checkpoint:** `checkpoints/finetune/iter1-independent/final/`
**Hardware:** RTX 3090 (sm_8.6, 24 GB)
**Eval set:** 1,200-paragraph v2 holdout, proxy gold = GPT-5.4 + Opus-4.6
**Driver:** `python/scripts/quantize_sweep.py` (run via `bun run py:quant`)
## Setup
For each variant the *encoder* (ModernBERT-large backbone, 28 layers, 112
nn.Linear modules) is converted to the target precision/scheme, while the
attention pooler and the dual heads (category linear + 3 independent
threshold MLPs) are kept in bf16. Heads are <0.3% of params and sit on
already-distilled 1024-d representations quantizing them buys nothing and
risks the threshold margins that drive most of the spec error budget.
For every variant we measure end-to-end inference on the full 1,200-paragraph
holdout at batch=64, max_seq=512, after 5 warmup batches:
- **encoder_mb** sum of `param.numel() * param.element_size()` over the
encoder. **Caveat:** for torchao tensor subclasses (`AffineQuantizedTensor`)
this reports the *outer* dtype (bf16) rather than the int8 storage, so the
790 MB figure for the torchao rows is an over-estimate; real on-disk
storage is roughly half. The bnb 4-bit row (275 MB) is correct because
`Params4bit` reports `uint8` element_size.
- **ms/sample** wall-clock per paragraph at batch=64
- **peak VRAM** `torch.cuda.max_memory_allocated()` over the timed run
(encoder fwd + activations)
- **F1 / QWK / ECE** full eval pipeline reused from `src/finetune/eval.py`
## Results
| variant | enc MB | ms/samp | thru/s | VRAM MB | cat F1 (GPT) | spec F1 (GPT) | spec QWK | cat F1 (Opus) | spec F1 (Opus) | notes |
|--------------------|-------:|--------:|-------:|--------:|-------------:|--------------:|---------:|--------------:|---------------:|--------------------------------|
| fp32 | 1579 | 16.29 | 61 | 3504 | 0.9337 | 0.8943 | 0.9321 | 0.9227 | 0.8825 | sdpa (no flash-attn) |
| **bf16 (baseline)**| 790 | 5.52 | 181 | 1741 | 0.9337 | 0.8952 | 0.9324 | 0.9227 | 0.8834 | flash-attn-2 |
| fp16 | 790 | 5.54 | 181 | 1741 | 0.9337 | 0.8952 | 0.9324 | 0.9227 | 0.8834 | flash-attn-2 |
| **torchao int8-wo**| ~395* | 6.08 | 165 | 1416 | 0.9345 | 0.8941 | 0.9330 | 0.9235 | 0.8815 | weight-only int8 |
| torchao int8-dyn | ~395* | 9.67 | 103 | 1774 | 0.9336 | 0.8918 | 0.9315 | 0.9243 | 0.8827 | dyn act + int8 weight |
| torchao int4-wo | | | | | | | | | | requires `mslk>=1.0.0` |
| bnb LLM.int8 | ~395* | 7.76 | 129 | 2135 | 0.9361 | 0.8986 | 0.9308 | 0.9235 | 0.8827 | mixed-precision outliers |
| bnb nf4 (DQ) | 275 | 5.86 | 171 | 1287 | 0.3537 | 0.2205 | 0.2423 | 0.3576 | 0.2075 | **collapsed** |
| bnb nf4 (no DQ) | 275 | 5.86 | 171 | 1287 | 0.3537 | 0.2205 | 0.2423 | 0.3576 | 0.2075 | **collapsed** |
| bnb fp4 (no DQ) | 275 | 5.87 | 170 | 1287 | 0.1629 | 0.2085 | 0.2326 | 0.1686 | 0.1978 | **collapsed harder** |
\*torchao subclass tensors report bf16 element_size; true storage ~395 MB.
Per-variant detail (per-class F1, MCC, AUC, confusion matrices, calibration
bins) is in `results/eval/quant/{variant}/metrics.json`. Aggregate row-level
data is in `results/eval/quant/summary.json`.
## Findings
### 1. bf16 is already the production sweet spot
Flash-attention-2 + bf16 gives **3.0× the throughput of fp32** (181 vs 61
samples/sec) at **half the VRAM** (1.7 vs 3.5 GB) with bit-identical
accuracy. This is what we already train and serve at; the sweep simply
confirms there's no headroom in fp16/fp32 for this hardware.
### 2. fp16 ≡ bf16 on Ampere
Identical latency, identical VRAM, identical F1. RTX 3090 has matched
bf16/fp16 throughput on tensor cores and the model has no overflow issues
in either format. Pick whichever the loader prefers.
### 3. torchao int8 weight-only is the only quantization variant worth shipping
- **VRAM 19%** (1741 1416 MB) meaningful for batched serving
- **F1 essentially unchanged** (cat +0.0008, spec 0.0011 vs bf16 both
inside per-seed noise)
- **Latency +10%** (5.52 6.08 ms/sample) the int8 weight is dequantized
to bf16 on the fly because RTX 3090 (sm_8.6) lacks the int8 tensor-core
matmul kernel paths torchao would otherwise use; on H100/A100/Ada this
same config would also be faster
The accuracy delta is statistically nothing well within the ±0.002 std we
observed across the 3-seed ensemble. **This is the variant we'd ship as the
"low-VRAM" deployment option.**
### 4. torchao int8 dynamic activation: don't bother on this hardware
43% throughput (5.52 9.67 ms/sample) and *more* peak VRAM than bf16
(1774 vs 1741 MB) because the per-batch activation quantization adds work
without unlocking int8 tensor cores. Pure regression on Ampere.
### 5. bnb LLM.int8: slowest int8, no accuracy upside
- **+41% latency** (5.52 7.76 ms/sample) due to mixed-precision outlier
handling
- **+23% VRAM** (1741 2135 MB) outlier columns are kept in fp16 plus
scratch buffers
- **F1 +0.0024 cat, +0.0034 spec** within noise; not a real win
bnb LLM.int8 was designed for LLM-scale models where outlier features
dominate quant error; for an encoder of this size on a single 3090 it
just trades performance for nothing.
### 6. All 4-bit variants collapse — ModernBERT-large is too quant-sensitive
Both nf4 (with and without double-quantization) and fp4 produce essentially
random predictions:
| variant | cat F1 | spec F1 | spec ECE |
|---------|-------:|--------:|---------:|
| nf4 | 0.354 | 0.221 | 0.434 |
| fp4 | 0.163 | 0.209 | 0.443 |
Per-layer dequantization is faithful we verified that the dequantized
weight of one MLP Wi layer differs from the original by mean 0.005 / max
0.11 (sub-1% error). But the relative output drift on a single Linear is
already ~98% (mean), and that error compounds across 28 transformer blocks
+ GLU FFN paths until the [CLS]/pooled representation no longer carries
the discriminative signal. The category head essentially collapses to a
near-uniform prior (cat ECE 0.10 vs the 0.054 baseline) and the threshold
heads collapse onto L1 because all three thresholds emit similar logits.
The fact that **DQ vs no-DQ are bit-identical** at this scale tells us the
nf4 weight indices are stable under absmax requantization (only ~5% of the
weight bytes change, all in the metadata block) the catastrophe is
inherent to 4-bit weight precision on this architecture, not to a
quantization-config knob.
This is a real noteworthy null for the paper: **naive post-training 4-bit
weight quantization is not viable for ModernBERT-large on this task**.
Recovering 4-bit would require either (a) QAT, (b) per-channel calibration
with a held-out activation distribution (GPTQ / AWQ-style), or (c) keeping
the GLU FFN in 8-bit and only 4-bit'ing attention projections. None of
these are reachable inside the remaining capstone time budget.
### 7. torchao int4-wo: dependency hole
torchao 0.17 requires `mslk >= 1.0.0` for the new `Int4Tensor.from_hp` path.
Not installed in the lockfile and not worth chasing given the bnb 4-bit
collapse even if the kernel ran cleanly we'd expect the same compounding
error pattern.
## Recommendations
| Use case | Variant | Why |
|-----------------------------------|--------------------|-------------------------------------------------------------|
| **Production / paper headline** | bf16 | Best of every dimension on this hardware |
| **Low-VRAM batch serving** | torchao int8-wo | 19% VRAM, accuracy intact, only 10% latency penalty |
| **Multi-GPU sharded serving** | bf16 | int8-wo's dequant overhead grows with replica count |
| **Embedded / 4-bit** | not viable | Needs QAT or AWQ-style calibration; future work |
## Paper-worthy notes
1. **Quantization story** bf16 is already the sweet spot; torchao int8-wo
buys 19% VRAM with no accuracy cost; 4-bit fails. This adds another row
to the speed/cost table.
2. **Architecture-specific quant fragility** ModernBERT-large's GLU FFN
amplifies per-layer weight error across 28 blocks. This is a noteworthy
counterpoint to the 4-bit-by-default LLM serving narrative and worth
one paragraph in the discussion section alongside the DAPT and
CORAL null results.
3. **Hardware caveat** int8 latency results would invert on
Hopper/Ada/A100; the 3090 just doesn't have the matmul path. State the
sm_8.6 caveat in the table caption.
## Reproduce
```bash
# from repo root
bun run py:quant
# writes to results/eval/quant/{summary.json, REPORT.md, <variant>/metrics.json}
```
Run time: ~5 minutes total (most spent in fp32 + torchao build steps).