SEC-cyBERT/python/scripts/temperature_scale.py

243 lines
9.1 KiB
Python

"""Temperature scaling calibration for the trained ensemble.
Approach:
1. Run the 3-seed ensemble on the held-out 1,200 paragraphs.
2. Use the val split (10% of training data) to fit a single scalar T per
head by minimizing NLL via LBFGS — this avoids touching the holdout
used for F1 reporting.
3. Apply T to holdout logits, recompute ECE.
Temperature scaling preserves argmax → all F1 metrics are unchanged.
Only the calibration metric (ECE) and probability distributions change.
"""
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from src.common.config import FinetuneConfig
from src.finetune.data import CAT2ID, CATEGORIES, load_finetune_data
from src.finetune.eval import (
EvalConfig,
SPEC_LABELS,
_ordinal_to_class_probs,
compute_ece,
load_holdout_data,
load_model,
run_inference,
)
from src.finetune.model import ordinal_predict, softmax_predict
CHECKPOINTS = {
"seed42": "../checkpoints/finetune/iter1-independent/final",
"seed69": "../checkpoints/finetune/iter1-seed69/final",
"seed420": "../checkpoints/finetune/iter1-seed420/final",
}
TRAIN_CONFIG = "configs/finetune/iter1-independent.yaml"
PARAGRAPHS_PATH = "../data/paragraphs/paragraphs-clean.patched.jsonl"
HOLDOUT_PATH = "../data/gold/v2-holdout-ids.json"
BENCHMARK_PATHS = {
"GPT-5.4": "../data/annotations/v2-bench/gpt-5.4.jsonl",
"Opus-4.6": "../data/annotations/v2-bench/opus-4.6.jsonl",
}
OUTPUT_DIR = Path("../results/eval/ensemble-3seed-tempscaled")
SPEC_HEAD = "independent"
def fit_temperature(logits: torch.Tensor, labels: torch.Tensor, mode: str) -> float:
"""Fit a single scalar T to minimize NLL on (logits, labels).
mode='ce' → standard categorical cross-entropy on softmax(logits/T).
mode='ordinal' → cumulative BCE on sigmoid(logits/T) against ordinal targets.
"""
T = torch.nn.Parameter(torch.ones(1, dtype=torch.float64))
optimizer = torch.optim.LBFGS([T], lr=0.05, max_iter=100)
logits = logits.double()
labels_t = labels.long()
if mode == "ordinal":
# Build cumulative targets: target[k] = 1 if label > k
K = logits.shape[1]
cum_targets = torch.zeros_like(logits)
for k in range(K):
cum_targets[:, k] = (labels_t > k).double()
def closure() -> torch.Tensor:
optimizer.zero_grad()
scaled = logits / T.clamp(min=1e-3)
if mode == "ce":
loss = F.cross_entropy(scaled, labels_t)
else:
loss = F.binary_cross_entropy_with_logits(scaled, cum_targets)
loss.backward()
return loss
optimizer.step(closure)
return float(T.detach().item())
def collect_ensemble_logits(records: list[dict], device: torch.device):
"""Run all 3 seeds on `records`, return averaged cat/spec logits."""
cat_stack, spec_stack = [], []
for name, ckpt_path in CHECKPOINTS.items():
print(f" [{name}] loading {ckpt_path}")
cfg = EvalConfig(
checkpoint_path=ckpt_path,
paragraphs_path=PARAGRAPHS_PATH,
holdout_path=HOLDOUT_PATH,
benchmark_paths=BENCHMARK_PATHS,
output_dir=str(OUTPUT_DIR),
specificity_head=SPEC_HEAD,
)
model, tokenizer = load_model(cfg, device)
inf = run_inference(
model, tokenizer, records,
cfg.max_seq_length, cfg.batch_size,
device, SPEC_HEAD,
)
cat_stack.append(inf["cat_logits"])
spec_stack.append(inf["spec_logits"])
del model
torch.cuda.empty_cache()
cat_logits = np.mean(np.stack(cat_stack, axis=0), axis=0)
spec_logits = np.mean(np.stack(spec_stack, axis=0), axis=0)
return cat_logits, spec_logits
def load_val_records(tokenizer):
"""Load the val split as plain text records compatible with run_inference."""
fcfg = FinetuneConfig.from_yaml(TRAIN_CONFIG)
splits = load_finetune_data(
paragraphs_path=fcfg.data.paragraphs_path,
consensus_path=fcfg.data.consensus_path,
quality_path=fcfg.data.quality_path,
holdout_path=fcfg.data.holdout_path,
max_seq_length=fcfg.data.max_seq_length,
validation_split=fcfg.data.validation_split,
tokenizer=tokenizer,
seed=fcfg.training.seed,
)
val = splits["test"]
# Reconstruct text from input_ids so run_inference can re-tokenize
records = []
for i in range(len(val)):
text = tokenizer.decode(val[i]["input_ids"], skip_special_tokens=True)
records.append({
"text": text,
"category_label": val[i]["category_labels"],
"specificity_label": val[i]["specificity_labels"],
})
return records
def main() -> None:
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n Device: {device}")
# ── 1. Load val split via tokenizer from seed42 ──
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINTS["seed42"])
print("\n Loading val split for temperature fitting...")
val_records = load_val_records(tokenizer)
print(f" Val samples: {len(val_records)}")
# Subsample to avoid full ensemble pass on 7K samples (overkill for fitting T)
rng = np.random.default_rng(0)
if len(val_records) > 2000:
idx = rng.choice(len(val_records), 2000, replace=False)
val_records = [val_records[i] for i in idx]
print(f" Subsampled to {len(val_records)} for T fitting")
# ── 2. Run ensemble on val ──
print("\n Running ensemble on val for T fitting...")
val_cat_logits, val_spec_logits = collect_ensemble_logits(val_records, device)
val_cat_labels = torch.tensor([r["category_label"] for r in val_records])
val_spec_labels = torch.tensor([r["specificity_label"] for r in val_records])
# ── 3. Fit T on val ──
T_cat = fit_temperature(torch.from_numpy(val_cat_logits), val_cat_labels, mode="ce")
T_spec = fit_temperature(torch.from_numpy(val_spec_logits), val_spec_labels, mode="ordinal")
print(f"\n Fitted T_cat = {T_cat:.4f}")
print(f" Fitted T_spec = {T_spec:.4f}")
# ── 4. Run ensemble on holdout ──
print("\n Running ensemble on holdout...")
holdout_records = load_holdout_data(PARAGRAPHS_PATH, HOLDOUT_PATH, BENCHMARK_PATHS)
h_cat_logits, h_spec_logits = collect_ensemble_logits(holdout_records, device)
# ── 5. Apply temperature, recompute ECE per benchmark ──
h_cat_logits_t = torch.from_numpy(h_cat_logits)
h_spec_logits_t = torch.from_numpy(h_spec_logits)
cat_probs_pre = F.softmax(h_cat_logits_t, dim=1).numpy()
cat_probs_post = F.softmax(h_cat_logits_t / T_cat, dim=1).numpy()
spec_probs_pre = _ordinal_to_class_probs(h_spec_logits_t).numpy()
spec_probs_post = _ordinal_to_class_probs(h_spec_logits_t / T_spec).numpy()
# Predictions are unchanged (argmax invariant for cat; ordinal threshold at 0 invariant)
cat_preds = h_cat_logits_t.argmax(dim=1).numpy()
spec_preds = ordinal_predict(h_spec_logits_t).numpy()
summary = {
"T_cat": T_cat,
"T_spec": T_spec,
"per_benchmark": {},
}
for ref_name in BENCHMARK_PATHS:
cat_labels, spec_labels = [], []
cat_idx, spec_idx = [], []
for i, rec in enumerate(holdout_records):
bench = rec["benchmark_labels"].get(ref_name)
if bench is None:
continue
cat_labels.append(CAT2ID[bench["category"]])
spec_labels.append(bench["specificity"] - 1)
cat_idx.append(i)
spec_idx.append(i)
cat_labels = np.array(cat_labels)
spec_labels = np.array(spec_labels)
cat_idx = np.array(cat_idx)
spec_idx = np.array(spec_idx)
ece_cat_pre, _ = compute_ece(cat_probs_pre[cat_idx], cat_labels)
ece_cat_post, _ = compute_ece(cat_probs_post[cat_idx], cat_labels)
ece_spec_pre, _ = compute_ece(spec_probs_pre[spec_idx], spec_labels)
ece_spec_post, _ = compute_ece(spec_probs_post[spec_idx], spec_labels)
# Sanity check: predictions unchanged
cat_match = (cat_preds[cat_idx] == cat_probs_post[cat_idx].argmax(axis=1)).all()
spec_match = (spec_preds[spec_idx] == spec_probs_post[spec_idx].argmax(axis=1)).all()
print(f"\n {ref_name}")
print(f" Cat ECE: {ece_cat_pre:.4f}{ece_cat_post:.4f}{ece_cat_post - ece_cat_pre:+.4f})")
print(f" Spec ECE: {ece_spec_pre:.4f}{ece_spec_post:.4f}{ece_spec_post - ece_spec_pre:+.4f})")
print(f" Predictions preserved: cat={cat_match} spec={spec_match}")
summary["per_benchmark"][ref_name] = {
"ece_cat_pre": ece_cat_pre,
"ece_cat_post": ece_cat_post,
"ece_spec_pre": ece_spec_pre,
"ece_spec_post": ece_spec_post,
"cat_preds_preserved": bool(cat_match),
"spec_preds_preserved": bool(spec_match),
}
with open(OUTPUT_DIR / "temperature_scaling.json", "w") as f:
json.dump(summary, f, indent=2)
print(f"\n Saved {OUTPUT_DIR / 'temperature_scaling.json'}")
if __name__ == "__main__":
main()