189 lines
6.3 KiB
Python
189 lines
6.3 KiB
Python
"""Ensemble evaluation: average logits across N trained seed checkpoints.
|
|
|
|
Runs inference for each checkpoint, averages category and specificity logits,
|
|
derives predictions from the averaged logits, then computes the same metric
|
|
suite as src.finetune.eval against the proxy gold benchmarks.
|
|
"""
|
|
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from src.finetune.data import CAT2ID, CATEGORIES
|
|
from src.finetune.eval import (
|
|
EvalConfig,
|
|
SPEC_LABELS,
|
|
_ordinal_to_class_probs,
|
|
compute_all_metrics,
|
|
format_report,
|
|
generate_comparison_figures,
|
|
generate_figures,
|
|
load_holdout_data,
|
|
load_model,
|
|
run_inference,
|
|
)
|
|
from src.finetune.model import ordinal_predict, softmax_predict
|
|
|
|
|
|
CHECKPOINTS = {
|
|
"seed42": "../checkpoints/finetune/iter1-independent/final",
|
|
"seed69": "../checkpoints/finetune/iter1-seed69/final",
|
|
"seed420": "../checkpoints/finetune/iter1-seed420/final",
|
|
}
|
|
|
|
BENCHMARK_PATHS = {
|
|
"GPT-5.4": "../data/annotations/v2-bench/gpt-5.4.jsonl",
|
|
"Opus-4.6": "../data/annotations/v2-bench/opus-4.6.jsonl",
|
|
}
|
|
|
|
PARAGRAPHS_PATH = "../data/paragraphs/paragraphs-clean.patched.jsonl"
|
|
HOLDOUT_PATH = "../data/gold/v2-holdout-ids.json"
|
|
OUTPUT_DIR = "../results/eval/ensemble-3seed"
|
|
SPEC_HEAD = "independent"
|
|
|
|
|
|
def main() -> None:
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
output_dir = Path(OUTPUT_DIR)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
print(f"\n Device: {device}")
|
|
print(f" Ensemble: {list(CHECKPOINTS.keys())}\n")
|
|
|
|
# Load holdout once
|
|
records = load_holdout_data(PARAGRAPHS_PATH, HOLDOUT_PATH, BENCHMARK_PATHS)
|
|
print(f" Holdout paragraphs: {len(records)}")
|
|
|
|
# Run each seed, collect logits
|
|
per_seed_cat_logits = []
|
|
per_seed_spec_logits = []
|
|
per_seed_inference = {}
|
|
|
|
for name, ckpt_path in CHECKPOINTS.items():
|
|
print(f"\n ── {name} ── loading {ckpt_path}")
|
|
cfg = EvalConfig(
|
|
checkpoint_path=ckpt_path,
|
|
paragraphs_path=PARAGRAPHS_PATH,
|
|
holdout_path=HOLDOUT_PATH,
|
|
benchmark_paths=BENCHMARK_PATHS,
|
|
output_dir=str(output_dir),
|
|
specificity_head=SPEC_HEAD,
|
|
)
|
|
model, tokenizer = load_model(cfg, device)
|
|
inference = run_inference(
|
|
model, tokenizer, records,
|
|
cfg.max_seq_length, cfg.batch_size,
|
|
device, SPEC_HEAD,
|
|
)
|
|
print(f" {inference['avg_ms_per_sample']:.2f}ms/sample")
|
|
per_seed_cat_logits.append(inference["cat_logits"])
|
|
per_seed_spec_logits.append(inference["spec_logits"])
|
|
per_seed_inference[name] = inference
|
|
|
|
# Free GPU mem before next load
|
|
del model
|
|
torch.cuda.empty_cache()
|
|
|
|
# Average logits across seeds
|
|
cat_logits = np.mean(np.stack(per_seed_cat_logits, axis=0), axis=0)
|
|
spec_logits = np.mean(np.stack(per_seed_spec_logits, axis=0), axis=0)
|
|
|
|
cat_logits_t = torch.from_numpy(cat_logits)
|
|
spec_logits_t = torch.from_numpy(spec_logits)
|
|
|
|
cat_probs = F.softmax(cat_logits_t, dim=1).numpy()
|
|
cat_preds = cat_logits_t.argmax(dim=1).numpy()
|
|
|
|
if SPEC_HEAD == "softmax":
|
|
spec_preds = softmax_predict(spec_logits_t).numpy()
|
|
spec_probs = F.softmax(spec_logits_t, dim=1).numpy()
|
|
else:
|
|
spec_preds = ordinal_predict(spec_logits_t).numpy()
|
|
spec_probs = _ordinal_to_class_probs(spec_logits_t).numpy()
|
|
|
|
ensemble_inference = {
|
|
"cat_preds": cat_preds,
|
|
"cat_probs": cat_probs,
|
|
"cat_logits": cat_logits,
|
|
"spec_preds": spec_preds,
|
|
"spec_probs": spec_probs,
|
|
"spec_logits": spec_logits,
|
|
"total_time_s": sum(p["total_time_s"] for p in per_seed_inference.values()),
|
|
"num_samples": len(records),
|
|
"avg_ms_per_sample": sum(p["avg_ms_per_sample"] for p in per_seed_inference.values()),
|
|
}
|
|
|
|
# Evaluate against benchmarks
|
|
model_name = "ensemble-3seed"
|
|
all_results = {}
|
|
|
|
for ref_name in BENCHMARK_PATHS:
|
|
print(f"\n Evaluating ensemble vs {ref_name}...")
|
|
|
|
cat_labels, spec_labels = [], []
|
|
e_cat_preds, e_spec_preds = [], []
|
|
e_cat_probs, e_spec_probs = [], []
|
|
|
|
for i, rec in enumerate(records):
|
|
bench = rec["benchmark_labels"].get(ref_name)
|
|
if bench is None:
|
|
continue
|
|
cat_labels.append(CAT2ID[bench["category"]])
|
|
spec_labels.append(bench["specificity"] - 1)
|
|
e_cat_preds.append(cat_preds[i])
|
|
e_spec_preds.append(spec_preds[i])
|
|
e_cat_probs.append(cat_probs[i])
|
|
e_spec_probs.append(spec_probs[i])
|
|
|
|
cat_labels = np.array(cat_labels)
|
|
spec_labels = np.array(spec_labels)
|
|
e_cat_preds = np.array(e_cat_preds)
|
|
e_spec_preds = np.array(e_spec_preds)
|
|
e_cat_probs = np.array(e_cat_probs)
|
|
e_spec_probs = np.array(e_spec_probs)
|
|
|
|
print(f" Matched samples: {len(cat_labels)}")
|
|
|
|
cat_metrics = compute_all_metrics(
|
|
e_cat_preds, cat_labels, e_cat_probs, CATEGORIES, "cat", is_ordinal=False
|
|
)
|
|
spec_metrics = compute_all_metrics(
|
|
e_spec_preds, spec_labels, e_spec_probs, SPEC_LABELS, "spec", is_ordinal=True
|
|
)
|
|
|
|
combined = {**cat_metrics, **spec_metrics, **ensemble_inference}
|
|
combined["combined_macro_f1"] = (combined["cat_macro_f1"] + combined["spec_macro_f1"]) / 2
|
|
|
|
report = format_report(model_name, ref_name, combined, ensemble_inference)
|
|
print(report)
|
|
|
|
report_path = output_dir / f"report_{ref_name.lower().replace(' ', '_').replace('.', '')}.txt"
|
|
with open(report_path, "w") as f:
|
|
f.write(report)
|
|
|
|
figs = generate_figures(combined, output_dir, model_name, ref_name)
|
|
print(f" Figures: {len(figs)}")
|
|
|
|
all_results[f"{model_name}_vs_{ref_name}"] = combined
|
|
|
|
comp_figs = generate_comparison_figures(all_results, output_dir)
|
|
|
|
# Save JSON
|
|
serializable = {}
|
|
for k, v in all_results.items():
|
|
serializable[k] = {
|
|
mk: mv for mk, mv in v.items()
|
|
if isinstance(mv, (int, float, str, list, bool))
|
|
}
|
|
with open(output_dir / "metrics.json", "w") as f:
|
|
json.dump(serializable, f, indent=2, default=str)
|
|
|
|
print(f"\n Results saved to {output_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|