SEC-cyBERT/python/scripts/quantize_sweep.py
Joey Eamigh 67beaede45
quantization + onnx sweeps
Phase 10.8: torchao/bnb quant sweep on iter1-independent. bf16 already
optimal; torchao int8-wo gives -19% VRAM at no F1 cost; all 4-bit
variants collapse (ModernBERT-large too quant-sensitive).

Phase 10.9: ONNX export + ORT eval. Legacy exporter only working path
(dynamo adds 56 Memcpy nodes); ORT fp32 -22% latency vs torch via
kernel fusion but bf16+flash-attn-2 still wins; fp16 broken on rotary;
dynamic int8 silently CPU-fallback + 0.5 F1 collapse.

Driver scripts wired to bun run py:quant / py:onnx; full reports at
results/eval/{quant,onnx}/REPORT.md.
2026-04-07 05:10:38 -04:00

492 lines
19 KiB
Python

"""Quantization sweep for the iter1-independent ModernBERT-large checkpoint.
Loads the trained DualHeadModernBERT, applies a series of quantization
schemes to the *encoder* (heads kept in their native dtype), and re-runs
holdout evaluation against the GPT-5.4 / Opus-4.6 proxy gold.
For each variant we record:
- cat / spec macro F1, per-class F1, QWK, MAE, ECE
- peak VRAM (encoder forward)
- latency (ms/sample, batch=64) and throughput
- encoder parameter footprint in MB
- delta vs bf16 baseline
Variants:
fp32, bf16 (baseline), fp16,
torchao int8 weight-only,
torchao int8 dynamic-act + int8 weight,
torchao int4 weight-only (group=128),
bitsandbytes LLM.int8 (8-bit),
bitsandbytes nf4 (4-bit, double-quant, bf16 compute).
Heads (category linear, attention pooler, independent threshold MLPs)
stay in bf16 — they sit on a 1024-dim representation and account for
< 0.3% of params, so quantizing them buys nothing and risks the threshold
margins which already drive most of the spec error budget.
Usage:
bun run py:quant # via package.json wrapper
# or directly:
cd python && uv run scripts/quantize_sweep.py
"""
from __future__ import annotations
import gc
import json
import sys
import time
import traceback
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
# Make `src` importable when run as a script
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from src.finetune.data import CAT2ID, CATEGORIES, NUM_CATEGORIES, NUM_SPECIFICITY # noqa: E402
from src.finetune.eval import ( # noqa: E402
SPEC_LABELS,
compute_all_metrics,
load_holdout_data,
)
from src.finetune.model import DualHeadModernBERT, ordinal_predict # noqa: E402
REPO = ROOT.parent
CHECKPOINT = REPO / "checkpoints/finetune/iter1-independent/final"
PARAGRAPHS = REPO / "data/paragraphs/paragraphs-clean.patched.jsonl"
HOLDOUT = REPO / "data/gold/v2-holdout-ids.json"
BENCHMARKS = {
"GPT-5.4": REPO / "data/annotations/v2-bench/gpt-5.4.jsonl",
"Opus-4.6": REPO / "data/annotations/v2-bench/opus-4.6.jsonl",
}
OUTPUT_DIR = REPO / "results/eval/quant"
BATCH_SIZE = 64
MAX_SEQ = 512
WARMUP_BATCHES = 5
# ──────────────────────────────────────────────────────────────────────
# Model loading
# ──────────────────────────────────────────────────────────────────────
def _build_model(dtype: torch.dtype, attn_impl: str = "sdpa") -> tuple[DualHeadModernBERT, AutoTokenizer]:
"""Construct DualHeadModernBERT and load trained weights at the requested dtype."""
tokenizer = AutoTokenizer.from_pretrained(str(CHECKPOINT))
backbone = AutoModel.from_pretrained(
"answerdotai/ModernBERT-large",
trust_remote_code=True,
attn_implementation=attn_impl,
dtype=dtype,
)
model = DualHeadModernBERT(
backbone=backbone,
hidden_size=backbone.config.hidden_size,
num_categories=NUM_CATEGORIES,
num_specificity=NUM_SPECIFICITY,
specificity_head_type="independent",
spec_mlp_dim=256,
pooling="attention",
)
state = load_file(str(CHECKPOINT / "model.safetensors"))
model.load_state_dict(state, strict=False)
model = model.to(dtype)
model.eval()
return model, tokenizer
def _try_flash_attn() -> str:
try:
import flash_attn # noqa: F401
return "flash_attention_2"
except ImportError:
return "sdpa"
# ──────────────────────────────────────────────────────────────────────
# Quantization variants
# ──────────────────────────────────────────────────────────────────────
def variant_native(dtype: torch.dtype, attn: str | None = None):
def _build():
impl = attn or _try_flash_attn()
# bf16/fp16 supported by flash-attn; fp32 must use sdpa
if dtype == torch.float32:
impl = "sdpa"
model, tok = _build_model(dtype, attn_impl=impl)
return model.cuda(), tok
return _build
def variant_torchao(config_factory):
def _build():
from torchao.quantization import quantize_
# torchao expects bf16 master weights
model, tok = _build_model(torch.bfloat16, attn_impl=_try_flash_attn())
model = model.cuda()
# Quantize encoder linears only (skip heads + attention pooler)
quantize_(model.backbone, config_factory())
return model, tok
return _build
def _swap_bnb_linear(
module: nn.Module,
mode: str,
compute_dtype=torch.bfloat16,
compress_statistics: bool = True,
) -> int:
"""Recursively replace nn.Linear with bnb 8-bit / 4-bit equivalents.
Returns number of layers swapped. Copies weights from the original
module so the trained checkpoint is preserved.
"""
import bitsandbytes as bnb
swapped = 0
for name, child in list(module.named_children()):
if isinstance(child, nn.Linear):
in_f, out_f = child.in_features, child.out_features
has_bias = child.bias is not None
if mode == "int8":
new = bnb.nn.Linear8bitLt(
in_f, out_f, bias=has_bias,
has_fp16_weights=False, threshold=6.0,
)
new.weight = bnb.nn.Int8Params(
child.weight.data.clone(),
requires_grad=False,
has_fp16_weights=False,
)
if has_bias:
new.bias = nn.Parameter(child.bias.data.clone())
elif mode in ("nf4", "fp4"):
new = bnb.nn.Linear4bit(
in_f, out_f, bias=has_bias,
compute_dtype=compute_dtype,
quant_type=mode,
compress_statistics=compress_statistics,
quant_storage=torch.uint8,
device="cuda",
)
w = child.weight.data.detach().to(torch.float32).clone()
new.weight = bnb.nn.Params4bit(
w, requires_grad=False, quant_type=mode,
compress_statistics=compress_statistics, module=new,
).cuda()
if has_bias:
new.bias = nn.Parameter(
child.bias.data.detach().to(compute_dtype).clone().cuda()
)
else:
raise ValueError(mode)
new = new.cuda()
setattr(module, name, new)
swapped += 1
else:
swapped += _swap_bnb_linear(child, mode, compute_dtype)
return swapped
def variant_bnb(mode: str, compress_statistics: bool = True):
def _build():
model, tok = _build_model(torch.bfloat16, attn_impl="sdpa")
model = model.cuda()
n = _swap_bnb_linear(
model.backbone, mode, compress_statistics=compress_statistics,
)
print(f" bnb {mode} (cs={compress_statistics}): swapped {n} linears")
return model, tok
return _build
# ──────────────────────────────────────────────────────────────────────
# Inference + measurement
# ──────────────────────────────────────────────────────────────────────
def _encoder_param_bytes(model: DualHeadModernBERT) -> int:
"""Sum bytes of every parameter / buffer inside the encoder backbone.
Handles bnb Int8Params (int8 storage) and Params4bit (uint8 packed)
correctly because element_size() reflects the storage dtype.
"""
total = 0
seen = set()
for p in list(model.backbone.parameters()) + list(model.backbone.buffers()):
if id(p) in seen:
continue
seen.add(id(p))
total += p.numel() * p.element_size()
return total
@torch.no_grad()
def run_inference(model, tokenizer, texts: list[str]) -> dict:
device = next(model.parameters()).device
cat_logits_list = []
spec_logits_list = []
# Warmup
warm_batch = tokenizer(
texts[: BATCH_SIZE], truncation=True, max_length=MAX_SEQ,
padding="longest", return_tensors="pt",
).to(device)
for _ in range(WARMUP_BATCHES):
_ = model(input_ids=warm_batch["input_ids"], attention_mask=warm_batch["attention_mask"])
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
total_time = 0.0
for i in range(0, len(texts), BATCH_SIZE):
batch = texts[i : i + BATCH_SIZE]
enc = tokenizer(
batch, truncation=True, max_length=MAX_SEQ,
padding="longest", return_tensors="pt",
).to(device)
torch.cuda.synchronize()
t0 = time.perf_counter()
out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"])
torch.cuda.synchronize()
total_time += time.perf_counter() - t0
cat_logits_list.append(out["category_logits"].float().cpu())
spec_logits_list.append(out["specificity_logits"].float().cpu())
peak_vram = torch.cuda.max_memory_allocated()
cat_logits = torch.cat(cat_logits_list)
spec_logits = torch.cat(spec_logits_list)
return {
"cat_logits": cat_logits,
"spec_logits": spec_logits,
"total_time_s": total_time,
"ms_per_sample": (total_time / len(texts)) * 1000,
"throughput": len(texts) / total_time,
"peak_vram_mb": peak_vram / (1024 ** 2),
"num_samples": len(texts),
}
def evaluate_predictions(
cat_logits: torch.Tensor,
spec_logits: torch.Tensor,
records: list[dict],
ref_name: str,
) -> dict:
cat_probs_all = F.softmax(cat_logits, dim=1).numpy()
cat_preds_all = cat_logits.argmax(dim=1).numpy()
spec_preds_all = ordinal_predict(spec_logits).numpy()
# ordinal → class probs
sp = torch.sigmoid(spec_logits)
K = sp.shape[1] + 1
spec_probs_all = torch.zeros(sp.shape[0], K)
spec_probs_all[:, 0] = 1 - sp[:, 0]
for k in range(1, K - 1):
spec_probs_all[:, k] = sp[:, k - 1] - sp[:, k]
spec_probs_all[:, -1] = sp[:, -1]
spec_probs_all = spec_probs_all.clamp(min=0)
spec_probs_all = spec_probs_all / spec_probs_all.sum(dim=1, keepdim=True)
spec_probs_all = spec_probs_all.numpy()
cat_labels, spec_labels = [], []
cat_p, spec_p, cat_pr, spec_pr = [], [], [], []
for i, rec in enumerate(records):
b = rec["benchmark_labels"].get(ref_name)
if b is None:
continue
cat_labels.append(CAT2ID[b["category"]])
spec_labels.append(b["specificity"] - 1)
cat_p.append(cat_preds_all[i])
spec_p.append(spec_preds_all[i])
cat_pr.append(cat_probs_all[i])
spec_pr.append(spec_probs_all[i])
cat_m = compute_all_metrics(
np.array(cat_p), np.array(cat_labels), np.array(cat_pr),
CATEGORIES, "cat", is_ordinal=False,
)
spec_m = compute_all_metrics(
np.array(spec_p), np.array(spec_labels), np.array(spec_pr),
SPEC_LABELS, "spec", is_ordinal=True,
)
return {**cat_m, **spec_m}
# ──────────────────────────────────────────────────────────────────────
# Variant registry
# ──────────────────────────────────────────────────────────────────────
@dataclass
class Variant:
name: str
description: str
builder: callable
skip_reason: str | None = None
def build_variants() -> list[Variant]:
from torchao.quantization import (
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
)
return [
Variant("fp32", "Float32 encoder + heads", variant_native(torch.float32, attn="sdpa")),
Variant("bf16", "BFloat16 baseline (matches eval pipeline)", variant_native(torch.bfloat16)),
Variant("fp16", "Float16 encoder + heads", variant_native(torch.float16)),
Variant(
"torchao-int8-wo",
"torchao Int8 weight-only on encoder linears",
variant_torchao(lambda: Int8WeightOnlyConfig()),
),
Variant(
"torchao-int8-dyn",
"torchao Int8 dynamic activation + Int8 weight on encoder",
variant_torchao(lambda: Int8DynamicActivationInt8WeightConfig()),
),
Variant(
"torchao-int4-wo",
"torchao Int4 weight-only (group=128) on encoder",
variant_torchao(lambda: Int4WeightOnlyConfig(group_size=128)),
),
Variant("bnb-int8", "bitsandbytes LLM.int8 on encoder linears", variant_bnb("int8")),
Variant("bnb-nf4", "bitsandbytes NF4 4-bit (double-quant, bf16 compute)", variant_bnb("nf4", compress_statistics=True)),
Variant("bnb-nf4-nodq", "bitsandbytes NF4 4-bit (no double-quant)", variant_bnb("nf4", compress_statistics=False)),
Variant("bnb-fp4", "bitsandbytes FP4 4-bit (no double-quant)", variant_bnb("fp4", compress_statistics=False)),
]
# ──────────────────────────────────────────────────────────────────────
# Driver
# ──────────────────────────────────────────────────────────────────────
def free():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
def main():
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
print(f"Loading holdout from {HOLDOUT}")
records = load_holdout_data(
str(PARAGRAPHS), str(HOLDOUT), {k: str(v) for k, v in BENCHMARKS.items()},
)
texts = [r["text"] for r in records]
print(f" {len(records)} holdout paragraphs loaded")
variants = build_variants()
summary = []
for v in variants:
print(f"\n══ {v.name}{v.description}")
free()
try:
t0 = time.perf_counter()
model, tokenizer = v.builder()
build_s = time.perf_counter() - t0
enc_bytes = _encoder_param_bytes(model)
print(f" encoder footprint: {enc_bytes / 1e6:.1f} MB (build {build_s:.1f}s)")
inf = run_inference(model, tokenizer, texts)
print(
f" latency {inf['ms_per_sample']:.2f} ms/sample, "
f"throughput {inf['throughput']:.0f}/s, "
f"peak VRAM {inf['peak_vram_mb']:.0f} MB"
)
metrics_per_ref = {}
for ref in BENCHMARKS:
m = evaluate_predictions(inf["cat_logits"], inf["spec_logits"], records, ref)
metrics_per_ref[ref] = m
print(
f" vs {ref}: cat F1={m['cat_macro_f1']:.4f}, "
f"spec F1={m['spec_macro_f1']:.4f}, QWK={m['spec_qwk']:.4f}, "
f"cat ECE={m['cat_ece']:.4f}, spec ECE={m['spec_ece']:.4f}"
)
row = {
"variant": v.name,
"description": v.description,
"encoder_mb": enc_bytes / 1e6,
"ms_per_sample": inf["ms_per_sample"],
"throughput_per_s": inf["throughput"],
"peak_vram_mb": inf["peak_vram_mb"],
"build_s": build_s,
}
for ref, m in metrics_per_ref.items():
row[f"{ref}_cat_f1"] = m["cat_macro_f1"]
row[f"{ref}_spec_f1"] = m["spec_macro_f1"]
row[f"{ref}_cat_mcc"] = m["cat_mcc"]
row[f"{ref}_spec_qwk"] = m["spec_qwk"]
row[f"{ref}_spec_mae"] = m["spec_mae"]
row[f"{ref}_cat_ece"] = m["cat_ece"]
row[f"{ref}_spec_ece"] = m["spec_ece"]
# per-spec-level F1
for s in SPEC_LABELS:
short = s.replace(" ", "").replace(":", "")[:8]
row[f"{ref}_spec_f1_{short}"] = m.get(f"spec_f1_{short}", 0)
summary.append(row)
# Per-variant detailed metrics dump
vdir = OUTPUT_DIR / v.name
vdir.mkdir(parents=True, exist_ok=True)
with open(vdir / "metrics.json", "w") as f:
ser = {}
for ref, m in metrics_per_ref.items():
ser[ref] = {
k: (v_ if not isinstance(v_, np.ndarray) else v_.tolist())
for k, v_ in m.items()
if isinstance(v_, (int, float, str, list, bool))
}
ser["_runtime"] = {
"encoder_mb": enc_bytes / 1e6,
"ms_per_sample": inf["ms_per_sample"],
"throughput_per_s": inf["throughput"],
"peak_vram_mb": inf["peak_vram_mb"],
"build_s": build_s,
}
json.dump(ser, f, indent=2, default=str)
del model, tokenizer, inf
except Exception as e:
print(f" FAILED: {type(e).__name__}: {e}")
traceback.print_exc()
summary.append({
"variant": v.name,
"description": v.description,
"error": f"{type(e).__name__}: {e}",
})
free()
# Write summary
summary_path = OUTPUT_DIR / "summary.json"
with open(summary_path, "w") as f:
json.dump(summary, f, indent=2, default=str)
print(f"\nSummary written to {summary_path}")
# Print compact table
print("\n" + "=" * 110)
print(f"{'variant':<18} {'enc MB':>9} {'ms/samp':>9} {'throughput':>11} "
f"{'VRAM MB':>9} {'cat F1':>9} {'spec F1':>9} {'spec QWK':>9}")
print("-" * 110)
for r in summary:
if "error" in r:
print(f"{r['variant']:<18} ERROR: {r['error']}")
continue
print(
f"{r['variant']:<18} {r['encoder_mb']:>9.1f} {r['ms_per_sample']:>9.2f} "
f"{r['throughput_per_s']:>11.0f} {r['peak_vram_mb']:>9.0f} "
f"{r['GPT-5.4_cat_f1']:>9.4f} {r['GPT-5.4_spec_f1']:>9.4f} {r['GPT-5.4_spec_qwk']:>9.4f}"
)
print("=" * 110)
if __name__ == "__main__":
main()