SEC-cyBERT/python/main.py
2026-04-05 12:16:16 -04:00

108 lines
4.0 KiB
Python

"""SEC-cyBERT training CLI.
Usage:
uv run main.py dapt --config configs/dapt/modernbert.yaml
uv run main.py dapt --config configs/dapt/modernbert.yaml \\
--model-path ../checkpoints/dapt/modernbert-large/final \\
--data-path ../data/paragraphs/paragraphs-clean.jsonl \\
--output-dir ../checkpoints/tapt/modernbert-large \\
--stage tapt
"""
import argparse
import sys
def cmd_dapt(args: argparse.Namespace) -> None:
from src.common.config import DAPTConfig
from src.dapt.train import train
config = DAPTConfig.from_yaml(args.config)
config.apply_overrides(
model_path=args.model_path,
data_path=args.data_path,
output_dir=args.output_dir,
stage=args.stage,
)
train(config)
def cmd_finetune(args: argparse.Namespace) -> None:
from src.common.config import FinetuneConfig
from src.finetune.train import train
config = FinetuneConfig.from_yaml(args.config)
config.apply_overrides(
model_path=args.model_path,
output_dir=args.output_dir,
loss_type=args.loss_type,
class_weighting=args.class_weighting,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
)
train(config)
def cmd_ablate(args: argparse.Namespace) -> None:
from src.common.config import FinetuneConfig
from src.finetune.train import ablate
config = FinetuneConfig.from_yaml(args.config)
if args.output_dir:
config.training.output_dir = args.output_dir
if args.epochs:
config.training.num_train_epochs = args.epochs
ablate(config)
def main() -> None:
parser = argparse.ArgumentParser(
description="SEC-cyBERT training pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
sub = parser.add_subparsers(dest="command", required=True)
# ── dapt / tapt ──
dapt = sub.add_parser(
"dapt",
help="Run DAPT or TAPT pre-training (masked language modeling)",
)
dapt.add_argument("--config", required=True, help="Path to YAML config file")
dapt.add_argument("--model-path", help="Override model name or checkpoint path")
dapt.add_argument("--data-path", help="Override corpus path (file or directory)")
dapt.add_argument("--output-dir", help="Override output directory")
dapt.add_argument("--stage", choices=["dapt", "tapt"], help="Override stage label")
dapt.set_defaults(func=cmd_dapt)
# ── finetune ──
ft = sub.add_parser("finetune", help="Fine-tune classifier (dual-head)")
ft.add_argument("--config", required=True, help="Path to YAML config file")
ft.add_argument("--model-path", help="Override model checkpoint path")
ft.add_argument("--output-dir", help="Override output directory")
ft.add_argument("--loss-type", choices=["ce", "focal"], help="Override loss type")
ft.add_argument("--class-weighting", type=lambda x: x.lower() == "true", help="Override class weighting (true/false)")
ft.add_argument("--epochs", type=int, help="Override number of epochs")
ft.add_argument("--batch-size", type=int, help="Override batch size")
ft.add_argument("--learning-rate", type=float, help="Override learning rate")
ft.set_defaults(func=cmd_finetune)
# ── ablate ──
ab = sub.add_parser("ablate", help="Run full ablation grid (3 ckpts x 2 weighting x 2 loss)")
ab.add_argument("--config", required=True, help="Path to YAML config file")
ab.add_argument("--output-dir", help="Override base output directory")
ab.add_argument("--epochs", type=int, help="Override epochs per ablation run (default: config value)")
ab.set_defaults(func=cmd_ablate)
# ── eval (placeholder) ──
ev = sub.add_parser("eval", help="Evaluate a trained model")
ev.add_argument("--config", required=True, help="Path to YAML config file")
ev.set_defaults(func=lambda args: print("Evaluation not yet implemented."))
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()