107 lines
3.2 KiB
Python
107 lines
3.2 KiB
Python
"""Analyze ablation results and launch the best config for a full training run.
|
|
|
|
Usage:
|
|
uv run python scripts/run-best-config.py [--epochs 5] [--dry-run]
|
|
"""
|
|
|
|
import json
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
|
|
ABLATION_DIR = Path("../checkpoints/finetune/ablation")
|
|
RESULTS_FILE = ABLATION_DIR / "ablation_results.json"
|
|
|
|
# Maps ablation run name components to CLI args
|
|
CHECKPOINT_MAP = {
|
|
"base": "answerdotai/ModernBERT-large",
|
|
"dapt": "../checkpoints/dapt/modernbert-large/final",
|
|
"tapt": "../checkpoints/tapt/modernbert-large/final",
|
|
}
|
|
|
|
|
|
def main():
|
|
epochs = 5
|
|
dry_run = False
|
|
for arg in sys.argv[1:]:
|
|
if arg.startswith("--epochs"):
|
|
epochs = int(arg.split("=")[1] if "=" in arg else sys.argv[sys.argv.index(arg) + 1])
|
|
if arg == "--dry-run":
|
|
dry_run = True
|
|
|
|
if not RESULTS_FILE.exists():
|
|
print(f"No results file found at {RESULTS_FILE}")
|
|
print("Ablation may still be running. Check: ps aux | grep ablate")
|
|
sys.exit(1)
|
|
|
|
with open(RESULTS_FILE) as f:
|
|
results = json.load(f)
|
|
|
|
# Filter successful runs
|
|
successful = [r for r in results if "error" not in r]
|
|
if not successful:
|
|
print("No successful ablation runs found!")
|
|
sys.exit(1)
|
|
|
|
# Sort by combined macro F1
|
|
successful.sort(key=lambda r: r.get("eval_combined_macro_f1", 0), reverse=True)
|
|
|
|
# Print results table
|
|
print(f"\n{'='*80}")
|
|
print(" ABLATION RESULTS (sorted by combined F1)")
|
|
print(f"{'='*80}")
|
|
print(f" {'Run':<45} {'Combined':>10} {'Cat F1':>10} {'Spec F1':>10} {'QWK':>10}")
|
|
print(f" {'-'*45} {'-'*10} {'-'*10} {'-'*10} {'-'*10}")
|
|
|
|
for r in successful:
|
|
name = r["run"]
|
|
combined = r.get("eval_combined_macro_f1", 0)
|
|
cat = r.get("eval_cat_macro_f1", 0)
|
|
spec = r.get("eval_spec_macro_f1", 0)
|
|
qwk = r.get("eval_spec_qwk", 0)
|
|
marker = " <-- BEST" if r == successful[0] else ""
|
|
print(f" {name:<45} {combined:>10.4f} {cat:>10.4f} {spec:>10.4f} {qwk:>10.4f}{marker}")
|
|
|
|
# Parse best config
|
|
best = successful[0]
|
|
name = best["run"]
|
|
parts = name.split("_")
|
|
ckpt_name = parts[0]
|
|
weighting = parts[1] == "weighted"
|
|
loss_type = parts[2]
|
|
|
|
model_path = CHECKPOINT_MAP[ckpt_name]
|
|
output_dir = f"../checkpoints/finetune/best-{name}-ep{epochs}"
|
|
|
|
print(f"\n Best config: {name}")
|
|
print(f" Model: {model_path}")
|
|
print(f" Class weighting: {weighting}")
|
|
print(f" Loss: {loss_type}")
|
|
print(f" Epochs: {epochs}")
|
|
print(f" Output: {output_dir}")
|
|
|
|
cmd = [
|
|
"uv", "run", "python", "main.py", "finetune",
|
|
"--config", "configs/finetune/modernbert.yaml",
|
|
"--model-path", model_path,
|
|
"--output-dir", output_dir,
|
|
"--loss-type", loss_type,
|
|
"--class-weighting", str(weighting).lower(),
|
|
"--epochs", str(epochs),
|
|
]
|
|
|
|
if dry_run:
|
|
print(f"\n Dry run — would execute:")
|
|
print(f" {' '.join(cmd)}")
|
|
else:
|
|
print(f"\n Launching full training...")
|
|
env = {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}
|
|
import os
|
|
env.update(os.environ)
|
|
subprocess.run(cmd, env=env)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|