147 lines
5.6 KiB
Python
147 lines
5.6 KiB
Python
"""SEC-cyBERT training CLI.
|
|
|
|
Usage:
|
|
uv run main.py dapt --config configs/dapt/modernbert.yaml
|
|
uv run main.py dapt --config configs/dapt/modernbert.yaml \\
|
|
--model-path ../checkpoints/dapt/modernbert-large/final \\
|
|
--data-path ../data/paragraphs/paragraphs-clean.jsonl \\
|
|
--output-dir ../checkpoints/tapt/modernbert-large \\
|
|
--stage tapt
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
|
|
|
|
def cmd_dapt(args: argparse.Namespace) -> None:
|
|
from src.common.config import DAPTConfig
|
|
from src.dapt.train import train
|
|
|
|
config = DAPTConfig.from_yaml(args.config)
|
|
config.apply_overrides(
|
|
model_path=args.model_path,
|
|
data_path=args.data_path,
|
|
output_dir=args.output_dir,
|
|
stage=args.stage,
|
|
)
|
|
train(config)
|
|
|
|
|
|
def cmd_finetune(args: argparse.Namespace) -> None:
|
|
from src.common.config import FinetuneConfig
|
|
from src.finetune.train import train
|
|
|
|
config = FinetuneConfig.from_yaml(args.config)
|
|
config.apply_overrides(
|
|
model_path=args.model_path,
|
|
output_dir=args.output_dir,
|
|
loss_type=args.loss_type,
|
|
class_weighting=args.class_weighting,
|
|
epochs=args.epochs,
|
|
batch_size=args.batch_size,
|
|
learning_rate=args.learning_rate,
|
|
)
|
|
train(config)
|
|
|
|
|
|
def cmd_eval(args: argparse.Namespace) -> None:
|
|
from src.finetune.eval import EvalConfig, evaluate
|
|
|
|
benchmark_paths = {}
|
|
if args.benchmark:
|
|
for name, path in args.benchmark:
|
|
benchmark_paths[name] = path
|
|
else:
|
|
# Default benchmarks
|
|
benchmark_paths = {
|
|
"GPT-5.4": "../data/annotations/v2-bench/gpt-5.4.jsonl",
|
|
"Opus-4.6": "../data/annotations/v2-bench/opus-4.6.jsonl",
|
|
}
|
|
|
|
config = EvalConfig(
|
|
checkpoint_path=args.checkpoint,
|
|
paragraphs_path=args.paragraphs,
|
|
holdout_path=args.holdout,
|
|
benchmark_paths=benchmark_paths,
|
|
output_dir=args.output_dir,
|
|
max_seq_length=args.max_seq_length,
|
|
batch_size=args.batch_size,
|
|
specificity_head=args.spec_head,
|
|
spec_mlp_dim=args.spec_mlp_dim,
|
|
pooling=args.pooling,
|
|
)
|
|
evaluate(config)
|
|
|
|
|
|
def cmd_ablate(args: argparse.Namespace) -> None:
|
|
from src.common.config import FinetuneConfig
|
|
from src.finetune.train import ablate
|
|
|
|
config = FinetuneConfig.from_yaml(args.config)
|
|
if args.output_dir:
|
|
config.training.output_dir = args.output_dir
|
|
if args.epochs:
|
|
config.training.num_train_epochs = args.epochs
|
|
ablate(config)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="SEC-cyBERT training pipeline",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
sub = parser.add_subparsers(dest="command", required=True)
|
|
|
|
# ── dapt / tapt ──
|
|
dapt = sub.add_parser(
|
|
"dapt",
|
|
help="Run DAPT or TAPT pre-training (masked language modeling)",
|
|
)
|
|
dapt.add_argument("--config", required=True, help="Path to YAML config file")
|
|
dapt.add_argument("--model-path", help="Override model name or checkpoint path")
|
|
dapt.add_argument("--data-path", help="Override corpus path (file or directory)")
|
|
dapt.add_argument("--output-dir", help="Override output directory")
|
|
dapt.add_argument("--stage", choices=["dapt", "tapt"], help="Override stage label")
|
|
dapt.set_defaults(func=cmd_dapt)
|
|
|
|
# ── finetune ──
|
|
ft = sub.add_parser("finetune", help="Fine-tune classifier (dual-head)")
|
|
ft.add_argument("--config", required=True, help="Path to YAML config file")
|
|
ft.add_argument("--model-path", help="Override model checkpoint path")
|
|
ft.add_argument("--output-dir", help="Override output directory")
|
|
ft.add_argument("--loss-type", choices=["ce", "focal"], help="Override loss type")
|
|
ft.add_argument("--class-weighting", type=lambda x: x.lower() == "true", help="Override class weighting (true/false)")
|
|
ft.add_argument("--epochs", type=int, help="Override number of epochs")
|
|
ft.add_argument("--batch-size", type=int, help="Override batch size")
|
|
ft.add_argument("--learning-rate", type=float, help="Override learning rate")
|
|
ft.set_defaults(func=cmd_finetune)
|
|
|
|
# ── ablate ──
|
|
ab = sub.add_parser("ablate", help="Run full ablation grid (3 ckpts x 2 weighting x 2 loss)")
|
|
ab.add_argument("--config", required=True, help="Path to YAML config file")
|
|
ab.add_argument("--output-dir", help="Override base output directory")
|
|
ab.add_argument("--epochs", type=int, help="Override epochs per ablation run (default: config value)")
|
|
ab.set_defaults(func=cmd_ablate)
|
|
|
|
# ── eval ──
|
|
ev = sub.add_parser("eval", help="Evaluate a trained model on holdout set")
|
|
ev.add_argument("--checkpoint", required=True, help="Path to model checkpoint directory")
|
|
ev.add_argument("--paragraphs", default="../data/paragraphs/paragraphs-clean.patched.jsonl")
|
|
ev.add_argument("--holdout", default="../data/gold/v2-holdout-ids.json")
|
|
ev.add_argument("--benchmark", action="append", nargs=2, metavar=("NAME", "PATH"),
|
|
help="Benchmark reference: NAME PATH (can repeat)")
|
|
ev.add_argument("--output-dir", default="../results/eval")
|
|
ev.add_argument("--max-seq-length", type=int, default=512)
|
|
ev.add_argument("--batch-size", type=int, default=64)
|
|
ev.add_argument("--spec-head", default="independent", choices=["coral", "independent", "softmax"])
|
|
ev.add_argument("--spec-mlp-dim", type=int, default=256)
|
|
ev.add_argument("--pooling", default="attention", choices=["cls", "attention"])
|
|
ev.set_defaults(func=cmd_eval)
|
|
|
|
args = parser.parse_args()
|
|
args.func(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|