"""Analyze ablation results and launch the best config for a full training run. Usage: uv run python scripts/run-best-config.py [--epochs 5] [--dry-run] """ import json import subprocess import sys from pathlib import Path ABLATION_DIR = Path("../checkpoints/finetune/ablation") RESULTS_FILE = ABLATION_DIR / "ablation_results.json" # Maps ablation run name components to CLI args CHECKPOINT_MAP = { "base": "answerdotai/ModernBERT-large", "dapt": "../checkpoints/dapt/modernbert-large/final", "tapt": "../checkpoints/tapt/modernbert-large/final", } def main(): epochs = 5 dry_run = False for arg in sys.argv[1:]: if arg.startswith("--epochs"): epochs = int(arg.split("=")[1] if "=" in arg else sys.argv[sys.argv.index(arg) + 1]) if arg == "--dry-run": dry_run = True if not RESULTS_FILE.exists(): print(f"No results file found at {RESULTS_FILE}") print("Ablation may still be running. Check: ps aux | grep ablate") sys.exit(1) with open(RESULTS_FILE) as f: results = json.load(f) # Filter successful runs successful = [r for r in results if "error" not in r] if not successful: print("No successful ablation runs found!") sys.exit(1) # Sort by combined macro F1 successful.sort(key=lambda r: r.get("eval_combined_macro_f1", 0), reverse=True) # Print results table print(f"\n{'='*80}") print(" ABLATION RESULTS (sorted by combined F1)") print(f"{'='*80}") print(f" {'Run':<45} {'Combined':>10} {'Cat F1':>10} {'Spec F1':>10} {'QWK':>10}") print(f" {'-'*45} {'-'*10} {'-'*10} {'-'*10} {'-'*10}") for r in successful: name = r["run"] combined = r.get("eval_combined_macro_f1", 0) cat = r.get("eval_cat_macro_f1", 0) spec = r.get("eval_spec_macro_f1", 0) qwk = r.get("eval_spec_qwk", 0) marker = " <-- BEST" if r == successful[0] else "" print(f" {name:<45} {combined:>10.4f} {cat:>10.4f} {spec:>10.4f} {qwk:>10.4f}{marker}") # Parse best config best = successful[0] name = best["run"] parts = name.split("_") ckpt_name = parts[0] weighting = parts[1] == "weighted" loss_type = parts[2] model_path = CHECKPOINT_MAP[ckpt_name] output_dir = f"../checkpoints/finetune/best-{name}-ep{epochs}" print(f"\n Best config: {name}") print(f" Model: {model_path}") print(f" Class weighting: {weighting}") print(f" Loss: {loss_type}") print(f" Epochs: {epochs}") print(f" Output: {output_dir}") cmd = [ "uv", "run", "python", "main.py", "finetune", "--config", "configs/finetune/modernbert.yaml", "--model-path", model_path, "--output-dir", output_dir, "--loss-type", loss_type, "--class-weighting", str(weighting).lower(), "--epochs", str(epochs), ] if dry_run: print(f"\n Dry run — would execute:") print(f" {' '.join(cmd)}") else: print(f"\n Launching full training...") env = {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"} import os env.update(os.environ) subprocess.run(cmd, env=env) if __name__ == "__main__": main()