70 lines
2.5 KiB
Python
70 lines
2.5 KiB
Python
"""SEC-cyBERT training CLI.
|
|
|
|
Usage:
|
|
uv run main.py dapt --config configs/dapt/modernbert.yaml
|
|
uv run main.py dapt --config configs/dapt/modernbert.yaml \\
|
|
--model-path ../checkpoints/dapt/modernbert-large/final \\
|
|
--data-path ../data/paragraphs/paragraphs-clean.jsonl \\
|
|
--output-dir ../checkpoints/tapt/modernbert-large \\
|
|
--stage tapt
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
|
|
|
|
def cmd_dapt(args: argparse.Namespace) -> None:
|
|
from src.common.config import DAPTConfig
|
|
from src.dapt.train import train
|
|
|
|
config = DAPTConfig.from_yaml(args.config)
|
|
config.apply_overrides(
|
|
model_path=args.model_path,
|
|
data_path=args.data_path,
|
|
output_dir=args.output_dir,
|
|
stage=args.stage,
|
|
)
|
|
train(config)
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="SEC-cyBERT training pipeline",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
sub = parser.add_subparsers(dest="command", required=True)
|
|
|
|
# ── dapt / tapt ──
|
|
dapt = sub.add_parser(
|
|
"dapt",
|
|
help="Run DAPT or TAPT pre-training (masked language modeling)",
|
|
)
|
|
dapt.add_argument("--config", required=True, help="Path to YAML config file")
|
|
dapt.add_argument("--model-path", help="Override model name or checkpoint path")
|
|
dapt.add_argument("--data-path", help="Override corpus path (file or directory)")
|
|
dapt.add_argument("--output-dir", help="Override output directory")
|
|
dapt.add_argument("--stage", choices=["dapt", "tapt"], help="Override stage label")
|
|
dapt.set_defaults(func=cmd_dapt)
|
|
|
|
# ── finetune (placeholder) ──
|
|
ft = sub.add_parser("finetune", help="Fine-tune classifier (dual-head)")
|
|
ft.add_argument("--config", required=True, help="Path to YAML config file")
|
|
ft.set_defaults(func=lambda args: print("Fine-tuning not yet implemented."))
|
|
|
|
# ── eval (placeholder) ──
|
|
ev = sub.add_parser("eval", help="Evaluate a trained model")
|
|
ev.add_argument("--config", required=True, help="Path to YAML config file")
|
|
ev.set_defaults(func=lambda args: print("Evaluation not yet implemented."))
|
|
|
|
# ── decoder (placeholder) ──
|
|
dec = sub.add_parser("decoder", help="Decoder experiment (Qwen LoRA)")
|
|
dec.add_argument("--config", required=True, help="Path to YAML config file")
|
|
dec.set_defaults(func=lambda args: print("Decoder experiment not yet implemented."))
|
|
|
|
args = parser.parse_args()
|
|
args.func(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|