"""Temperature scaling calibration for the trained ensemble. Approach: 1. Run the 3-seed ensemble on the held-out 1,200 paragraphs. 2. Use the val split (10% of training data) to fit a single scalar T per head by minimizing NLL via LBFGS — this avoids touching the holdout used for F1 reporting. 3. Apply T to holdout logits, recompute ECE. Temperature scaling preserves argmax → all F1 metrics are unchanged. Only the calibration metric (ECE) and probability distributions change. """ import json from pathlib import Path import numpy as np import torch import torch.nn.functional as F from transformers import AutoTokenizer from src.common.config import FinetuneConfig from src.finetune.data import CAT2ID, CATEGORIES, load_finetune_data from src.finetune.eval import ( EvalConfig, SPEC_LABELS, _ordinal_to_class_probs, compute_ece, load_holdout_data, load_model, run_inference, ) from src.finetune.model import ordinal_predict, softmax_predict CHECKPOINTS = { "seed42": "../checkpoints/finetune/iter1-independent/final", "seed69": "../checkpoints/finetune/iter1-seed69/final", "seed420": "../checkpoints/finetune/iter1-seed420/final", } TRAIN_CONFIG = "configs/finetune/iter1-independent.yaml" PARAGRAPHS_PATH = "../data/paragraphs/paragraphs-clean.patched.jsonl" HOLDOUT_PATH = "../data/gold/v2-holdout-ids.json" BENCHMARK_PATHS = { "GPT-5.4": "../data/annotations/v2-bench/gpt-5.4.jsonl", "Opus-4.6": "../data/annotations/v2-bench/opus-4.6.jsonl", } OUTPUT_DIR = Path("../results/eval/ensemble-3seed-tempscaled") SPEC_HEAD = "independent" def fit_temperature(logits: torch.Tensor, labels: torch.Tensor, mode: str) -> float: """Fit a single scalar T to minimize NLL on (logits, labels). mode='ce' → standard categorical cross-entropy on softmax(logits/T). mode='ordinal' → cumulative BCE on sigmoid(logits/T) against ordinal targets. """ T = torch.nn.Parameter(torch.ones(1, dtype=torch.float64)) optimizer = torch.optim.LBFGS([T], lr=0.05, max_iter=100) logits = logits.double() labels_t = labels.long() if mode == "ordinal": # Build cumulative targets: target[k] = 1 if label > k K = logits.shape[1] cum_targets = torch.zeros_like(logits) for k in range(K): cum_targets[:, k] = (labels_t > k).double() def closure() -> torch.Tensor: optimizer.zero_grad() scaled = logits / T.clamp(min=1e-3) if mode == "ce": loss = F.cross_entropy(scaled, labels_t) else: loss = F.binary_cross_entropy_with_logits(scaled, cum_targets) loss.backward() return loss optimizer.step(closure) return float(T.detach().item()) def collect_ensemble_logits(records: list[dict], device: torch.device): """Run all 3 seeds on `records`, return averaged cat/spec logits.""" cat_stack, spec_stack = [], [] for name, ckpt_path in CHECKPOINTS.items(): print(f" [{name}] loading {ckpt_path}") cfg = EvalConfig( checkpoint_path=ckpt_path, paragraphs_path=PARAGRAPHS_PATH, holdout_path=HOLDOUT_PATH, benchmark_paths=BENCHMARK_PATHS, output_dir=str(OUTPUT_DIR), specificity_head=SPEC_HEAD, ) model, tokenizer = load_model(cfg, device) inf = run_inference( model, tokenizer, records, cfg.max_seq_length, cfg.batch_size, device, SPEC_HEAD, ) cat_stack.append(inf["cat_logits"]) spec_stack.append(inf["spec_logits"]) del model torch.cuda.empty_cache() cat_logits = np.mean(np.stack(cat_stack, axis=0), axis=0) spec_logits = np.mean(np.stack(spec_stack, axis=0), axis=0) return cat_logits, spec_logits def load_val_records(tokenizer): """Load the val split as plain text records compatible with run_inference.""" fcfg = FinetuneConfig.from_yaml(TRAIN_CONFIG) splits = load_finetune_data( paragraphs_path=fcfg.data.paragraphs_path, consensus_path=fcfg.data.consensus_path, quality_path=fcfg.data.quality_path, holdout_path=fcfg.data.holdout_path, max_seq_length=fcfg.data.max_seq_length, validation_split=fcfg.data.validation_split, tokenizer=tokenizer, seed=fcfg.training.seed, ) val = splits["test"] # Reconstruct text from input_ids so run_inference can re-tokenize records = [] for i in range(len(val)): text = tokenizer.decode(val[i]["input_ids"], skip_special_tokens=True) records.append({ "text": text, "category_label": val[i]["category_labels"], "specificity_label": val[i]["specificity_labels"], }) return records def main() -> None: OUTPUT_DIR.mkdir(parents=True, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"\n Device: {device}") # ── 1. Load val split via tokenizer from seed42 ── tokenizer = AutoTokenizer.from_pretrained(CHECKPOINTS["seed42"]) print("\n Loading val split for temperature fitting...") val_records = load_val_records(tokenizer) print(f" Val samples: {len(val_records)}") # Subsample to avoid full ensemble pass on 7K samples (overkill for fitting T) rng = np.random.default_rng(0) if len(val_records) > 2000: idx = rng.choice(len(val_records), 2000, replace=False) val_records = [val_records[i] for i in idx] print(f" Subsampled to {len(val_records)} for T fitting") # ── 2. Run ensemble on val ── print("\n Running ensemble on val for T fitting...") val_cat_logits, val_spec_logits = collect_ensemble_logits(val_records, device) val_cat_labels = torch.tensor([r["category_label"] for r in val_records]) val_spec_labels = torch.tensor([r["specificity_label"] for r in val_records]) # ── 3. Fit T on val ── T_cat = fit_temperature(torch.from_numpy(val_cat_logits), val_cat_labels, mode="ce") T_spec = fit_temperature(torch.from_numpy(val_spec_logits), val_spec_labels, mode="ordinal") print(f"\n Fitted T_cat = {T_cat:.4f}") print(f" Fitted T_spec = {T_spec:.4f}") # ── 4. Run ensemble on holdout ── print("\n Running ensemble on holdout...") holdout_records = load_holdout_data(PARAGRAPHS_PATH, HOLDOUT_PATH, BENCHMARK_PATHS) h_cat_logits, h_spec_logits = collect_ensemble_logits(holdout_records, device) # ── 5. Apply temperature, recompute ECE per benchmark ── h_cat_logits_t = torch.from_numpy(h_cat_logits) h_spec_logits_t = torch.from_numpy(h_spec_logits) cat_probs_pre = F.softmax(h_cat_logits_t, dim=1).numpy() cat_probs_post = F.softmax(h_cat_logits_t / T_cat, dim=1).numpy() spec_probs_pre = _ordinal_to_class_probs(h_spec_logits_t).numpy() spec_probs_post = _ordinal_to_class_probs(h_spec_logits_t / T_spec).numpy() # Predictions are unchanged (argmax invariant for cat; ordinal threshold at 0 invariant) cat_preds = h_cat_logits_t.argmax(dim=1).numpy() spec_preds = ordinal_predict(h_spec_logits_t).numpy() summary = { "T_cat": T_cat, "T_spec": T_spec, "per_benchmark": {}, } for ref_name in BENCHMARK_PATHS: cat_labels, spec_labels = [], [] cat_idx, spec_idx = [], [] for i, rec in enumerate(holdout_records): bench = rec["benchmark_labels"].get(ref_name) if bench is None: continue cat_labels.append(CAT2ID[bench["category"]]) spec_labels.append(bench["specificity"] - 1) cat_idx.append(i) spec_idx.append(i) cat_labels = np.array(cat_labels) spec_labels = np.array(spec_labels) cat_idx = np.array(cat_idx) spec_idx = np.array(spec_idx) ece_cat_pre, _ = compute_ece(cat_probs_pre[cat_idx], cat_labels) ece_cat_post, _ = compute_ece(cat_probs_post[cat_idx], cat_labels) ece_spec_pre, _ = compute_ece(spec_probs_pre[spec_idx], spec_labels) ece_spec_post, _ = compute_ece(spec_probs_post[spec_idx], spec_labels) # Sanity check: predictions unchanged cat_match = (cat_preds[cat_idx] == cat_probs_post[cat_idx].argmax(axis=1)).all() spec_match = (spec_preds[spec_idx] == spec_probs_post[spec_idx].argmax(axis=1)).all() print(f"\n {ref_name}") print(f" Cat ECE: {ece_cat_pre:.4f} → {ece_cat_post:.4f} (Δ {ece_cat_post - ece_cat_pre:+.4f})") print(f" Spec ECE: {ece_spec_pre:.4f} → {ece_spec_post:.4f} (Δ {ece_spec_post - ece_spec_pre:+.4f})") print(f" Predictions preserved: cat={cat_match} spec={spec_match}") summary["per_benchmark"][ref_name] = { "ece_cat_pre": ece_cat_pre, "ece_cat_post": ece_cat_post, "ece_spec_pre": ece_spec_pre, "ece_spec_post": ece_spec_post, "cat_preds_preserved": bool(cat_match), "spec_preds_preserved": bool(spec_match), } with open(OUTPUT_DIR / "temperature_scaling.json", "w") as f: json.dump(summary, f, indent=2) print(f"\n Saved {OUTPUT_DIR / 'temperature_scaling.json'}") if __name__ == "__main__": main()