SEC-cyBERT/python/scripts/onnx_export_eval.py
Joey Eamigh 67beaede45
quantization + onnx sweeps
Phase 10.8: torchao/bnb quant sweep on iter1-independent. bf16 already
optimal; torchao int8-wo gives -19% VRAM at no F1 cost; all 4-bit
variants collapse (ModernBERT-large too quant-sensitive).

Phase 10.9: ONNX export + ORT eval. Legacy exporter only working path
(dynamo adds 56 Memcpy nodes); ORT fp32 -22% latency vs torch via
kernel fusion but bf16+flash-attn-2 still wins; fp16 broken on rotary;
dynamic int8 silently CPU-fallback + 0.5 F1 collapse.

Driver scripts wired to bun run py:quant / py:onnx; full reports at
results/eval/{quant,onnx}/REPORT.md.
2026-04-07 05:10:38 -04:00

370 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ONNX export + eval for the iter1-independent ModernBERT-large checkpoint.
Variants:
onnx-fp32 — straight torch.onnx.export from the fp32 model
onnx-fp16 — fp32 export converted to fp16 via onnxconverter_common
(proxy for bf16; ORT does not support bf16 inference natively)
onnx-int8-dyn — dynamic int8 quantization of the fp32 graph via
onnxruntime.quantization.quantize_dynamic (weights in int8,
activations quantized at runtime)
For each variant:
- latency (ms/sample, batch=64, 5 warmup batches)
- peak GPU memory delta around the session (free-mem snapshot)
- on-disk size of model.onnx + model.onnx.data
- cat / spec macro F1, QWK, ECE on the 1,200-paragraph holdout
against GPT-5.4 + Opus-4.6 proxy gold
Usage:
bun run py:onnx
"""
from __future__ import annotations
import gc
import json
import os
import sys
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
from src.finetune.data import CAT2ID, CATEGORIES, NUM_CATEGORIES, NUM_SPECIFICITY # noqa: E402
from src.finetune.eval import SPEC_LABELS, compute_all_metrics, load_holdout_data # noqa: E402
from src.finetune.model import ordinal_predict # noqa: E402
from scripts.quantize_sweep import ( # noqa: E402
BENCHMARKS, BATCH_SIZE, HOLDOUT, MAX_SEQ, PARAGRAPHS, WARMUP_BATCHES,
_build_model, evaluate_predictions,
)
OUTPUT_DIR = ROOT.parent / "results/eval/onnx"
ONNX_DIR = OUTPUT_DIR / "models"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
ONNX_DIR.mkdir(parents=True, exist_ok=True)
# ──────────────────────────────────────────────────────────────────────
# Export
# ──────────────────────────────────────────────────────────────────────
class _Wrap(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
out = self.model(input_ids=input_ids, attention_mask=attention_mask)
return out["category_logits"], out["specificity_logits"]
def export_fp32(out_path: Path, sample_batch: int = 4, sample_seq: int = 64) -> None:
print(f" building fp32 torch model...")
model, tokenizer = _build_model(torch.float32, attn_impl="sdpa")
model = model.cuda().eval()
wrap = _Wrap(model).cuda().eval()
dummy_text = ["the company maintains a cybersecurity program overseen by the board"] * sample_batch
enc = tokenizer(
dummy_text, padding="max_length", max_length=sample_seq,
truncation=True, return_tensors="pt",
).to("cuda")
print(f" exporting → {out_path}")
# Legacy TorchScript exporter (dynamo=False). The dynamo path produces a
# graph with 56+ Memcpy nodes when run on CUDAExecutionProvider, blowing
# latency 8× and VRAM 4× over native torch — unusable. The legacy
# exporter emits clean Gemm/MatMul/LayerNorm nodes ORT can fuse.
torch.onnx.export(
wrap,
(enc["input_ids"], enc["attention_mask"]),
str(out_path),
input_names=["input_ids", "attention_mask"],
output_names=["cat_logits", "spec_logits"],
dynamic_axes={
"input_ids": {0: "batch", 1: "seq"},
"attention_mask": {0: "batch", 1: "seq"},
"cat_logits": {0: "batch"},
"spec_logits": {0: "batch"},
},
opset_version=17,
dynamo=False,
do_constant_folding=True,
)
del wrap, model
gc.collect()
torch.cuda.empty_cache()
def convert_fp16(fp32_path: Path, fp16_path: Path) -> None:
"""Convert an fp32 ONNX model to fp16 via onnxconverter_common."""
import onnx
from onnxconverter_common import float16
print(f" loading {fp32_path}")
model = onnx.load(str(fp32_path), load_external_data=True)
print(f" converting to fp16...")
model_fp16 = float16.convert_float_to_float16(
model, keep_io_types=False, disable_shape_infer=True,
)
print(f" saving → {fp16_path}")
onnx.save_model(
model_fp16, str(fp16_path),
save_as_external_data=True,
all_tensors_to_one_file=True,
location=fp16_path.name + ".data",
size_threshold=1024,
)
def quantize_int8_dynamic(fp32_path: Path, int8_path: Path) -> None:
"""Dynamic int8 quantization (weights → int8, activations on the fly).
Two shape-inference paths in the ORT quantizer choke on the dynamo
export of ModernBERT-large:
1. `SymbolicShapeInference._infer_Range` asserts on the dynamic limit
input emitted by RoPE (`assert len(x) == 1` in `as_scalar`).
2. `onnx.shape_inference.infer_shapes_path` (C++) raises a (1024)/(7)
dim mismatch on the category head Gemm — the dynamo decomposition
leaves a dimension hint the C++ inferencer disagrees with.
The skip flags on `quant_pre_process` are ignored (it always runs
`SymbolicShapeInference.infer_shapes`), and `ONNXQuantizer.__init__`
calls `save_and_reload_model_with_shape_infer` unconditionally. We
monkey-patch both to no-ops, then run `quantize_dynamic` restricted to
MatMul ops (the only nodes we want quantized anyway).
"""
import onnx
from onnxruntime.quantization import QuantType, quantize_dynamic
from onnxruntime.quantization import quant_utils
from onnxruntime.tools import symbolic_shape_infer as sym
# No-op the broken shape passes.
original_save_reload = quant_utils.save_and_reload_model_with_shape_infer
def _passthrough(model):
return model
quant_utils.save_and_reload_model_with_shape_infer = _passthrough
# Some imports cache the symbol — patch the onnx_quantizer module too.
import onnxruntime.quantization.onnx_quantizer as oq
oq.save_and_reload_model_with_shape_infer = _passthrough
try:
print(f" quantizing {fp32_path}{int8_path}")
quantize_dynamic(
model_input=str(fp32_path),
model_output=str(int8_path),
weight_type=QuantType.QInt8,
per_channel=True,
reduce_range=False,
op_types_to_quantize=["MatMul", "Gemm"],
use_external_data_format=True,
extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT},
)
finally:
quant_utils.save_and_reload_model_with_shape_infer = original_save_reload
oq.save_and_reload_model_with_shape_infer = original_save_reload
# ──────────────────────────────────────────────────────────────────────
# Inference + metrics
# ──────────────────────────────────────────────────────────────────────
def _files_size(model_path: Path) -> int:
"""Sum of model.onnx + any external .data files in the same dir."""
total = model_path.stat().st_size
for sib in model_path.parent.iterdir():
if sib.name.startswith(model_path.name) and sib != model_path:
total += sib.stat().st_size
return total
def run_onnx(model_path: Path, texts: list[str], use_cuda: bool = True) -> dict:
import onnxruntime as ort
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"../checkpoints/finetune/iter1-independent/final"
)
so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
providers = (
["CUDAExecutionProvider", "CPUExecutionProvider"] if use_cuda
else ["CPUExecutionProvider"]
)
free_before, total_vram = torch.cuda.mem_get_info()
sess = ort.InferenceSession(str(model_path), so, providers=providers)
free_after_load, _ = torch.cuda.mem_get_info()
load_vram_mb = (free_before - free_after_load) / (1024 ** 2)
# Warmup
warm_enc = tokenizer(
texts[:BATCH_SIZE], truncation=True, max_length=MAX_SEQ,
padding="longest", return_tensors="np",
)
warm_inputs = {
"input_ids": warm_enc["input_ids"].astype(np.int64),
"attention_mask": warm_enc["attention_mask"].astype(np.int64),
}
for _ in range(WARMUP_BATCHES):
sess.run(None, warm_inputs)
free_after_warm, _ = torch.cuda.mem_get_info()
peak_vram_mb = (free_before - free_after_warm) / (1024 ** 2)
cat_logits_list = []
spec_logits_list = []
total_time = 0.0
for i in range(0, len(texts), BATCH_SIZE):
batch = texts[i : i + BATCH_SIZE]
enc = tokenizer(
batch, truncation=True, max_length=MAX_SEQ,
padding="longest", return_tensors="np",
)
inputs = {
"input_ids": enc["input_ids"].astype(np.int64),
"attention_mask": enc["attention_mask"].astype(np.int64),
}
t0 = time.perf_counter()
out = sess.run(None, inputs)
total_time += time.perf_counter() - t0
cat_logits_list.append(torch.from_numpy(out[0].astype(np.float32)))
spec_logits_list.append(torch.from_numpy(out[1].astype(np.float32)))
free_end, _ = torch.cuda.mem_get_info()
peak_vram_mb = max(peak_vram_mb, (free_before - free_end) / (1024 ** 2))
del sess
gc.collect()
torch.cuda.empty_cache()
return {
"cat_logits": torch.cat(cat_logits_list),
"spec_logits": torch.cat(spec_logits_list),
"ms_per_sample": (total_time / len(texts)) * 1000,
"throughput": len(texts) / total_time,
"peak_vram_mb": peak_vram_mb,
"load_vram_mb": load_vram_mb,
"providers": providers,
}
# ──────────────────────────────────────────────────────────────────────
# Driver
# ──────────────────────────────────────────────────────────────────────
def main():
print("loading holdout...")
records = load_holdout_data(
str(PARAGRAPHS), str(HOLDOUT), {k: str(v) for k, v in BENCHMARKS.items()},
)
texts = [r["text"] for r in records]
print(f" {len(records)} paragraphs")
fp32_path = ONNX_DIR / "model_fp32.onnx"
fp16_path = ONNX_DIR / "model_fp16.onnx"
int8_path = ONNX_DIR / "model_int8_dyn.onnx"
# ── Export fp32 (source for both fp16 and int8 quant) ──
if not fp32_path.exists():
print("\n══ exporting fp32 ONNX")
export_fp32(fp32_path)
else:
print(f"\n══ reusing existing {fp32_path}")
# ── fp16 conversion ──
if not fp16_path.exists():
print("\n══ converting → fp16 ONNX")
convert_fp16(fp32_path, fp16_path)
else:
print(f"\n══ reusing existing {fp16_path}")
# ── int8 dynamic quantization ──
if not int8_path.exists():
print("\n══ quantizing → int8 dynamic ONNX")
quantize_int8_dynamic(fp32_path, int8_path)
else:
print(f"\n══ reusing existing {int8_path}")
summary = []
variants = [
("onnx-fp32", fp32_path),
("onnx-fp16", fp16_path),
("onnx-int8-dyn", int8_path),
]
for name, path in variants:
print(f"\n══ {name}{path.name}")
size_mb = _files_size(path) / 1e6
print(f" on-disk size: {size_mb:.1f} MB")
try:
inf = run_onnx(path, texts, use_cuda=True)
print(
f" latency {inf['ms_per_sample']:.2f} ms/sample, "
f"throughput {inf['throughput']:.0f}/s, "
f"peak VRAM {inf['peak_vram_mb']:.0f} MB "
f"(load {inf['load_vram_mb']:.0f} MB)"
)
row = {
"variant": name,
"model_mb": size_mb,
"ms_per_sample": inf["ms_per_sample"],
"throughput_per_s": inf["throughput"],
"peak_vram_mb": inf["peak_vram_mb"],
"load_vram_mb": inf["load_vram_mb"],
}
for ref in BENCHMARKS:
m = evaluate_predictions(inf["cat_logits"], inf["spec_logits"], records, ref)
print(
f" vs {ref}: cat F1={m['cat_macro_f1']:.4f}, "
f"spec F1={m['spec_macro_f1']:.4f}, QWK={m['spec_qwk']:.4f}, "
f"cat ECE={m['cat_ece']:.4f}, spec ECE={m['spec_ece']:.4f}"
)
row[f"{ref}_cat_f1"] = m["cat_macro_f1"]
row[f"{ref}_spec_f1"] = m["spec_macro_f1"]
row[f"{ref}_cat_mcc"] = m["cat_mcc"]
row[f"{ref}_spec_qwk"] = m["spec_qwk"]
row[f"{ref}_spec_mae"] = m["spec_mae"]
row[f"{ref}_cat_ece"] = m["cat_ece"]
row[f"{ref}_spec_ece"] = m["spec_ece"]
summary.append(row)
except Exception as e:
import traceback
traceback.print_exc()
summary.append({"variant": name, "error": f"{type(e).__name__}: {e}"})
summary_path = OUTPUT_DIR / "summary.json"
with open(summary_path, "w") as f:
json.dump(summary, f, indent=2, default=str)
print(f"\nsummary → {summary_path}")
print("\n" + "=" * 110)
print(f"{'variant':<18} {'MB':>9} {'ms/samp':>9} {'throughput':>11} "
f"{'VRAM MB':>9} {'cat F1':>9} {'spec F1':>9} {'spec QWK':>9}")
print("-" * 110)
for r in summary:
if "error" in r:
print(f"{r['variant']:<18} ERROR: {r['error']}")
continue
print(
f"{r['variant']:<18} {r['model_mb']:>9.1f} {r['ms_per_sample']:>9.2f} "
f"{r['throughput_per_s']:>11.0f} {r['peak_vram_mb']:>9.0f} "
f"{r['GPT-5.4_cat_f1']:>9.4f} {r['GPT-5.4_spec_f1']:>9.4f} "
f"{r['GPT-5.4_spec_qwk']:>9.4f}"
)
print("=" * 110)
if __name__ == "__main__":
main()