"""Ensemble evaluation: average logits across N trained seed checkpoints. Runs inference for each checkpoint, averages category and specificity logits, derives predictions from the averaged logits, then computes the same metric suite as src.finetune.eval against the proxy gold benchmarks. """ import json from pathlib import Path import numpy as np import torch import torch.nn.functional as F from src.finetune.data import CAT2ID, CATEGORIES from src.finetune.eval import ( EvalConfig, SPEC_LABELS, _ordinal_to_class_probs, compute_all_metrics, format_report, generate_comparison_figures, generate_figures, load_holdout_data, load_model, run_inference, ) from src.finetune.model import ordinal_predict, softmax_predict CHECKPOINTS = { "seed42": "../checkpoints/finetune/iter1-independent/final", "seed69": "../checkpoints/finetune/iter1-seed69/final", "seed420": "../checkpoints/finetune/iter1-seed420/final", } BENCHMARK_PATHS = { "GPT-5.4": "../data/annotations/v2-bench/gpt-5.4.jsonl", "Opus-4.6": "../data/annotations/v2-bench/opus-4.6.jsonl", } PARAGRAPHS_PATH = "../data/paragraphs/paragraphs-clean.patched.jsonl" HOLDOUT_PATH = "../data/gold/v2-holdout-ids.json" OUTPUT_DIR = "../results/eval/ensemble-3seed" SPEC_HEAD = "independent" def main() -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") output_dir = Path(OUTPUT_DIR) output_dir.mkdir(parents=True, exist_ok=True) print(f"\n Device: {device}") print(f" Ensemble: {list(CHECKPOINTS.keys())}\n") # Load holdout once records = load_holdout_data(PARAGRAPHS_PATH, HOLDOUT_PATH, BENCHMARK_PATHS) print(f" Holdout paragraphs: {len(records)}") # Run each seed, collect logits per_seed_cat_logits = [] per_seed_spec_logits = [] per_seed_inference = {} for name, ckpt_path in CHECKPOINTS.items(): print(f"\n ── {name} ── loading {ckpt_path}") cfg = EvalConfig( checkpoint_path=ckpt_path, paragraphs_path=PARAGRAPHS_PATH, holdout_path=HOLDOUT_PATH, benchmark_paths=BENCHMARK_PATHS, output_dir=str(output_dir), specificity_head=SPEC_HEAD, ) model, tokenizer = load_model(cfg, device) inference = run_inference( model, tokenizer, records, cfg.max_seq_length, cfg.batch_size, device, SPEC_HEAD, ) print(f" {inference['avg_ms_per_sample']:.2f}ms/sample") per_seed_cat_logits.append(inference["cat_logits"]) per_seed_spec_logits.append(inference["spec_logits"]) per_seed_inference[name] = inference # Free GPU mem before next load del model torch.cuda.empty_cache() # Average logits across seeds cat_logits = np.mean(np.stack(per_seed_cat_logits, axis=0), axis=0) spec_logits = np.mean(np.stack(per_seed_spec_logits, axis=0), axis=0) cat_logits_t = torch.from_numpy(cat_logits) spec_logits_t = torch.from_numpy(spec_logits) cat_probs = F.softmax(cat_logits_t, dim=1).numpy() cat_preds = cat_logits_t.argmax(dim=1).numpy() if SPEC_HEAD == "softmax": spec_preds = softmax_predict(spec_logits_t).numpy() spec_probs = F.softmax(spec_logits_t, dim=1).numpy() else: spec_preds = ordinal_predict(spec_logits_t).numpy() spec_probs = _ordinal_to_class_probs(spec_logits_t).numpy() ensemble_inference = { "cat_preds": cat_preds, "cat_probs": cat_probs, "cat_logits": cat_logits, "spec_preds": spec_preds, "spec_probs": spec_probs, "spec_logits": spec_logits, "total_time_s": sum(p["total_time_s"] for p in per_seed_inference.values()), "num_samples": len(records), "avg_ms_per_sample": sum(p["avg_ms_per_sample"] for p in per_seed_inference.values()), } # Evaluate against benchmarks model_name = "ensemble-3seed" all_results = {} for ref_name in BENCHMARK_PATHS: print(f"\n Evaluating ensemble vs {ref_name}...") cat_labels, spec_labels = [], [] e_cat_preds, e_spec_preds = [], [] e_cat_probs, e_spec_probs = [], [] for i, rec in enumerate(records): bench = rec["benchmark_labels"].get(ref_name) if bench is None: continue cat_labels.append(CAT2ID[bench["category"]]) spec_labels.append(bench["specificity"] - 1) e_cat_preds.append(cat_preds[i]) e_spec_preds.append(spec_preds[i]) e_cat_probs.append(cat_probs[i]) e_spec_probs.append(spec_probs[i]) cat_labels = np.array(cat_labels) spec_labels = np.array(spec_labels) e_cat_preds = np.array(e_cat_preds) e_spec_preds = np.array(e_spec_preds) e_cat_probs = np.array(e_cat_probs) e_spec_probs = np.array(e_spec_probs) print(f" Matched samples: {len(cat_labels)}") cat_metrics = compute_all_metrics( e_cat_preds, cat_labels, e_cat_probs, CATEGORIES, "cat", is_ordinal=False ) spec_metrics = compute_all_metrics( e_spec_preds, spec_labels, e_spec_probs, SPEC_LABELS, "spec", is_ordinal=True ) combined = {**cat_metrics, **spec_metrics, **ensemble_inference} combined["combined_macro_f1"] = (combined["cat_macro_f1"] + combined["spec_macro_f1"]) / 2 report = format_report(model_name, ref_name, combined, ensemble_inference) print(report) report_path = output_dir / f"report_{ref_name.lower().replace(' ', '_').replace('.', '')}.txt" with open(report_path, "w") as f: f.write(report) figs = generate_figures(combined, output_dir, model_name, ref_name) print(f" Figures: {len(figs)}") all_results[f"{model_name}_vs_{ref_name}"] = combined comp_figs = generate_comparison_figures(all_results, output_dir) # Save JSON serializable = {} for k, v in all_results.items(): serializable[k] = { mk: mv for mk, mv in v.items() if isinstance(mv, (int, float, str, list, bool)) } with open(output_dir / "metrics.json", "w") as f: json.dump(serializable, f, indent=2, default=str) print(f"\n Results saved to {output_dir}") if __name__ == "__main__": main()