"""ONNX export + eval for the iter1-independent ModernBERT-large checkpoint. Variants: onnx-fp32 — straight torch.onnx.export from the fp32 model onnx-fp16 — fp32 export converted to fp16 via onnxconverter_common (proxy for bf16; ORT does not support bf16 inference natively) onnx-int8-dyn — dynamic int8 quantization of the fp32 graph via onnxruntime.quantization.quantize_dynamic (weights in int8, activations quantized at runtime) For each variant: - latency (ms/sample, batch=64, 5 warmup batches) - peak GPU memory delta around the session (free-mem snapshot) - on-disk size of model.onnx + model.onnx.data - cat / spec macro F1, QWK, ECE on the 1,200-paragraph holdout against GPT-5.4 + Opus-4.6 proxy gold Usage: bun run py:onnx """ from __future__ import annotations import gc import json import os import sys import time from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT)) from src.finetune.data import CAT2ID, CATEGORIES, NUM_CATEGORIES, NUM_SPECIFICITY # noqa: E402 from src.finetune.eval import SPEC_LABELS, compute_all_metrics, load_holdout_data # noqa: E402 from src.finetune.model import ordinal_predict # noqa: E402 from scripts.quantize_sweep import ( # noqa: E402 BENCHMARKS, BATCH_SIZE, HOLDOUT, MAX_SEQ, PARAGRAPHS, WARMUP_BATCHES, _build_model, evaluate_predictions, ) OUTPUT_DIR = ROOT.parent / "results/eval/onnx" ONNX_DIR = OUTPUT_DIR / "models" OUTPUT_DIR.mkdir(parents=True, exist_ok=True) ONNX_DIR.mkdir(parents=True, exist_ok=True) # ────────────────────────────────────────────────────────────────────── # Export # ────────────────────────────────────────────────────────────────────── class _Wrap(nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, input_ids, attention_mask): out = self.model(input_ids=input_ids, attention_mask=attention_mask) return out["category_logits"], out["specificity_logits"] def export_fp32(out_path: Path, sample_batch: int = 4, sample_seq: int = 64) -> None: print(f" building fp32 torch model...") model, tokenizer = _build_model(torch.float32, attn_impl="sdpa") model = model.cuda().eval() wrap = _Wrap(model).cuda().eval() dummy_text = ["the company maintains a cybersecurity program overseen by the board"] * sample_batch enc = tokenizer( dummy_text, padding="max_length", max_length=sample_seq, truncation=True, return_tensors="pt", ).to("cuda") print(f" exporting → {out_path}") # Legacy TorchScript exporter (dynamo=False). The dynamo path produces a # graph with 56+ Memcpy nodes when run on CUDAExecutionProvider, blowing # latency 8× and VRAM 4× over native torch — unusable. The legacy # exporter emits clean Gemm/MatMul/LayerNorm nodes ORT can fuse. torch.onnx.export( wrap, (enc["input_ids"], enc["attention_mask"]), str(out_path), input_names=["input_ids", "attention_mask"], output_names=["cat_logits", "spec_logits"], dynamic_axes={ "input_ids": {0: "batch", 1: "seq"}, "attention_mask": {0: "batch", 1: "seq"}, "cat_logits": {0: "batch"}, "spec_logits": {0: "batch"}, }, opset_version=17, dynamo=False, do_constant_folding=True, ) del wrap, model gc.collect() torch.cuda.empty_cache() def convert_fp16(fp32_path: Path, fp16_path: Path) -> None: """Convert an fp32 ONNX model to fp16 via onnxconverter_common.""" import onnx from onnxconverter_common import float16 print(f" loading {fp32_path}") model = onnx.load(str(fp32_path), load_external_data=True) print(f" converting to fp16...") model_fp16 = float16.convert_float_to_float16( model, keep_io_types=False, disable_shape_infer=True, ) print(f" saving → {fp16_path}") onnx.save_model( model_fp16, str(fp16_path), save_as_external_data=True, all_tensors_to_one_file=True, location=fp16_path.name + ".data", size_threshold=1024, ) def quantize_int8_dynamic(fp32_path: Path, int8_path: Path) -> None: """Dynamic int8 quantization (weights → int8, activations on the fly). Two shape-inference paths in the ORT quantizer choke on the dynamo export of ModernBERT-large: 1. `SymbolicShapeInference._infer_Range` asserts on the dynamic limit input emitted by RoPE (`assert len(x) == 1` in `as_scalar`). 2. `onnx.shape_inference.infer_shapes_path` (C++) raises a (1024)/(7) dim mismatch on the category head Gemm — the dynamo decomposition leaves a dimension hint the C++ inferencer disagrees with. The skip flags on `quant_pre_process` are ignored (it always runs `SymbolicShapeInference.infer_shapes`), and `ONNXQuantizer.__init__` calls `save_and_reload_model_with_shape_infer` unconditionally. We monkey-patch both to no-ops, then run `quantize_dynamic` restricted to MatMul ops (the only nodes we want quantized anyway). """ import onnx from onnxruntime.quantization import QuantType, quantize_dynamic from onnxruntime.quantization import quant_utils from onnxruntime.tools import symbolic_shape_infer as sym # No-op the broken shape passes. original_save_reload = quant_utils.save_and_reload_model_with_shape_infer def _passthrough(model): return model quant_utils.save_and_reload_model_with_shape_infer = _passthrough # Some imports cache the symbol — patch the onnx_quantizer module too. import onnxruntime.quantization.onnx_quantizer as oq oq.save_and_reload_model_with_shape_infer = _passthrough try: print(f" quantizing {fp32_path} → {int8_path}") quantize_dynamic( model_input=str(fp32_path), model_output=str(int8_path), weight_type=QuantType.QInt8, per_channel=True, reduce_range=False, op_types_to_quantize=["MatMul", "Gemm"], use_external_data_format=True, extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT}, ) finally: quant_utils.save_and_reload_model_with_shape_infer = original_save_reload oq.save_and_reload_model_with_shape_infer = original_save_reload # ────────────────────────────────────────────────────────────────────── # Inference + metrics # ────────────────────────────────────────────────────────────────────── def _files_size(model_path: Path) -> int: """Sum of model.onnx + any external .data files in the same dir.""" total = model_path.stat().st_size for sib in model_path.parent.iterdir(): if sib.name.startswith(model_path.name) and sib != model_path: total += sib.stat().st_size return total def run_onnx(model_path: Path, texts: list[str], use_cuda: bool = True) -> dict: import onnxruntime as ort from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( "../checkpoints/finetune/iter1-independent/final" ) so = ort.SessionOptions() so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL providers = ( ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_cuda else ["CPUExecutionProvider"] ) free_before, total_vram = torch.cuda.mem_get_info() sess = ort.InferenceSession(str(model_path), so, providers=providers) free_after_load, _ = torch.cuda.mem_get_info() load_vram_mb = (free_before - free_after_load) / (1024 ** 2) # Warmup warm_enc = tokenizer( texts[:BATCH_SIZE], truncation=True, max_length=MAX_SEQ, padding="longest", return_tensors="np", ) warm_inputs = { "input_ids": warm_enc["input_ids"].astype(np.int64), "attention_mask": warm_enc["attention_mask"].astype(np.int64), } for _ in range(WARMUP_BATCHES): sess.run(None, warm_inputs) free_after_warm, _ = torch.cuda.mem_get_info() peak_vram_mb = (free_before - free_after_warm) / (1024 ** 2) cat_logits_list = [] spec_logits_list = [] total_time = 0.0 for i in range(0, len(texts), BATCH_SIZE): batch = texts[i : i + BATCH_SIZE] enc = tokenizer( batch, truncation=True, max_length=MAX_SEQ, padding="longest", return_tensors="np", ) inputs = { "input_ids": enc["input_ids"].astype(np.int64), "attention_mask": enc["attention_mask"].astype(np.int64), } t0 = time.perf_counter() out = sess.run(None, inputs) total_time += time.perf_counter() - t0 cat_logits_list.append(torch.from_numpy(out[0].astype(np.float32))) spec_logits_list.append(torch.from_numpy(out[1].astype(np.float32))) free_end, _ = torch.cuda.mem_get_info() peak_vram_mb = max(peak_vram_mb, (free_before - free_end) / (1024 ** 2)) del sess gc.collect() torch.cuda.empty_cache() return { "cat_logits": torch.cat(cat_logits_list), "spec_logits": torch.cat(spec_logits_list), "ms_per_sample": (total_time / len(texts)) * 1000, "throughput": len(texts) / total_time, "peak_vram_mb": peak_vram_mb, "load_vram_mb": load_vram_mb, "providers": providers, } # ────────────────────────────────────────────────────────────────────── # Driver # ────────────────────────────────────────────────────────────────────── def main(): print("loading holdout...") records = load_holdout_data( str(PARAGRAPHS), str(HOLDOUT), {k: str(v) for k, v in BENCHMARKS.items()}, ) texts = [r["text"] for r in records] print(f" {len(records)} paragraphs") fp32_path = ONNX_DIR / "model_fp32.onnx" fp16_path = ONNX_DIR / "model_fp16.onnx" int8_path = ONNX_DIR / "model_int8_dyn.onnx" # ── Export fp32 (source for both fp16 and int8 quant) ── if not fp32_path.exists(): print("\n══ exporting fp32 ONNX") export_fp32(fp32_path) else: print(f"\n══ reusing existing {fp32_path}") # ── fp16 conversion ── if not fp16_path.exists(): print("\n══ converting → fp16 ONNX") convert_fp16(fp32_path, fp16_path) else: print(f"\n══ reusing existing {fp16_path}") # ── int8 dynamic quantization ── if not int8_path.exists(): print("\n══ quantizing → int8 dynamic ONNX") quantize_int8_dynamic(fp32_path, int8_path) else: print(f"\n══ reusing existing {int8_path}") summary = [] variants = [ ("onnx-fp32", fp32_path), ("onnx-fp16", fp16_path), ("onnx-int8-dyn", int8_path), ] for name, path in variants: print(f"\n══ {name} — {path.name}") size_mb = _files_size(path) / 1e6 print(f" on-disk size: {size_mb:.1f} MB") try: inf = run_onnx(path, texts, use_cuda=True) print( f" latency {inf['ms_per_sample']:.2f} ms/sample, " f"throughput {inf['throughput']:.0f}/s, " f"peak VRAM {inf['peak_vram_mb']:.0f} MB " f"(load {inf['load_vram_mb']:.0f} MB)" ) row = { "variant": name, "model_mb": size_mb, "ms_per_sample": inf["ms_per_sample"], "throughput_per_s": inf["throughput"], "peak_vram_mb": inf["peak_vram_mb"], "load_vram_mb": inf["load_vram_mb"], } for ref in BENCHMARKS: m = evaluate_predictions(inf["cat_logits"], inf["spec_logits"], records, ref) print( f" vs {ref}: cat F1={m['cat_macro_f1']:.4f}, " f"spec F1={m['spec_macro_f1']:.4f}, QWK={m['spec_qwk']:.4f}, " f"cat ECE={m['cat_ece']:.4f}, spec ECE={m['spec_ece']:.4f}" ) row[f"{ref}_cat_f1"] = m["cat_macro_f1"] row[f"{ref}_spec_f1"] = m["spec_macro_f1"] row[f"{ref}_cat_mcc"] = m["cat_mcc"] row[f"{ref}_spec_qwk"] = m["spec_qwk"] row[f"{ref}_spec_mae"] = m["spec_mae"] row[f"{ref}_cat_ece"] = m["cat_ece"] row[f"{ref}_spec_ece"] = m["spec_ece"] summary.append(row) except Exception as e: import traceback traceback.print_exc() summary.append({"variant": name, "error": f"{type(e).__name__}: {e}"}) summary_path = OUTPUT_DIR / "summary.json" with open(summary_path, "w") as f: json.dump(summary, f, indent=2, default=str) print(f"\nsummary → {summary_path}") print("\n" + "=" * 110) print(f"{'variant':<18} {'MB':>9} {'ms/samp':>9} {'throughput':>11} " f"{'VRAM MB':>9} {'cat F1':>9} {'spec F1':>9} {'spec QWK':>9}") print("-" * 110) for r in summary: if "error" in r: print(f"{r['variant']:<18} ERROR: {r['error']}") continue print( f"{r['variant']:<18} {r['model_mb']:>9.1f} {r['ms_per_sample']:>9.2f} " f"{r['throughput_per_s']:>11.0f} {r['peak_vram_mb']:>9.0f} " f"{r['GPT-5.4_cat_f1']:>9.4f} {r['GPT-5.4_spec_f1']:>9.4f} " f"{r['GPT-5.4_spec_qwk']:>9.4f}" ) print("=" * 110) if __name__ == "__main__": main()