"""Quantization sweep for the iter1-independent ModernBERT-large checkpoint. Loads the trained DualHeadModernBERT, applies a series of quantization schemes to the *encoder* (heads kept in their native dtype), and re-runs holdout evaluation against the GPT-5.4 / Opus-4.6 proxy gold. For each variant we record: - cat / spec macro F1, per-class F1, QWK, MAE, ECE - peak VRAM (encoder forward) - latency (ms/sample, batch=64) and throughput - encoder parameter footprint in MB - delta vs bf16 baseline Variants: fp32, bf16 (baseline), fp16, torchao int8 weight-only, torchao int8 dynamic-act + int8 weight, torchao int4 weight-only (group=128), bitsandbytes LLM.int8 (8-bit), bitsandbytes nf4 (4-bit, double-quant, bf16 compute). Heads (category linear, attention pooler, independent threshold MLPs) stay in bf16 — they sit on a 1024-dim representation and account for < 0.3% of params, so quantizing them buys nothing and risks the threshold margins which already drive most of the spec error budget. Usage: bun run py:quant # via package.json wrapper # or directly: cd python && uv run scripts/quantize_sweep.py """ from __future__ import annotations import gc import json import sys import time import traceback from dataclasses import dataclass, field from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from safetensors.torch import load_file from transformers import AutoModel, AutoTokenizer # Make `src` importable when run as a script ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) from src.finetune.data import CAT2ID, CATEGORIES, NUM_CATEGORIES, NUM_SPECIFICITY # noqa: E402 from src.finetune.eval import ( # noqa: E402 SPEC_LABELS, compute_all_metrics, load_holdout_data, ) from src.finetune.model import DualHeadModernBERT, ordinal_predict # noqa: E402 REPO = ROOT.parent CHECKPOINT = REPO / "checkpoints/finetune/iter1-independent/final" PARAGRAPHS = REPO / "data/paragraphs/paragraphs-clean.patched.jsonl" HOLDOUT = REPO / "data/gold/v2-holdout-ids.json" BENCHMARKS = { "GPT-5.4": REPO / "data/annotations/v2-bench/gpt-5.4.jsonl", "Opus-4.6": REPO / "data/annotations/v2-bench/opus-4.6.jsonl", } OUTPUT_DIR = REPO / "results/eval/quant" BATCH_SIZE = 64 MAX_SEQ = 512 WARMUP_BATCHES = 5 # ────────────────────────────────────────────────────────────────────── # Model loading # ────────────────────────────────────────────────────────────────────── def _build_model(dtype: torch.dtype, attn_impl: str = "sdpa") -> tuple[DualHeadModernBERT, AutoTokenizer]: """Construct DualHeadModernBERT and load trained weights at the requested dtype.""" tokenizer = AutoTokenizer.from_pretrained(str(CHECKPOINT)) backbone = AutoModel.from_pretrained( "answerdotai/ModernBERT-large", trust_remote_code=True, attn_implementation=attn_impl, dtype=dtype, ) model = DualHeadModernBERT( backbone=backbone, hidden_size=backbone.config.hidden_size, num_categories=NUM_CATEGORIES, num_specificity=NUM_SPECIFICITY, specificity_head_type="independent", spec_mlp_dim=256, pooling="attention", ) state = load_file(str(CHECKPOINT / "model.safetensors")) model.load_state_dict(state, strict=False) model = model.to(dtype) model.eval() return model, tokenizer def _try_flash_attn() -> str: try: import flash_attn # noqa: F401 return "flash_attention_2" except ImportError: return "sdpa" # ────────────────────────────────────────────────────────────────────── # Quantization variants # ────────────────────────────────────────────────────────────────────── def variant_native(dtype: torch.dtype, attn: str | None = None): def _build(): impl = attn or _try_flash_attn() # bf16/fp16 supported by flash-attn; fp32 must use sdpa if dtype == torch.float32: impl = "sdpa" model, tok = _build_model(dtype, attn_impl=impl) return model.cuda(), tok return _build def variant_torchao(config_factory): def _build(): from torchao.quantization import quantize_ # torchao expects bf16 master weights model, tok = _build_model(torch.bfloat16, attn_impl=_try_flash_attn()) model = model.cuda() # Quantize encoder linears only (skip heads + attention pooler) quantize_(model.backbone, config_factory()) return model, tok return _build def _swap_bnb_linear( module: nn.Module, mode: str, compute_dtype=torch.bfloat16, compress_statistics: bool = True, ) -> int: """Recursively replace nn.Linear with bnb 8-bit / 4-bit equivalents. Returns number of layers swapped. Copies weights from the original module so the trained checkpoint is preserved. """ import bitsandbytes as bnb swapped = 0 for name, child in list(module.named_children()): if isinstance(child, nn.Linear): in_f, out_f = child.in_features, child.out_features has_bias = child.bias is not None if mode == "int8": new = bnb.nn.Linear8bitLt( in_f, out_f, bias=has_bias, has_fp16_weights=False, threshold=6.0, ) new.weight = bnb.nn.Int8Params( child.weight.data.clone(), requires_grad=False, has_fp16_weights=False, ) if has_bias: new.bias = nn.Parameter(child.bias.data.clone()) elif mode in ("nf4", "fp4"): new = bnb.nn.Linear4bit( in_f, out_f, bias=has_bias, compute_dtype=compute_dtype, quant_type=mode, compress_statistics=compress_statistics, quant_storage=torch.uint8, device="cuda", ) w = child.weight.data.detach().to(torch.float32).clone() new.weight = bnb.nn.Params4bit( w, requires_grad=False, quant_type=mode, compress_statistics=compress_statistics, module=new, ).cuda() if has_bias: new.bias = nn.Parameter( child.bias.data.detach().to(compute_dtype).clone().cuda() ) else: raise ValueError(mode) new = new.cuda() setattr(module, name, new) swapped += 1 else: swapped += _swap_bnb_linear(child, mode, compute_dtype) return swapped def variant_bnb(mode: str, compress_statistics: bool = True): def _build(): model, tok = _build_model(torch.bfloat16, attn_impl="sdpa") model = model.cuda() n = _swap_bnb_linear( model.backbone, mode, compress_statistics=compress_statistics, ) print(f" bnb {mode} (cs={compress_statistics}): swapped {n} linears") return model, tok return _build # ────────────────────────────────────────────────────────────────────── # Inference + measurement # ────────────────────────────────────────────────────────────────────── def _encoder_param_bytes(model: DualHeadModernBERT) -> int: """Sum bytes of every parameter / buffer inside the encoder backbone. Handles bnb Int8Params (int8 storage) and Params4bit (uint8 packed) correctly because element_size() reflects the storage dtype. """ total = 0 seen = set() for p in list(model.backbone.parameters()) + list(model.backbone.buffers()): if id(p) in seen: continue seen.add(id(p)) total += p.numel() * p.element_size() return total @torch.no_grad() def run_inference(model, tokenizer, texts: list[str]) -> dict: device = next(model.parameters()).device cat_logits_list = [] spec_logits_list = [] # Warmup warm_batch = tokenizer( texts[: BATCH_SIZE], truncation=True, max_length=MAX_SEQ, padding="longest", return_tensors="pt", ).to(device) for _ in range(WARMUP_BATCHES): _ = model(input_ids=warm_batch["input_ids"], attention_mask=warm_batch["attention_mask"]) torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats() total_time = 0.0 for i in range(0, len(texts), BATCH_SIZE): batch = texts[i : i + BATCH_SIZE] enc = tokenizer( batch, truncation=True, max_length=MAX_SEQ, padding="longest", return_tensors="pt", ).to(device) torch.cuda.synchronize() t0 = time.perf_counter() out = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"]) torch.cuda.synchronize() total_time += time.perf_counter() - t0 cat_logits_list.append(out["category_logits"].float().cpu()) spec_logits_list.append(out["specificity_logits"].float().cpu()) peak_vram = torch.cuda.max_memory_allocated() cat_logits = torch.cat(cat_logits_list) spec_logits = torch.cat(spec_logits_list) return { "cat_logits": cat_logits, "spec_logits": spec_logits, "total_time_s": total_time, "ms_per_sample": (total_time / len(texts)) * 1000, "throughput": len(texts) / total_time, "peak_vram_mb": peak_vram / (1024 ** 2), "num_samples": len(texts), } def evaluate_predictions( cat_logits: torch.Tensor, spec_logits: torch.Tensor, records: list[dict], ref_name: str, ) -> dict: cat_probs_all = F.softmax(cat_logits, dim=1).numpy() cat_preds_all = cat_logits.argmax(dim=1).numpy() spec_preds_all = ordinal_predict(spec_logits).numpy() # ordinal → class probs sp = torch.sigmoid(spec_logits) K = sp.shape[1] + 1 spec_probs_all = torch.zeros(sp.shape[0], K) spec_probs_all[:, 0] = 1 - sp[:, 0] for k in range(1, K - 1): spec_probs_all[:, k] = sp[:, k - 1] - sp[:, k] spec_probs_all[:, -1] = sp[:, -1] spec_probs_all = spec_probs_all.clamp(min=0) spec_probs_all = spec_probs_all / spec_probs_all.sum(dim=1, keepdim=True) spec_probs_all = spec_probs_all.numpy() cat_labels, spec_labels = [], [] cat_p, spec_p, cat_pr, spec_pr = [], [], [], [] for i, rec in enumerate(records): b = rec["benchmark_labels"].get(ref_name) if b is None: continue cat_labels.append(CAT2ID[b["category"]]) spec_labels.append(b["specificity"] - 1) cat_p.append(cat_preds_all[i]) spec_p.append(spec_preds_all[i]) cat_pr.append(cat_probs_all[i]) spec_pr.append(spec_probs_all[i]) cat_m = compute_all_metrics( np.array(cat_p), np.array(cat_labels), np.array(cat_pr), CATEGORIES, "cat", is_ordinal=False, ) spec_m = compute_all_metrics( np.array(spec_p), np.array(spec_labels), np.array(spec_pr), SPEC_LABELS, "spec", is_ordinal=True, ) return {**cat_m, **spec_m} # ────────────────────────────────────────────────────────────────────── # Variant registry # ────────────────────────────────────────────────────────────────────── @dataclass class Variant: name: str description: str builder: callable skip_reason: str | None = None def build_variants() -> list[Variant]: from torchao.quantization import ( Int4WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, ) return [ Variant("fp32", "Float32 encoder + heads", variant_native(torch.float32, attn="sdpa")), Variant("bf16", "BFloat16 baseline (matches eval pipeline)", variant_native(torch.bfloat16)), Variant("fp16", "Float16 encoder + heads", variant_native(torch.float16)), Variant( "torchao-int8-wo", "torchao Int8 weight-only on encoder linears", variant_torchao(lambda: Int8WeightOnlyConfig()), ), Variant( "torchao-int8-dyn", "torchao Int8 dynamic activation + Int8 weight on encoder", variant_torchao(lambda: Int8DynamicActivationInt8WeightConfig()), ), Variant( "torchao-int4-wo", "torchao Int4 weight-only (group=128) on encoder", variant_torchao(lambda: Int4WeightOnlyConfig(group_size=128)), ), Variant("bnb-int8", "bitsandbytes LLM.int8 on encoder linears", variant_bnb("int8")), Variant("bnb-nf4", "bitsandbytes NF4 4-bit (double-quant, bf16 compute)", variant_bnb("nf4", compress_statistics=True)), Variant("bnb-nf4-nodq", "bitsandbytes NF4 4-bit (no double-quant)", variant_bnb("nf4", compress_statistics=False)), Variant("bnb-fp4", "bitsandbytes FP4 4-bit (no double-quant)", variant_bnb("fp4", compress_statistics=False)), ] # ────────────────────────────────────────────────────────────────────── # Driver # ────────────────────────────────────────────────────────────────────── def free(): gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() def main(): OUTPUT_DIR.mkdir(parents=True, exist_ok=True) print(f"Loading holdout from {HOLDOUT}") records = load_holdout_data( str(PARAGRAPHS), str(HOLDOUT), {k: str(v) for k, v in BENCHMARKS.items()}, ) texts = [r["text"] for r in records] print(f" {len(records)} holdout paragraphs loaded") variants = build_variants() summary = [] for v in variants: print(f"\n══ {v.name} — {v.description}") free() try: t0 = time.perf_counter() model, tokenizer = v.builder() build_s = time.perf_counter() - t0 enc_bytes = _encoder_param_bytes(model) print(f" encoder footprint: {enc_bytes / 1e6:.1f} MB (build {build_s:.1f}s)") inf = run_inference(model, tokenizer, texts) print( f" latency {inf['ms_per_sample']:.2f} ms/sample, " f"throughput {inf['throughput']:.0f}/s, " f"peak VRAM {inf['peak_vram_mb']:.0f} MB" ) metrics_per_ref = {} for ref in BENCHMARKS: m = evaluate_predictions(inf["cat_logits"], inf["spec_logits"], records, ref) metrics_per_ref[ref] = m print( f" vs {ref}: cat F1={m['cat_macro_f1']:.4f}, " f"spec F1={m['spec_macro_f1']:.4f}, QWK={m['spec_qwk']:.4f}, " f"cat ECE={m['cat_ece']:.4f}, spec ECE={m['spec_ece']:.4f}" ) row = { "variant": v.name, "description": v.description, "encoder_mb": enc_bytes / 1e6, "ms_per_sample": inf["ms_per_sample"], "throughput_per_s": inf["throughput"], "peak_vram_mb": inf["peak_vram_mb"], "build_s": build_s, } for ref, m in metrics_per_ref.items(): row[f"{ref}_cat_f1"] = m["cat_macro_f1"] row[f"{ref}_spec_f1"] = m["spec_macro_f1"] row[f"{ref}_cat_mcc"] = m["cat_mcc"] row[f"{ref}_spec_qwk"] = m["spec_qwk"] row[f"{ref}_spec_mae"] = m["spec_mae"] row[f"{ref}_cat_ece"] = m["cat_ece"] row[f"{ref}_spec_ece"] = m["spec_ece"] # per-spec-level F1 for s in SPEC_LABELS: short = s.replace(" ", "").replace(":", "")[:8] row[f"{ref}_spec_f1_{short}"] = m.get(f"spec_f1_{short}", 0) summary.append(row) # Per-variant detailed metrics dump vdir = OUTPUT_DIR / v.name vdir.mkdir(parents=True, exist_ok=True) with open(vdir / "metrics.json", "w") as f: ser = {} for ref, m in metrics_per_ref.items(): ser[ref] = { k: (v_ if not isinstance(v_, np.ndarray) else v_.tolist()) for k, v_ in m.items() if isinstance(v_, (int, float, str, list, bool)) } ser["_runtime"] = { "encoder_mb": enc_bytes / 1e6, "ms_per_sample": inf["ms_per_sample"], "throughput_per_s": inf["throughput"], "peak_vram_mb": inf["peak_vram_mb"], "build_s": build_s, } json.dump(ser, f, indent=2, default=str) del model, tokenizer, inf except Exception as e: print(f" FAILED: {type(e).__name__}: {e}") traceback.print_exc() summary.append({ "variant": v.name, "description": v.description, "error": f"{type(e).__name__}: {e}", }) free() # Write summary summary_path = OUTPUT_DIR / "summary.json" with open(summary_path, "w") as f: json.dump(summary, f, indent=2, default=str) print(f"\nSummary written to {summary_path}") # Print compact table print("\n" + "=" * 110) print(f"{'variant':<18} {'enc MB':>9} {'ms/samp':>9} {'throughput':>11} " f"{'VRAM MB':>9} {'cat F1':>9} {'spec F1':>9} {'spec QWK':>9}") print("-" * 110) for r in summary: if "error" in r: print(f"{r['variant']:<18} ERROR: {r['error']}") continue print( f"{r['variant']:<18} {r['encoder_mb']:>9.1f} {r['ms_per_sample']:>9.2f} " f"{r['throughput_per_s']:>11.0f} {r['peak_vram_mb']:>9.0f} " f"{r['GPT-5.4_cat_f1']:>9.4f} {r['GPT-5.4_spec_f1']:>9.4f} {r['GPT-5.4_spec_qwk']:>9.4f}" ) print("=" * 110) if __name__ == "__main__": main()