Phase 10.8: torchao/bnb quant sweep on iter1-independent. bf16 already
optimal; torchao int8-wo gives -19% VRAM at no F1 cost; all 4-bit
variants collapse (ModernBERT-large too quant-sensitive).
Phase 10.9: ONNX export + ORT eval. Legacy exporter only working path
(dynamo adds 56 Memcpy nodes); ORT fp32 -22% latency vs torch via
kernel fusion but bf16+flash-attn-2 still wins; fp16 broken on rotary;
dynamic int8 silently CPU-fallback + 0.5 F1 collapse.
Driver scripts wired to bun run py:quant / py:onnx; full reports at
results/eval/{quant,onnx}/REPORT.md.
370 lines
14 KiB
Python
370 lines
14 KiB
Python
"""ONNX export + eval for the iter1-independent ModernBERT-large checkpoint.
|
||
|
||
Variants:
|
||
onnx-fp32 — straight torch.onnx.export from the fp32 model
|
||
onnx-fp16 — fp32 export converted to fp16 via onnxconverter_common
|
||
(proxy for bf16; ORT does not support bf16 inference natively)
|
||
onnx-int8-dyn — dynamic int8 quantization of the fp32 graph via
|
||
onnxruntime.quantization.quantize_dynamic (weights in int8,
|
||
activations quantized at runtime)
|
||
|
||
For each variant:
|
||
- latency (ms/sample, batch=64, 5 warmup batches)
|
||
- peak GPU memory delta around the session (free-mem snapshot)
|
||
- on-disk size of model.onnx + model.onnx.data
|
||
- cat / spec macro F1, QWK, ECE on the 1,200-paragraph holdout
|
||
against GPT-5.4 + Opus-4.6 proxy gold
|
||
|
||
Usage:
|
||
bun run py:onnx
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import gc
|
||
import json
|
||
import os
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
ROOT = Path(__file__).resolve().parents[1]
|
||
sys.path.insert(0, str(ROOT))
|
||
|
||
from src.finetune.data import CAT2ID, CATEGORIES, NUM_CATEGORIES, NUM_SPECIFICITY # noqa: E402
|
||
from src.finetune.eval import SPEC_LABELS, compute_all_metrics, load_holdout_data # noqa: E402
|
||
from src.finetune.model import ordinal_predict # noqa: E402
|
||
from scripts.quantize_sweep import ( # noqa: E402
|
||
BENCHMARKS, BATCH_SIZE, HOLDOUT, MAX_SEQ, PARAGRAPHS, WARMUP_BATCHES,
|
||
_build_model, evaluate_predictions,
|
||
)
|
||
|
||
OUTPUT_DIR = ROOT.parent / "results/eval/onnx"
|
||
ONNX_DIR = OUTPUT_DIR / "models"
|
||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
ONNX_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
# Export
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
|
||
class _Wrap(nn.Module):
|
||
def __init__(self, model):
|
||
super().__init__()
|
||
self.model = model
|
||
|
||
def forward(self, input_ids, attention_mask):
|
||
out = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
||
return out["category_logits"], out["specificity_logits"]
|
||
|
||
|
||
def export_fp32(out_path: Path, sample_batch: int = 4, sample_seq: int = 64) -> None:
|
||
print(f" building fp32 torch model...")
|
||
model, tokenizer = _build_model(torch.float32, attn_impl="sdpa")
|
||
model = model.cuda().eval()
|
||
wrap = _Wrap(model).cuda().eval()
|
||
|
||
dummy_text = ["the company maintains a cybersecurity program overseen by the board"] * sample_batch
|
||
enc = tokenizer(
|
||
dummy_text, padding="max_length", max_length=sample_seq,
|
||
truncation=True, return_tensors="pt",
|
||
).to("cuda")
|
||
|
||
print(f" exporting → {out_path}")
|
||
# Legacy TorchScript exporter (dynamo=False). The dynamo path produces a
|
||
# graph with 56+ Memcpy nodes when run on CUDAExecutionProvider, blowing
|
||
# latency 8× and VRAM 4× over native torch — unusable. The legacy
|
||
# exporter emits clean Gemm/MatMul/LayerNorm nodes ORT can fuse.
|
||
torch.onnx.export(
|
||
wrap,
|
||
(enc["input_ids"], enc["attention_mask"]),
|
||
str(out_path),
|
||
input_names=["input_ids", "attention_mask"],
|
||
output_names=["cat_logits", "spec_logits"],
|
||
dynamic_axes={
|
||
"input_ids": {0: "batch", 1: "seq"},
|
||
"attention_mask": {0: "batch", 1: "seq"},
|
||
"cat_logits": {0: "batch"},
|
||
"spec_logits": {0: "batch"},
|
||
},
|
||
opset_version=17,
|
||
dynamo=False,
|
||
do_constant_folding=True,
|
||
)
|
||
|
||
del wrap, model
|
||
gc.collect()
|
||
torch.cuda.empty_cache()
|
||
|
||
|
||
def convert_fp16(fp32_path: Path, fp16_path: Path) -> None:
|
||
"""Convert an fp32 ONNX model to fp16 via onnxconverter_common."""
|
||
import onnx
|
||
from onnxconverter_common import float16
|
||
|
||
print(f" loading {fp32_path}")
|
||
model = onnx.load(str(fp32_path), load_external_data=True)
|
||
print(f" converting to fp16...")
|
||
model_fp16 = float16.convert_float_to_float16(
|
||
model, keep_io_types=False, disable_shape_infer=True,
|
||
)
|
||
print(f" saving → {fp16_path}")
|
||
onnx.save_model(
|
||
model_fp16, str(fp16_path),
|
||
save_as_external_data=True,
|
||
all_tensors_to_one_file=True,
|
||
location=fp16_path.name + ".data",
|
||
size_threshold=1024,
|
||
)
|
||
|
||
|
||
def quantize_int8_dynamic(fp32_path: Path, int8_path: Path) -> None:
|
||
"""Dynamic int8 quantization (weights → int8, activations on the fly).
|
||
|
||
Two shape-inference paths in the ORT quantizer choke on the dynamo
|
||
export of ModernBERT-large:
|
||
|
||
1. `SymbolicShapeInference._infer_Range` asserts on the dynamic limit
|
||
input emitted by RoPE (`assert len(x) == 1` in `as_scalar`).
|
||
2. `onnx.shape_inference.infer_shapes_path` (C++) raises a (1024)/(7)
|
||
dim mismatch on the category head Gemm — the dynamo decomposition
|
||
leaves a dimension hint the C++ inferencer disagrees with.
|
||
|
||
The skip flags on `quant_pre_process` are ignored (it always runs
|
||
`SymbolicShapeInference.infer_shapes`), and `ONNXQuantizer.__init__`
|
||
calls `save_and_reload_model_with_shape_infer` unconditionally. We
|
||
monkey-patch both to no-ops, then run `quantize_dynamic` restricted to
|
||
MatMul ops (the only nodes we want quantized anyway).
|
||
"""
|
||
import onnx
|
||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||
from onnxruntime.quantization import quant_utils
|
||
from onnxruntime.tools import symbolic_shape_infer as sym
|
||
|
||
# No-op the broken shape passes.
|
||
original_save_reload = quant_utils.save_and_reload_model_with_shape_infer
|
||
|
||
def _passthrough(model):
|
||
return model
|
||
|
||
quant_utils.save_and_reload_model_with_shape_infer = _passthrough
|
||
# Some imports cache the symbol — patch the onnx_quantizer module too.
|
||
import onnxruntime.quantization.onnx_quantizer as oq
|
||
oq.save_and_reload_model_with_shape_infer = _passthrough
|
||
|
||
try:
|
||
print(f" quantizing {fp32_path} → {int8_path}")
|
||
quantize_dynamic(
|
||
model_input=str(fp32_path),
|
||
model_output=str(int8_path),
|
||
weight_type=QuantType.QInt8,
|
||
per_channel=True,
|
||
reduce_range=False,
|
||
op_types_to_quantize=["MatMul", "Gemm"],
|
||
use_external_data_format=True,
|
||
extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT},
|
||
)
|
||
finally:
|
||
quant_utils.save_and_reload_model_with_shape_infer = original_save_reload
|
||
oq.save_and_reload_model_with_shape_infer = original_save_reload
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
# Inference + metrics
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
|
||
def _files_size(model_path: Path) -> int:
|
||
"""Sum of model.onnx + any external .data files in the same dir."""
|
||
total = model_path.stat().st_size
|
||
for sib in model_path.parent.iterdir():
|
||
if sib.name.startswith(model_path.name) and sib != model_path:
|
||
total += sib.stat().st_size
|
||
return total
|
||
|
||
|
||
def run_onnx(model_path: Path, texts: list[str], use_cuda: bool = True) -> dict:
|
||
import onnxruntime as ort
|
||
from transformers import AutoTokenizer
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(
|
||
"../checkpoints/finetune/iter1-independent/final"
|
||
)
|
||
|
||
so = ort.SessionOptions()
|
||
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||
providers = (
|
||
["CUDAExecutionProvider", "CPUExecutionProvider"] if use_cuda
|
||
else ["CPUExecutionProvider"]
|
||
)
|
||
|
||
free_before, total_vram = torch.cuda.mem_get_info()
|
||
sess = ort.InferenceSession(str(model_path), so, providers=providers)
|
||
free_after_load, _ = torch.cuda.mem_get_info()
|
||
load_vram_mb = (free_before - free_after_load) / (1024 ** 2)
|
||
|
||
# Warmup
|
||
warm_enc = tokenizer(
|
||
texts[:BATCH_SIZE], truncation=True, max_length=MAX_SEQ,
|
||
padding="longest", return_tensors="np",
|
||
)
|
||
warm_inputs = {
|
||
"input_ids": warm_enc["input_ids"].astype(np.int64),
|
||
"attention_mask": warm_enc["attention_mask"].astype(np.int64),
|
||
}
|
||
for _ in range(WARMUP_BATCHES):
|
||
sess.run(None, warm_inputs)
|
||
|
||
free_after_warm, _ = torch.cuda.mem_get_info()
|
||
peak_vram_mb = (free_before - free_after_warm) / (1024 ** 2)
|
||
|
||
cat_logits_list = []
|
||
spec_logits_list = []
|
||
total_time = 0.0
|
||
for i in range(0, len(texts), BATCH_SIZE):
|
||
batch = texts[i : i + BATCH_SIZE]
|
||
enc = tokenizer(
|
||
batch, truncation=True, max_length=MAX_SEQ,
|
||
padding="longest", return_tensors="np",
|
||
)
|
||
inputs = {
|
||
"input_ids": enc["input_ids"].astype(np.int64),
|
||
"attention_mask": enc["attention_mask"].astype(np.int64),
|
||
}
|
||
t0 = time.perf_counter()
|
||
out = sess.run(None, inputs)
|
||
total_time += time.perf_counter() - t0
|
||
cat_logits_list.append(torch.from_numpy(out[0].astype(np.float32)))
|
||
spec_logits_list.append(torch.from_numpy(out[1].astype(np.float32)))
|
||
|
||
free_end, _ = torch.cuda.mem_get_info()
|
||
peak_vram_mb = max(peak_vram_mb, (free_before - free_end) / (1024 ** 2))
|
||
|
||
del sess
|
||
gc.collect()
|
||
torch.cuda.empty_cache()
|
||
|
||
return {
|
||
"cat_logits": torch.cat(cat_logits_list),
|
||
"spec_logits": torch.cat(spec_logits_list),
|
||
"ms_per_sample": (total_time / len(texts)) * 1000,
|
||
"throughput": len(texts) / total_time,
|
||
"peak_vram_mb": peak_vram_mb,
|
||
"load_vram_mb": load_vram_mb,
|
||
"providers": providers,
|
||
}
|
||
|
||
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
# Driver
|
||
# ──────────────────────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
print("loading holdout...")
|
||
records = load_holdout_data(
|
||
str(PARAGRAPHS), str(HOLDOUT), {k: str(v) for k, v in BENCHMARKS.items()},
|
||
)
|
||
texts = [r["text"] for r in records]
|
||
print(f" {len(records)} paragraphs")
|
||
|
||
fp32_path = ONNX_DIR / "model_fp32.onnx"
|
||
fp16_path = ONNX_DIR / "model_fp16.onnx"
|
||
int8_path = ONNX_DIR / "model_int8_dyn.onnx"
|
||
|
||
# ── Export fp32 (source for both fp16 and int8 quant) ──
|
||
if not fp32_path.exists():
|
||
print("\n══ exporting fp32 ONNX")
|
||
export_fp32(fp32_path)
|
||
else:
|
||
print(f"\n══ reusing existing {fp32_path}")
|
||
|
||
# ── fp16 conversion ──
|
||
if not fp16_path.exists():
|
||
print("\n══ converting → fp16 ONNX")
|
||
convert_fp16(fp32_path, fp16_path)
|
||
else:
|
||
print(f"\n══ reusing existing {fp16_path}")
|
||
|
||
# ── int8 dynamic quantization ──
|
||
if not int8_path.exists():
|
||
print("\n══ quantizing → int8 dynamic ONNX")
|
||
quantize_int8_dynamic(fp32_path, int8_path)
|
||
else:
|
||
print(f"\n══ reusing existing {int8_path}")
|
||
|
||
summary = []
|
||
variants = [
|
||
("onnx-fp32", fp32_path),
|
||
("onnx-fp16", fp16_path),
|
||
("onnx-int8-dyn", int8_path),
|
||
]
|
||
for name, path in variants:
|
||
print(f"\n══ {name} — {path.name}")
|
||
size_mb = _files_size(path) / 1e6
|
||
print(f" on-disk size: {size_mb:.1f} MB")
|
||
try:
|
||
inf = run_onnx(path, texts, use_cuda=True)
|
||
print(
|
||
f" latency {inf['ms_per_sample']:.2f} ms/sample, "
|
||
f"throughput {inf['throughput']:.0f}/s, "
|
||
f"peak VRAM {inf['peak_vram_mb']:.0f} MB "
|
||
f"(load {inf['load_vram_mb']:.0f} MB)"
|
||
)
|
||
row = {
|
||
"variant": name,
|
||
"model_mb": size_mb,
|
||
"ms_per_sample": inf["ms_per_sample"],
|
||
"throughput_per_s": inf["throughput"],
|
||
"peak_vram_mb": inf["peak_vram_mb"],
|
||
"load_vram_mb": inf["load_vram_mb"],
|
||
}
|
||
for ref in BENCHMARKS:
|
||
m = evaluate_predictions(inf["cat_logits"], inf["spec_logits"], records, ref)
|
||
print(
|
||
f" vs {ref}: cat F1={m['cat_macro_f1']:.4f}, "
|
||
f"spec F1={m['spec_macro_f1']:.4f}, QWK={m['spec_qwk']:.4f}, "
|
||
f"cat ECE={m['cat_ece']:.4f}, spec ECE={m['spec_ece']:.4f}"
|
||
)
|
||
row[f"{ref}_cat_f1"] = m["cat_macro_f1"]
|
||
row[f"{ref}_spec_f1"] = m["spec_macro_f1"]
|
||
row[f"{ref}_cat_mcc"] = m["cat_mcc"]
|
||
row[f"{ref}_spec_qwk"] = m["spec_qwk"]
|
||
row[f"{ref}_spec_mae"] = m["spec_mae"]
|
||
row[f"{ref}_cat_ece"] = m["cat_ece"]
|
||
row[f"{ref}_spec_ece"] = m["spec_ece"]
|
||
summary.append(row)
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
summary.append({"variant": name, "error": f"{type(e).__name__}: {e}"})
|
||
|
||
summary_path = OUTPUT_DIR / "summary.json"
|
||
with open(summary_path, "w") as f:
|
||
json.dump(summary, f, indent=2, default=str)
|
||
print(f"\nsummary → {summary_path}")
|
||
|
||
print("\n" + "=" * 110)
|
||
print(f"{'variant':<18} {'MB':>9} {'ms/samp':>9} {'throughput':>11} "
|
||
f"{'VRAM MB':>9} {'cat F1':>9} {'spec F1':>9} {'spec QWK':>9}")
|
||
print("-" * 110)
|
||
for r in summary:
|
||
if "error" in r:
|
||
print(f"{r['variant']:<18} ERROR: {r['error']}")
|
||
continue
|
||
print(
|
||
f"{r['variant']:<18} {r['model_mb']:>9.1f} {r['ms_per_sample']:>9.2f} "
|
||
f"{r['throughput_per_s']:>11.0f} {r['peak_vram_mb']:>9.0f} "
|
||
f"{r['GPT-5.4_cat_f1']:>9.4f} {r['GPT-5.4_spec_f1']:>9.4f} "
|
||
f"{r['GPT-5.4_spec_qwk']:>9.4f}"
|
||
)
|
||
print("=" * 110)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|