SEC-cyBERT/python/scripts/eval_ensemble.py

189 lines
6.3 KiB
Python

"""Ensemble evaluation: average logits across N trained seed checkpoints.
Runs inference for each checkpoint, averages category and specificity logits,
derives predictions from the averaged logits, then computes the same metric
suite as src.finetune.eval against the proxy gold benchmarks.
"""
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from src.finetune.data import CAT2ID, CATEGORIES
from src.finetune.eval import (
EvalConfig,
SPEC_LABELS,
_ordinal_to_class_probs,
compute_all_metrics,
format_report,
generate_comparison_figures,
generate_figures,
load_holdout_data,
load_model,
run_inference,
)
from src.finetune.model import ordinal_predict, softmax_predict
CHECKPOINTS = {
"seed42": "../checkpoints/finetune/iter1-independent/final",
"seed69": "../checkpoints/finetune/iter1-seed69/final",
"seed420": "../checkpoints/finetune/iter1-seed420/final",
}
BENCHMARK_PATHS = {
"GPT-5.4": "../data/annotations/v2-bench/gpt-5.4.jsonl",
"Opus-4.6": "../data/annotations/v2-bench/opus-4.6.jsonl",
}
PARAGRAPHS_PATH = "../data/paragraphs/paragraphs-clean.patched.jsonl"
HOLDOUT_PATH = "../data/gold/v2-holdout-ids.json"
OUTPUT_DIR = "../results/eval/ensemble-3seed"
SPEC_HEAD = "independent"
def main() -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"\n Device: {device}")
print(f" Ensemble: {list(CHECKPOINTS.keys())}\n")
# Load holdout once
records = load_holdout_data(PARAGRAPHS_PATH, HOLDOUT_PATH, BENCHMARK_PATHS)
print(f" Holdout paragraphs: {len(records)}")
# Run each seed, collect logits
per_seed_cat_logits = []
per_seed_spec_logits = []
per_seed_inference = {}
for name, ckpt_path in CHECKPOINTS.items():
print(f"\n ── {name} ── loading {ckpt_path}")
cfg = EvalConfig(
checkpoint_path=ckpt_path,
paragraphs_path=PARAGRAPHS_PATH,
holdout_path=HOLDOUT_PATH,
benchmark_paths=BENCHMARK_PATHS,
output_dir=str(output_dir),
specificity_head=SPEC_HEAD,
)
model, tokenizer = load_model(cfg, device)
inference = run_inference(
model, tokenizer, records,
cfg.max_seq_length, cfg.batch_size,
device, SPEC_HEAD,
)
print(f" {inference['avg_ms_per_sample']:.2f}ms/sample")
per_seed_cat_logits.append(inference["cat_logits"])
per_seed_spec_logits.append(inference["spec_logits"])
per_seed_inference[name] = inference
# Free GPU mem before next load
del model
torch.cuda.empty_cache()
# Average logits across seeds
cat_logits = np.mean(np.stack(per_seed_cat_logits, axis=0), axis=0)
spec_logits = np.mean(np.stack(per_seed_spec_logits, axis=0), axis=0)
cat_logits_t = torch.from_numpy(cat_logits)
spec_logits_t = torch.from_numpy(spec_logits)
cat_probs = F.softmax(cat_logits_t, dim=1).numpy()
cat_preds = cat_logits_t.argmax(dim=1).numpy()
if SPEC_HEAD == "softmax":
spec_preds = softmax_predict(spec_logits_t).numpy()
spec_probs = F.softmax(spec_logits_t, dim=1).numpy()
else:
spec_preds = ordinal_predict(spec_logits_t).numpy()
spec_probs = _ordinal_to_class_probs(spec_logits_t).numpy()
ensemble_inference = {
"cat_preds": cat_preds,
"cat_probs": cat_probs,
"cat_logits": cat_logits,
"spec_preds": spec_preds,
"spec_probs": spec_probs,
"spec_logits": spec_logits,
"total_time_s": sum(p["total_time_s"] for p in per_seed_inference.values()),
"num_samples": len(records),
"avg_ms_per_sample": sum(p["avg_ms_per_sample"] for p in per_seed_inference.values()),
}
# Evaluate against benchmarks
model_name = "ensemble-3seed"
all_results = {}
for ref_name in BENCHMARK_PATHS:
print(f"\n Evaluating ensemble vs {ref_name}...")
cat_labels, spec_labels = [], []
e_cat_preds, e_spec_preds = [], []
e_cat_probs, e_spec_probs = [], []
for i, rec in enumerate(records):
bench = rec["benchmark_labels"].get(ref_name)
if bench is None:
continue
cat_labels.append(CAT2ID[bench["category"]])
spec_labels.append(bench["specificity"] - 1)
e_cat_preds.append(cat_preds[i])
e_spec_preds.append(spec_preds[i])
e_cat_probs.append(cat_probs[i])
e_spec_probs.append(spec_probs[i])
cat_labels = np.array(cat_labels)
spec_labels = np.array(spec_labels)
e_cat_preds = np.array(e_cat_preds)
e_spec_preds = np.array(e_spec_preds)
e_cat_probs = np.array(e_cat_probs)
e_spec_probs = np.array(e_spec_probs)
print(f" Matched samples: {len(cat_labels)}")
cat_metrics = compute_all_metrics(
e_cat_preds, cat_labels, e_cat_probs, CATEGORIES, "cat", is_ordinal=False
)
spec_metrics = compute_all_metrics(
e_spec_preds, spec_labels, e_spec_probs, SPEC_LABELS, "spec", is_ordinal=True
)
combined = {**cat_metrics, **spec_metrics, **ensemble_inference}
combined["combined_macro_f1"] = (combined["cat_macro_f1"] + combined["spec_macro_f1"]) / 2
report = format_report(model_name, ref_name, combined, ensemble_inference)
print(report)
report_path = output_dir / f"report_{ref_name.lower().replace(' ', '_').replace('.', '')}.txt"
with open(report_path, "w") as f:
f.write(report)
figs = generate_figures(combined, output_dir, model_name, ref_name)
print(f" Figures: {len(figs)}")
all_results[f"{model_name}_vs_{ref_name}"] = combined
comp_figs = generate_comparison_figures(all_results, output_dir)
# Save JSON
serializable = {}
for k, v in all_results.items():
serializable[k] = {
mk: mv for mk, mv in v.items()
if isinstance(mv, (int, float, str, list, bool))
}
with open(output_dir / "metrics.json", "w") as f:
json.dump(serializable, f, indent=2, default=str)
print(f"\n Results saved to {output_dir}")
if __name__ == "__main__":
main()