SEC-cyBERT/python/scripts/run-best-config.py
2026-04-05 12:16:16 -04:00

107 lines
3.2 KiB
Python

"""Analyze ablation results and launch the best config for a full training run.
Usage:
uv run python scripts/run-best-config.py [--epochs 5] [--dry-run]
"""
import json
import subprocess
import sys
from pathlib import Path
ABLATION_DIR = Path("../checkpoints/finetune/ablation")
RESULTS_FILE = ABLATION_DIR / "ablation_results.json"
# Maps ablation run name components to CLI args
CHECKPOINT_MAP = {
"base": "answerdotai/ModernBERT-large",
"dapt": "../checkpoints/dapt/modernbert-large/final",
"tapt": "../checkpoints/tapt/modernbert-large/final",
}
def main():
epochs = 5
dry_run = False
for arg in sys.argv[1:]:
if arg.startswith("--epochs"):
epochs = int(arg.split("=")[1] if "=" in arg else sys.argv[sys.argv.index(arg) + 1])
if arg == "--dry-run":
dry_run = True
if not RESULTS_FILE.exists():
print(f"No results file found at {RESULTS_FILE}")
print("Ablation may still be running. Check: ps aux | grep ablate")
sys.exit(1)
with open(RESULTS_FILE) as f:
results = json.load(f)
# Filter successful runs
successful = [r for r in results if "error" not in r]
if not successful:
print("No successful ablation runs found!")
sys.exit(1)
# Sort by combined macro F1
successful.sort(key=lambda r: r.get("eval_combined_macro_f1", 0), reverse=True)
# Print results table
print(f"\n{'='*80}")
print(" ABLATION RESULTS (sorted by combined F1)")
print(f"{'='*80}")
print(f" {'Run':<45} {'Combined':>10} {'Cat F1':>10} {'Spec F1':>10} {'QWK':>10}")
print(f" {'-'*45} {'-'*10} {'-'*10} {'-'*10} {'-'*10}")
for r in successful:
name = r["run"]
combined = r.get("eval_combined_macro_f1", 0)
cat = r.get("eval_cat_macro_f1", 0)
spec = r.get("eval_spec_macro_f1", 0)
qwk = r.get("eval_spec_qwk", 0)
marker = " <-- BEST" if r == successful[0] else ""
print(f" {name:<45} {combined:>10.4f} {cat:>10.4f} {spec:>10.4f} {qwk:>10.4f}{marker}")
# Parse best config
best = successful[0]
name = best["run"]
parts = name.split("_")
ckpt_name = parts[0]
weighting = parts[1] == "weighted"
loss_type = parts[2]
model_path = CHECKPOINT_MAP[ckpt_name]
output_dir = f"../checkpoints/finetune/best-{name}-ep{epochs}"
print(f"\n Best config: {name}")
print(f" Model: {model_path}")
print(f" Class weighting: {weighting}")
print(f" Loss: {loss_type}")
print(f" Epochs: {epochs}")
print(f" Output: {output_dir}")
cmd = [
"uv", "run", "python", "main.py", "finetune",
"--config", "configs/finetune/modernbert.yaml",
"--model-path", model_path,
"--output-dir", output_dir,
"--loss-type", loss_type,
"--class-weighting", str(weighting).lower(),
"--epochs", str(epochs),
]
if dry_run:
print(f"\n Dry run — would execute:")
print(f" {' '.join(cmd)}")
else:
print(f"\n Launching full training...")
env = {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}
import os
env.update(os.environ)
subprocess.run(cmd, env=env)
if __name__ == "__main__":
main()