243 lines
9.1 KiB
Python
243 lines
9.1 KiB
Python
"""Temperature scaling calibration for the trained ensemble.
|
|
|
|
Approach:
|
|
1. Run the 3-seed ensemble on the held-out 1,200 paragraphs.
|
|
2. Use the val split (10% of training data) to fit a single scalar T per
|
|
head by minimizing NLL via LBFGS — this avoids touching the holdout
|
|
used for F1 reporting.
|
|
3. Apply T to holdout logits, recompute ECE.
|
|
|
|
Temperature scaling preserves argmax → all F1 metrics are unchanged.
|
|
Only the calibration metric (ECE) and probability distributions change.
|
|
"""
|
|
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from transformers import AutoTokenizer
|
|
|
|
from src.common.config import FinetuneConfig
|
|
from src.finetune.data import CAT2ID, CATEGORIES, load_finetune_data
|
|
from src.finetune.eval import (
|
|
EvalConfig,
|
|
SPEC_LABELS,
|
|
_ordinal_to_class_probs,
|
|
compute_ece,
|
|
load_holdout_data,
|
|
load_model,
|
|
run_inference,
|
|
)
|
|
from src.finetune.model import ordinal_predict, softmax_predict
|
|
|
|
|
|
CHECKPOINTS = {
|
|
"seed42": "../checkpoints/finetune/iter1-independent/final",
|
|
"seed69": "../checkpoints/finetune/iter1-seed69/final",
|
|
"seed420": "../checkpoints/finetune/iter1-seed420/final",
|
|
}
|
|
TRAIN_CONFIG = "configs/finetune/iter1-independent.yaml"
|
|
PARAGRAPHS_PATH = "../data/paragraphs/paragraphs-clean.patched.jsonl"
|
|
HOLDOUT_PATH = "../data/gold/v2-holdout-ids.json"
|
|
BENCHMARK_PATHS = {
|
|
"GPT-5.4": "../data/annotations/v2-bench/gpt-5.4.jsonl",
|
|
"Opus-4.6": "../data/annotations/v2-bench/opus-4.6.jsonl",
|
|
}
|
|
OUTPUT_DIR = Path("../results/eval/ensemble-3seed-tempscaled")
|
|
SPEC_HEAD = "independent"
|
|
|
|
|
|
def fit_temperature(logits: torch.Tensor, labels: torch.Tensor, mode: str) -> float:
|
|
"""Fit a single scalar T to minimize NLL on (logits, labels).
|
|
|
|
mode='ce' → standard categorical cross-entropy on softmax(logits/T).
|
|
mode='ordinal' → cumulative BCE on sigmoid(logits/T) against ordinal targets.
|
|
"""
|
|
T = torch.nn.Parameter(torch.ones(1, dtype=torch.float64))
|
|
optimizer = torch.optim.LBFGS([T], lr=0.05, max_iter=100)
|
|
logits = logits.double()
|
|
labels_t = labels.long()
|
|
|
|
if mode == "ordinal":
|
|
# Build cumulative targets: target[k] = 1 if label > k
|
|
K = logits.shape[1]
|
|
cum_targets = torch.zeros_like(logits)
|
|
for k in range(K):
|
|
cum_targets[:, k] = (labels_t > k).double()
|
|
|
|
def closure() -> torch.Tensor:
|
|
optimizer.zero_grad()
|
|
scaled = logits / T.clamp(min=1e-3)
|
|
if mode == "ce":
|
|
loss = F.cross_entropy(scaled, labels_t)
|
|
else:
|
|
loss = F.binary_cross_entropy_with_logits(scaled, cum_targets)
|
|
loss.backward()
|
|
return loss
|
|
|
|
optimizer.step(closure)
|
|
return float(T.detach().item())
|
|
|
|
|
|
def collect_ensemble_logits(records: list[dict], device: torch.device):
|
|
"""Run all 3 seeds on `records`, return averaged cat/spec logits."""
|
|
cat_stack, spec_stack = [], []
|
|
for name, ckpt_path in CHECKPOINTS.items():
|
|
print(f" [{name}] loading {ckpt_path}")
|
|
cfg = EvalConfig(
|
|
checkpoint_path=ckpt_path,
|
|
paragraphs_path=PARAGRAPHS_PATH,
|
|
holdout_path=HOLDOUT_PATH,
|
|
benchmark_paths=BENCHMARK_PATHS,
|
|
output_dir=str(OUTPUT_DIR),
|
|
specificity_head=SPEC_HEAD,
|
|
)
|
|
model, tokenizer = load_model(cfg, device)
|
|
inf = run_inference(
|
|
model, tokenizer, records,
|
|
cfg.max_seq_length, cfg.batch_size,
|
|
device, SPEC_HEAD,
|
|
)
|
|
cat_stack.append(inf["cat_logits"])
|
|
spec_stack.append(inf["spec_logits"])
|
|
del model
|
|
torch.cuda.empty_cache()
|
|
|
|
cat_logits = np.mean(np.stack(cat_stack, axis=0), axis=0)
|
|
spec_logits = np.mean(np.stack(spec_stack, axis=0), axis=0)
|
|
return cat_logits, spec_logits
|
|
|
|
|
|
def load_val_records(tokenizer):
|
|
"""Load the val split as plain text records compatible with run_inference."""
|
|
fcfg = FinetuneConfig.from_yaml(TRAIN_CONFIG)
|
|
splits = load_finetune_data(
|
|
paragraphs_path=fcfg.data.paragraphs_path,
|
|
consensus_path=fcfg.data.consensus_path,
|
|
quality_path=fcfg.data.quality_path,
|
|
holdout_path=fcfg.data.holdout_path,
|
|
max_seq_length=fcfg.data.max_seq_length,
|
|
validation_split=fcfg.data.validation_split,
|
|
tokenizer=tokenizer,
|
|
seed=fcfg.training.seed,
|
|
)
|
|
val = splits["test"]
|
|
|
|
# Reconstruct text from input_ids so run_inference can re-tokenize
|
|
records = []
|
|
for i in range(len(val)):
|
|
text = tokenizer.decode(val[i]["input_ids"], skip_special_tokens=True)
|
|
records.append({
|
|
"text": text,
|
|
"category_label": val[i]["category_labels"],
|
|
"specificity_label": val[i]["specificity_labels"],
|
|
})
|
|
return records
|
|
|
|
|
|
def main() -> None:
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"\n Device: {device}")
|
|
|
|
# ── 1. Load val split via tokenizer from seed42 ──
|
|
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINTS["seed42"])
|
|
|
|
print("\n Loading val split for temperature fitting...")
|
|
val_records = load_val_records(tokenizer)
|
|
print(f" Val samples: {len(val_records)}")
|
|
|
|
# Subsample to avoid full ensemble pass on 7K samples (overkill for fitting T)
|
|
rng = np.random.default_rng(0)
|
|
if len(val_records) > 2000:
|
|
idx = rng.choice(len(val_records), 2000, replace=False)
|
|
val_records = [val_records[i] for i in idx]
|
|
print(f" Subsampled to {len(val_records)} for T fitting")
|
|
|
|
# ── 2. Run ensemble on val ──
|
|
print("\n Running ensemble on val for T fitting...")
|
|
val_cat_logits, val_spec_logits = collect_ensemble_logits(val_records, device)
|
|
val_cat_labels = torch.tensor([r["category_label"] for r in val_records])
|
|
val_spec_labels = torch.tensor([r["specificity_label"] for r in val_records])
|
|
|
|
# ── 3. Fit T on val ──
|
|
T_cat = fit_temperature(torch.from_numpy(val_cat_logits), val_cat_labels, mode="ce")
|
|
T_spec = fit_temperature(torch.from_numpy(val_spec_logits), val_spec_labels, mode="ordinal")
|
|
print(f"\n Fitted T_cat = {T_cat:.4f}")
|
|
print(f" Fitted T_spec = {T_spec:.4f}")
|
|
|
|
# ── 4. Run ensemble on holdout ──
|
|
print("\n Running ensemble on holdout...")
|
|
holdout_records = load_holdout_data(PARAGRAPHS_PATH, HOLDOUT_PATH, BENCHMARK_PATHS)
|
|
h_cat_logits, h_spec_logits = collect_ensemble_logits(holdout_records, device)
|
|
|
|
# ── 5. Apply temperature, recompute ECE per benchmark ──
|
|
h_cat_logits_t = torch.from_numpy(h_cat_logits)
|
|
h_spec_logits_t = torch.from_numpy(h_spec_logits)
|
|
|
|
cat_probs_pre = F.softmax(h_cat_logits_t, dim=1).numpy()
|
|
cat_probs_post = F.softmax(h_cat_logits_t / T_cat, dim=1).numpy()
|
|
|
|
spec_probs_pre = _ordinal_to_class_probs(h_spec_logits_t).numpy()
|
|
spec_probs_post = _ordinal_to_class_probs(h_spec_logits_t / T_spec).numpy()
|
|
|
|
# Predictions are unchanged (argmax invariant for cat; ordinal threshold at 0 invariant)
|
|
cat_preds = h_cat_logits_t.argmax(dim=1).numpy()
|
|
spec_preds = ordinal_predict(h_spec_logits_t).numpy()
|
|
|
|
summary = {
|
|
"T_cat": T_cat,
|
|
"T_spec": T_spec,
|
|
"per_benchmark": {},
|
|
}
|
|
|
|
for ref_name in BENCHMARK_PATHS:
|
|
cat_labels, spec_labels = [], []
|
|
cat_idx, spec_idx = [], []
|
|
for i, rec in enumerate(holdout_records):
|
|
bench = rec["benchmark_labels"].get(ref_name)
|
|
if bench is None:
|
|
continue
|
|
cat_labels.append(CAT2ID[bench["category"]])
|
|
spec_labels.append(bench["specificity"] - 1)
|
|
cat_idx.append(i)
|
|
spec_idx.append(i)
|
|
|
|
cat_labels = np.array(cat_labels)
|
|
spec_labels = np.array(spec_labels)
|
|
cat_idx = np.array(cat_idx)
|
|
spec_idx = np.array(spec_idx)
|
|
|
|
ece_cat_pre, _ = compute_ece(cat_probs_pre[cat_idx], cat_labels)
|
|
ece_cat_post, _ = compute_ece(cat_probs_post[cat_idx], cat_labels)
|
|
ece_spec_pre, _ = compute_ece(spec_probs_pre[spec_idx], spec_labels)
|
|
ece_spec_post, _ = compute_ece(spec_probs_post[spec_idx], spec_labels)
|
|
|
|
# Sanity check: predictions unchanged
|
|
cat_match = (cat_preds[cat_idx] == cat_probs_post[cat_idx].argmax(axis=1)).all()
|
|
spec_match = (spec_preds[spec_idx] == spec_probs_post[spec_idx].argmax(axis=1)).all()
|
|
|
|
print(f"\n {ref_name}")
|
|
print(f" Cat ECE: {ece_cat_pre:.4f} → {ece_cat_post:.4f} (Δ {ece_cat_post - ece_cat_pre:+.4f})")
|
|
print(f" Spec ECE: {ece_spec_pre:.4f} → {ece_spec_post:.4f} (Δ {ece_spec_post - ece_spec_pre:+.4f})")
|
|
print(f" Predictions preserved: cat={cat_match} spec={spec_match}")
|
|
|
|
summary["per_benchmark"][ref_name] = {
|
|
"ece_cat_pre": ece_cat_pre,
|
|
"ece_cat_post": ece_cat_post,
|
|
"ece_spec_pre": ece_spec_pre,
|
|
"ece_spec_post": ece_spec_post,
|
|
"cat_preds_preserved": bool(cat_match),
|
|
"spec_preds_preserved": bool(spec_match),
|
|
}
|
|
|
|
with open(OUTPUT_DIR / "temperature_scaling.json", "w") as f:
|
|
json.dump(summary, f, indent=2)
|
|
print(f"\n Saved {OUTPUT_DIR / 'temperature_scaling.json'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|