Phase 10.8: torchao/bnb quant sweep on iter1-independent. bf16 already
optimal; torchao int8-wo gives -19% VRAM at no F1 cost; all 4-bit
variants collapse (ModernBERT-large too quant-sensitive).
Phase 10.9: ONNX export + ORT eval. Legacy exporter only working path
(dynamo adds 56 Memcpy nodes); ORT fp32 -22% latency vs torch via
kernel fusion but bf16+flash-attn-2 still wins; fp16 broken on rotary;
dynamic int8 silently CPU-fallback + 0.5 F1 collapse.
Driver scripts wired to bun run py:quant / py:onnx; full reports at
results/eval/{quant,onnx}/REPORT.md.
492 lines
19 KiB
Python
492 lines
19 KiB
Python
"""Quantization sweep for the iter1-independent ModernBERT-large checkpoint.
|
|
|
|
Loads the trained DualHeadModernBERT, applies a series of quantization
|
|
schemes to the *encoder* (heads kept in their native dtype), and re-runs
|
|
holdout evaluation against the GPT-5.4 / Opus-4.6 proxy gold.
|
|
|
|
For each variant we record:
|
|
- cat / spec macro F1, per-class F1, QWK, MAE, ECE
|
|
- peak VRAM (encoder forward)
|
|
- latency (ms/sample, batch=64) and throughput
|
|
- encoder parameter footprint in MB
|
|
- delta vs bf16 baseline
|
|
|
|
Variants:
|
|
fp32, bf16 (baseline), fp16,
|
|
torchao int8 weight-only,
|
|
torchao int8 dynamic-act + int8 weight,
|
|
torchao int4 weight-only (group=128),
|
|
bitsandbytes LLM.int8 (8-bit),
|
|
bitsandbytes nf4 (4-bit, double-quant, bf16 compute).
|
|
|
|
Heads (category linear, attention pooler, independent threshold MLPs)
|
|
stay in bf16 — they sit on a 1024-dim representation and account for
|
|
< 0.3% of params, so quantizing them buys nothing and risks the threshold
|
|
margins which already drive most of the spec error budget.
|
|
|
|
Usage:
|
|
bun run py:quant # via package.json wrapper
|
|
# or directly:
|
|
cd python && uv run scripts/quantize_sweep.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import gc
|
|
import json
|
|
import sys
|
|
import time
|
|
import traceback
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from safetensors.torch import load_file
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
|
# Make `src` importable when run as a script
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
sys.path.insert(0, str(ROOT))
|
|
|
|
from src.finetune.data import CAT2ID, CATEGORIES, NUM_CATEGORIES, NUM_SPECIFICITY # noqa: E402
|
|
from src.finetune.eval import ( # noqa: E402
|
|
SPEC_LABELS,
|
|
compute_all_metrics,
|
|
load_holdout_data,
|
|
)
|
|
from src.finetune.model import DualHeadModernBERT, ordinal_predict # noqa: E402
|
|
|
|
REPO = ROOT.parent
|
|
CHECKPOINT = REPO / "checkpoints/finetune/iter1-independent/final"
|
|
PARAGRAPHS = REPO / "data/paragraphs/paragraphs-clean.patched.jsonl"
|
|
HOLDOUT = REPO / "data/gold/v2-holdout-ids.json"
|
|
BENCHMARKS = {
|
|
"GPT-5.4": REPO / "data/annotations/v2-bench/gpt-5.4.jsonl",
|
|
"Opus-4.6": REPO / "data/annotations/v2-bench/opus-4.6.jsonl",
|
|
}
|
|
OUTPUT_DIR = REPO / "results/eval/quant"
|
|
BATCH_SIZE = 64
|
|
MAX_SEQ = 512
|
|
WARMUP_BATCHES = 5
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Model loading
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
def _build_model(dtype: torch.dtype, attn_impl: str = "sdpa") -> tuple[DualHeadModernBERT, AutoTokenizer]:
|
|
"""Construct DualHeadModernBERT and load trained weights at the requested dtype."""
|
|
tokenizer = AutoTokenizer.from_pretrained(str(CHECKPOINT))
|
|
backbone = AutoModel.from_pretrained(
|
|
"answerdotai/ModernBERT-large",
|
|
trust_remote_code=True,
|
|
attn_implementation=attn_impl,
|
|
dtype=dtype,
|
|
)
|
|
model = DualHeadModernBERT(
|
|
backbone=backbone,
|
|
hidden_size=backbone.config.hidden_size,
|
|
num_categories=NUM_CATEGORIES,
|
|
num_specificity=NUM_SPECIFICITY,
|
|
specificity_head_type="independent",
|
|
spec_mlp_dim=256,
|
|
pooling="attention",
|
|
)
|
|
state = load_file(str(CHECKPOINT / "model.safetensors"))
|
|
model.load_state_dict(state, strict=False)
|
|
model = model.to(dtype)
|
|
model.eval()
|
|
return model, tokenizer
|
|
|
|
|
|
def _try_flash_attn() -> str:
|
|
try:
|
|
import flash_attn # noqa: F401
|
|
return "flash_attention_2"
|
|
except ImportError:
|
|
return "sdpa"
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Quantization variants
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
def variant_native(dtype: torch.dtype, attn: str | None = None):
|
|
def _build():
|
|
impl = attn or _try_flash_attn()
|
|
# bf16/fp16 supported by flash-attn; fp32 must use sdpa
|
|
if dtype == torch.float32:
|
|
impl = "sdpa"
|
|
model, tok = _build_model(dtype, attn_impl=impl)
|
|
return model.cuda(), tok
|
|
return _build
|
|
|
|
|
|
def variant_torchao(config_factory):
|
|
def _build():
|
|
from torchao.quantization import quantize_
|
|
# torchao expects bf16 master weights
|
|
model, tok = _build_model(torch.bfloat16, attn_impl=_try_flash_attn())
|
|
model = model.cuda()
|
|
# Quantize encoder linears only (skip heads + attention pooler)
|
|
quantize_(model.backbone, config_factory())
|
|
return model, tok
|
|
return _build
|
|
|
|
|
|
def _swap_bnb_linear(
|
|
module: nn.Module,
|
|
mode: str,
|
|
compute_dtype=torch.bfloat16,
|
|
compress_statistics: bool = True,
|
|
) -> int:
|
|
"""Recursively replace nn.Linear with bnb 8-bit / 4-bit equivalents.
|
|
|
|
Returns number of layers swapped. Copies weights from the original
|
|
module so the trained checkpoint is preserved.
|
|
"""
|
|
import bitsandbytes as bnb
|
|
|
|
swapped = 0
|
|
for name, child in list(module.named_children()):
|
|
if isinstance(child, nn.Linear):
|
|
in_f, out_f = child.in_features, child.out_features
|
|
has_bias = child.bias is not None
|
|
if mode == "int8":
|
|
new = bnb.nn.Linear8bitLt(
|
|
in_f, out_f, bias=has_bias,
|
|
has_fp16_weights=False, threshold=6.0,
|
|
)
|
|
new.weight = bnb.nn.Int8Params(
|
|
child.weight.data.clone(),
|
|
requires_grad=False,
|
|
has_fp16_weights=False,
|
|
)
|
|
if has_bias:
|
|
new.bias = nn.Parameter(child.bias.data.clone())
|
|
elif mode in ("nf4", "fp4"):
|
|
new = bnb.nn.Linear4bit(
|
|
in_f, out_f, bias=has_bias,
|
|
compute_dtype=compute_dtype,
|
|
quant_type=mode,
|
|
compress_statistics=compress_statistics,
|
|
quant_storage=torch.uint8,
|
|
device="cuda",
|
|
)
|
|
w = child.weight.data.detach().to(torch.float32).clone()
|
|
new.weight = bnb.nn.Params4bit(
|
|
w, requires_grad=False, quant_type=mode,
|
|
compress_statistics=compress_statistics, module=new,
|
|
).cuda()
|
|
if has_bias:
|
|
new.bias = nn.Parameter(
|
|
child.bias.data.detach().to(compute_dtype).clone().cuda()
|
|
)
|
|
else:
|
|
raise ValueError(mode)
|
|
new = new.cuda()
|
|
setattr(module, name, new)
|
|
swapped += 1
|
|
else:
|
|
swapped += _swap_bnb_linear(child, mode, compute_dtype)
|
|
return swapped
|
|
|
|
|
|
def variant_bnb(mode: str, compress_statistics: bool = True):
|
|
def _build():
|
|
model, tok = _build_model(torch.bfloat16, attn_impl="sdpa")
|
|
model = model.cuda()
|
|
n = _swap_bnb_linear(
|
|
model.backbone, mode, compress_statistics=compress_statistics,
|
|
)
|
|
print(f" bnb {mode} (cs={compress_statistics}): swapped {n} linears")
|
|
return model, tok
|
|
return _build
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Inference + measurement
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
def _encoder_param_bytes(model: DualHeadModernBERT) -> int:
|
|
"""Sum bytes of every parameter / buffer inside the encoder backbone.
|
|
|
|
Handles bnb Int8Params (int8 storage) and Params4bit (uint8 packed)
|
|
correctly because element_size() reflects the storage dtype.
|
|
"""
|
|
total = 0
|
|
seen = set()
|
|
for p in list(model.backbone.parameters()) + list(model.backbone.buffers()):
|
|
if id(p) in seen:
|
|
continue
|
|
seen.add(id(p))
|
|
total += p.numel() * p.element_size()
|
|
return total
|
|
|
|
|
|
@torch.no_grad()
|
|
def run_inference(model, tokenizer, texts: list[str]) -> dict:
|
|
device = next(model.parameters()).device
|
|
cat_logits_list = []
|
|
spec_logits_list = []
|
|
|
|
# Warmup
|
|
warm_batch = tokenizer(
|
|
texts[: BATCH_SIZE], truncation=True, max_length=MAX_SEQ,
|
|
padding="longest", return_tensors="pt",
|
|
).to(device)
|
|
for _ in range(WARMUP_BATCHES):
|
|
_ = model(input_ids=warm_batch["input_ids"], attention_mask=warm_batch["attention_mask"])
|
|
torch.cuda.synchronize()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
total_time = 0.0
|
|
for i in range(0, len(texts), BATCH_SIZE):
|
|
batch = texts[i : i + BATCH_SIZE]
|
|
enc = tokenizer(
|
|
batch, truncation=True, max_length=MAX_SEQ,
|
|
padding="longest", return_tensors="pt",
|
|
).to(device)
|
|
torch.cuda.synchronize()
|
|
t0 = time.perf_counter()
|
|
out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"])
|
|
torch.cuda.synchronize()
|
|
total_time += time.perf_counter() - t0
|
|
cat_logits_list.append(out["category_logits"].float().cpu())
|
|
spec_logits_list.append(out["specificity_logits"].float().cpu())
|
|
|
|
peak_vram = torch.cuda.max_memory_allocated()
|
|
cat_logits = torch.cat(cat_logits_list)
|
|
spec_logits = torch.cat(spec_logits_list)
|
|
return {
|
|
"cat_logits": cat_logits,
|
|
"spec_logits": spec_logits,
|
|
"total_time_s": total_time,
|
|
"ms_per_sample": (total_time / len(texts)) * 1000,
|
|
"throughput": len(texts) / total_time,
|
|
"peak_vram_mb": peak_vram / (1024 ** 2),
|
|
"num_samples": len(texts),
|
|
}
|
|
|
|
|
|
def evaluate_predictions(
|
|
cat_logits: torch.Tensor,
|
|
spec_logits: torch.Tensor,
|
|
records: list[dict],
|
|
ref_name: str,
|
|
) -> dict:
|
|
cat_probs_all = F.softmax(cat_logits, dim=1).numpy()
|
|
cat_preds_all = cat_logits.argmax(dim=1).numpy()
|
|
spec_preds_all = ordinal_predict(spec_logits).numpy()
|
|
# ordinal → class probs
|
|
sp = torch.sigmoid(spec_logits)
|
|
K = sp.shape[1] + 1
|
|
spec_probs_all = torch.zeros(sp.shape[0], K)
|
|
spec_probs_all[:, 0] = 1 - sp[:, 0]
|
|
for k in range(1, K - 1):
|
|
spec_probs_all[:, k] = sp[:, k - 1] - sp[:, k]
|
|
spec_probs_all[:, -1] = sp[:, -1]
|
|
spec_probs_all = spec_probs_all.clamp(min=0)
|
|
spec_probs_all = spec_probs_all / spec_probs_all.sum(dim=1, keepdim=True)
|
|
spec_probs_all = spec_probs_all.numpy()
|
|
|
|
cat_labels, spec_labels = [], []
|
|
cat_p, spec_p, cat_pr, spec_pr = [], [], [], []
|
|
for i, rec in enumerate(records):
|
|
b = rec["benchmark_labels"].get(ref_name)
|
|
if b is None:
|
|
continue
|
|
cat_labels.append(CAT2ID[b["category"]])
|
|
spec_labels.append(b["specificity"] - 1)
|
|
cat_p.append(cat_preds_all[i])
|
|
spec_p.append(spec_preds_all[i])
|
|
cat_pr.append(cat_probs_all[i])
|
|
spec_pr.append(spec_probs_all[i])
|
|
|
|
cat_m = compute_all_metrics(
|
|
np.array(cat_p), np.array(cat_labels), np.array(cat_pr),
|
|
CATEGORIES, "cat", is_ordinal=False,
|
|
)
|
|
spec_m = compute_all_metrics(
|
|
np.array(spec_p), np.array(spec_labels), np.array(spec_pr),
|
|
SPEC_LABELS, "spec", is_ordinal=True,
|
|
)
|
|
return {**cat_m, **spec_m}
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Variant registry
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
@dataclass
|
|
class Variant:
|
|
name: str
|
|
description: str
|
|
builder: callable
|
|
skip_reason: str | None = None
|
|
|
|
|
|
def build_variants() -> list[Variant]:
|
|
from torchao.quantization import (
|
|
Int4WeightOnlyConfig,
|
|
Int8DynamicActivationInt8WeightConfig,
|
|
Int8WeightOnlyConfig,
|
|
)
|
|
|
|
return [
|
|
Variant("fp32", "Float32 encoder + heads", variant_native(torch.float32, attn="sdpa")),
|
|
Variant("bf16", "BFloat16 baseline (matches eval pipeline)", variant_native(torch.bfloat16)),
|
|
Variant("fp16", "Float16 encoder + heads", variant_native(torch.float16)),
|
|
Variant(
|
|
"torchao-int8-wo",
|
|
"torchao Int8 weight-only on encoder linears",
|
|
variant_torchao(lambda: Int8WeightOnlyConfig()),
|
|
),
|
|
Variant(
|
|
"torchao-int8-dyn",
|
|
"torchao Int8 dynamic activation + Int8 weight on encoder",
|
|
variant_torchao(lambda: Int8DynamicActivationInt8WeightConfig()),
|
|
),
|
|
Variant(
|
|
"torchao-int4-wo",
|
|
"torchao Int4 weight-only (group=128) on encoder",
|
|
variant_torchao(lambda: Int4WeightOnlyConfig(group_size=128)),
|
|
),
|
|
Variant("bnb-int8", "bitsandbytes LLM.int8 on encoder linears", variant_bnb("int8")),
|
|
Variant("bnb-nf4", "bitsandbytes NF4 4-bit (double-quant, bf16 compute)", variant_bnb("nf4", compress_statistics=True)),
|
|
Variant("bnb-nf4-nodq", "bitsandbytes NF4 4-bit (no double-quant)", variant_bnb("nf4", compress_statistics=False)),
|
|
Variant("bnb-fp4", "bitsandbytes FP4 4-bit (no double-quant)", variant_bnb("fp4", compress_statistics=False)),
|
|
]
|
|
|
|
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
# Driver
|
|
# ──────────────────────────────────────────────────────────────────────
|
|
|
|
def free():
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
def main():
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
print(f"Loading holdout from {HOLDOUT}")
|
|
records = load_holdout_data(
|
|
str(PARAGRAPHS), str(HOLDOUT), {k: str(v) for k, v in BENCHMARKS.items()},
|
|
)
|
|
texts = [r["text"] for r in records]
|
|
print(f" {len(records)} holdout paragraphs loaded")
|
|
|
|
variants = build_variants()
|
|
summary = []
|
|
|
|
for v in variants:
|
|
print(f"\n══ {v.name} — {v.description}")
|
|
free()
|
|
try:
|
|
t0 = time.perf_counter()
|
|
model, tokenizer = v.builder()
|
|
build_s = time.perf_counter() - t0
|
|
enc_bytes = _encoder_param_bytes(model)
|
|
print(f" encoder footprint: {enc_bytes / 1e6:.1f} MB (build {build_s:.1f}s)")
|
|
inf = run_inference(model, tokenizer, texts)
|
|
print(
|
|
f" latency {inf['ms_per_sample']:.2f} ms/sample, "
|
|
f"throughput {inf['throughput']:.0f}/s, "
|
|
f"peak VRAM {inf['peak_vram_mb']:.0f} MB"
|
|
)
|
|
|
|
metrics_per_ref = {}
|
|
for ref in BENCHMARKS:
|
|
m = evaluate_predictions(inf["cat_logits"], inf["spec_logits"], records, ref)
|
|
metrics_per_ref[ref] = m
|
|
print(
|
|
f" vs {ref}: cat F1={m['cat_macro_f1']:.4f}, "
|
|
f"spec F1={m['spec_macro_f1']:.4f}, QWK={m['spec_qwk']:.4f}, "
|
|
f"cat ECE={m['cat_ece']:.4f}, spec ECE={m['spec_ece']:.4f}"
|
|
)
|
|
|
|
row = {
|
|
"variant": v.name,
|
|
"description": v.description,
|
|
"encoder_mb": enc_bytes / 1e6,
|
|
"ms_per_sample": inf["ms_per_sample"],
|
|
"throughput_per_s": inf["throughput"],
|
|
"peak_vram_mb": inf["peak_vram_mb"],
|
|
"build_s": build_s,
|
|
}
|
|
for ref, m in metrics_per_ref.items():
|
|
row[f"{ref}_cat_f1"] = m["cat_macro_f1"]
|
|
row[f"{ref}_spec_f1"] = m["spec_macro_f1"]
|
|
row[f"{ref}_cat_mcc"] = m["cat_mcc"]
|
|
row[f"{ref}_spec_qwk"] = m["spec_qwk"]
|
|
row[f"{ref}_spec_mae"] = m["spec_mae"]
|
|
row[f"{ref}_cat_ece"] = m["cat_ece"]
|
|
row[f"{ref}_spec_ece"] = m["spec_ece"]
|
|
# per-spec-level F1
|
|
for s in SPEC_LABELS:
|
|
short = s.replace(" ", "").replace(":", "")[:8]
|
|
row[f"{ref}_spec_f1_{short}"] = m.get(f"spec_f1_{short}", 0)
|
|
summary.append(row)
|
|
|
|
# Per-variant detailed metrics dump
|
|
vdir = OUTPUT_DIR / v.name
|
|
vdir.mkdir(parents=True, exist_ok=True)
|
|
with open(vdir / "metrics.json", "w") as f:
|
|
ser = {}
|
|
for ref, m in metrics_per_ref.items():
|
|
ser[ref] = {
|
|
k: (v_ if not isinstance(v_, np.ndarray) else v_.tolist())
|
|
for k, v_ in m.items()
|
|
if isinstance(v_, (int, float, str, list, bool))
|
|
}
|
|
ser["_runtime"] = {
|
|
"encoder_mb": enc_bytes / 1e6,
|
|
"ms_per_sample": inf["ms_per_sample"],
|
|
"throughput_per_s": inf["throughput"],
|
|
"peak_vram_mb": inf["peak_vram_mb"],
|
|
"build_s": build_s,
|
|
}
|
|
json.dump(ser, f, indent=2, default=str)
|
|
|
|
del model, tokenizer, inf
|
|
except Exception as e:
|
|
print(f" FAILED: {type(e).__name__}: {e}")
|
|
traceback.print_exc()
|
|
summary.append({
|
|
"variant": v.name,
|
|
"description": v.description,
|
|
"error": f"{type(e).__name__}: {e}",
|
|
})
|
|
free()
|
|
|
|
# Write summary
|
|
summary_path = OUTPUT_DIR / "summary.json"
|
|
with open(summary_path, "w") as f:
|
|
json.dump(summary, f, indent=2, default=str)
|
|
print(f"\nSummary written to {summary_path}")
|
|
|
|
# Print compact table
|
|
print("\n" + "=" * 110)
|
|
print(f"{'variant':<18} {'enc MB':>9} {'ms/samp':>9} {'throughput':>11} "
|
|
f"{'VRAM MB':>9} {'cat F1':>9} {'spec F1':>9} {'spec QWK':>9}")
|
|
print("-" * 110)
|
|
for r in summary:
|
|
if "error" in r:
|
|
print(f"{r['variant']:<18} ERROR: {r['error']}")
|
|
continue
|
|
print(
|
|
f"{r['variant']:<18} {r['encoder_mb']:>9.1f} {r['ms_per_sample']:>9.2f} "
|
|
f"{r['throughput_per_s']:>11.0f} {r['peak_vram_mb']:>9.0f} "
|
|
f"{r['GPT-5.4_cat_f1']:>9.4f} {r['GPT-5.4_spec_f1']:>9.4f} {r['GPT-5.4_spec_qwk']:>9.4f}"
|
|
)
|
|
print("=" * 110)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|