2026-03-29 21:03:11 -04:00

139 lines
5.1 KiB
Python

"""DAPT and TAPT training via HuggingFace Trainer.
Both stages use the same masked language modeling objective — the only
difference is the corpus (full filings for DAPT, Item 1C paragraphs for TAPT)
and the starting checkpoint (base model for DAPT, DAPT checkpoint for TAPT).
"""
from pathlib import Path
from transformers import (
AutoModelForMaskedLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from ..common.config import DAPTConfig
from ..data.corpus import load_corpus, tokenize_and_chunk
def train(config: DAPTConfig) -> None:
"""Run DAPT or TAPT training from a config."""
print(f"\n{'='*60}")
print(f" SEC-cyBERT {config.stage.upper()} Training")
print(f" Model: {config.model.name_or_path}")
print(f" Data: {config.data.corpus_path}")
print(f" Output: {config.training.output_dir}")
print(f"{'='*60}\n")
# Load tokenizer
tokenizer_name = config.model.tokenizer or config.model.name_or_path
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
trust_remote_code=config.model.trust_remote_code,
)
# Load model
model = AutoModelForMaskedLM.from_pretrained(
config.model.name_or_path,
trust_remote_code=config.model.trust_remote_code,
)
print(f" Model parameters: {model.num_parameters() / 1e6:.0f}M")
# Load and prepare data (with disk cache to avoid re-tokenizing on resume)
output_dir = Path(config.training.output_dir)
cache_dir = output_dir / ".data_cache"
if cache_dir.exists():
print(f" Loading cached dataset from {cache_dir}...")
from datasets import DatasetDict
split = DatasetDict.load_from_disk(str(cache_dir))
print(f" Train: {len(split['train']):,} | Val: {len(split['test']):,}\n")
else:
print(f" Loading corpus from {config.data.corpus_path}...")
dataset = load_corpus(config.data.corpus_path, config.data.text_field)
print(f" Raw documents: {len(dataset):,}")
# Filter tiny documents (cover pages, empty filings)
min_chars = 10_000
before = len(dataset)
dataset = dataset.filter(lambda x: len(x[config.data.text_field]) >= min_chars)
filtered = before - len(dataset)
if filtered > 0:
print(f" Filtered {filtered} docs < {min_chars:,} chars → {len(dataset):,} remaining")
print(f" Tokenizing and chunking to {config.data.max_seq_length} tokens...")
chunked = tokenize_and_chunk(
dataset,
tokenizer,
text_field=config.data.text_field,
max_seq_length=config.data.max_seq_length,
)
print(f" Training sequences: {len(chunked):,}")
# Train/val split
split = chunked.train_test_split(
test_size=config.data.validation_split,
seed=config.training.seed,
)
print(f" Train: {len(split['train']):,} | Val: {len(split['test']):,}")
# Cache to disk for fast resume
split.save_to_disk(str(cache_dir))
print(f" Cached to {cache_dir}\n")
# Data collator — handles dynamic masking each epoch
collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=config.training.mlm_probability,
)
# Training arguments
output_dir = Path(config.training.output_dir)
args = TrainingArguments(
output_dir=str(output_dir),
learning_rate=config.training.learning_rate,
num_train_epochs=config.training.num_train_epochs,
per_device_train_batch_size=config.training.per_device_train_batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
warmup_ratio=config.training.warmup_ratio,
weight_decay=config.training.weight_decay,
bf16=config.training.bf16,
gradient_checkpointing=config.training.gradient_checkpointing,
logging_steps=config.training.logging_steps,
save_steps=config.training.save_steps,
eval_strategy="steps",
eval_steps=config.training.eval_steps,
save_total_limit=config.training.save_total_limit,
dataloader_num_workers=config.training.dataloader_num_workers,
seed=config.training.seed,
report_to="none",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
)
trainer = Trainer(
model=model,
args=args,
train_dataset=split["train"],
eval_dataset=split["test"],
data_collator=collator,
)
# Train (with optional checkpoint resume)
trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint)
# Save final model + tokenizer
final_dir = output_dir / "final"
print(f"\n Saving final model to {final_dir}...")
trainer.save_model(str(final_dir))
tokenizer.save_pretrained(str(final_dir))
# Log final eval
metrics = trainer.evaluate()
print(f"\n Final eval loss: {metrics['eval_loss']:.4f}")
print(f" Final perplexity: {2 ** metrics['eval_loss']:.2f}")
print(f"\n {config.stage.upper()} training complete.")