2026-03-30 19:46:20 -04:00

197 lines
7.7 KiB
Python

"""DAPT and TAPT training via HuggingFace Trainer.
Both stages use the same masked language modeling objective — the only
difference is the corpus (full filings for DAPT, Item 1C paragraphs for TAPT)
and the starting checkpoint (base model for DAPT, DAPT checkpoint for TAPT).
"""
from pathlib import Path
from transformers import (
AutoModelForMaskedLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from ..common.config import DAPTConfig
from ..data.corpus import load_corpus, tokenize_and_chunk
def train(config: DAPTConfig) -> None:
"""Run DAPT or TAPT training from a config."""
print(f"\n{'='*60}")
print(f" SEC-cyBERT {config.stage.upper()} Training")
print(f" Model: {config.model.name_or_path}")
print(f" Data: {config.data.corpus_path}")
print(f" Output: {config.training.output_dir}")
print(f"{'='*60}\n")
# Load tokenizer
tokenizer_name = config.model.tokenizer or config.model.name_or_path
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
trust_remote_code=config.model.trust_remote_code,
)
# Use Flash Attention 2 if available (fastest kernel), else SDPA fallback
try:
import flash_attn # noqa: F401
attn_impl = "flash_attention_2"
except ImportError:
attn_impl = "sdpa"
model = AutoModelForMaskedLM.from_pretrained(
config.model.name_or_path,
trust_remote_code=config.model.trust_remote_code,
attn_implementation=attn_impl,
torch_dtype="bfloat16" if config.training.bf16 else None,
)
print(f" Model parameters: {model.num_parameters() / 1e6:.0f}M")
# Load and prepare data (with disk cache to avoid re-tokenizing on resume)
output_dir = Path(config.training.output_dir)
cache_dir = output_dir / ".data_cache"
if cache_dir.exists():
print(f" Loading cached dataset from {cache_dir}...")
from datasets import DatasetDict
split = DatasetDict.load_from_disk(str(cache_dir))
print(f" Train: {len(split['train']):,} | Val: {len(split['test']):,}\n")
else:
print(f" Loading corpus from {config.data.corpus_path}...")
dataset = load_corpus(config.data.corpus_path, config.data.text_field)
print(f" Raw documents: {len(dataset):,}")
# Filter tiny documents (cover pages, empty filings) — DAPT only
if config.stage == "dapt":
min_chars = 10_000
before = len(dataset)
dataset = dataset.filter(lambda x: len(x[config.data.text_field]) >= min_chars)
filtered = before - len(dataset)
if filtered > 0:
print(f" Filtered {filtered} docs < {min_chars:,} chars → {len(dataset):,} remaining")
# Subsample corpus if max_tokens is set (Ponnock 2025: diminishing
# returns beyond ~250M tokens for SEC DAPT). Takes from the END of
# the corpus (newest filings first, since accessions sort chronologically).
if config.data.max_tokens is not None:
chars_per_token = 4.72 # empirical for ModernBERT tokenizer
max_chars = int(config.data.max_tokens * chars_per_token)
cumulative = 0
n = len(dataset)
keep_from = n
for i in range(n - 1, -1, -1):
cumulative += len(dataset[i][config.data.text_field])
keep_from = i
if cumulative >= max_chars:
break
dataset = dataset.select(range(keep_from, n))
est_tokens = cumulative / chars_per_token
print(f" Subsampled to {n - keep_from:,} docs (~{est_tokens / 1e6:.0f}M tokens, newest filings, max_tokens={config.data.max_tokens:,})")
wwm = config.training.whole_word_mask
if wwm:
print(f" Tokenizing to {config.data.max_seq_length} tokens (whole-word mask)...")
else:
print(f" Tokenizing and chunking to {config.data.max_seq_length} tokens...")
chunked = tokenize_and_chunk(
dataset,
tokenizer,
text_field=config.data.text_field,
max_seq_length=config.data.max_seq_length,
whole_word_mask=wwm,
)
print(f" Training sequences: {len(chunked):,}")
# Train/val split
split = chunked.train_test_split(
test_size=config.data.validation_split,
seed=config.training.seed,
)
print(f" Train: {len(split['train']):,} | Val: {len(split['test']):,}")
# Cache to disk for fast resume
split.save_to_disk(str(cache_dir))
print(f" Cached to {cache_dir}\n")
# Data collator — handles dynamic masking each epoch
collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=True,
mlm_probability=config.training.mlm_probability,
whole_word_mask=config.training.whole_word_mask,
)
# Training arguments
output_dir = Path(config.training.output_dir)
steps_per_epoch = len(split["train"]) // (
config.training.per_device_train_batch_size
* config.training.gradient_accumulation_steps
)
warmup_steps = int(config.training.warmup_ratio * steps_per_epoch)
training_kwargs: dict = dict(
output_dir=str(output_dir),
learning_rate=config.training.learning_rate,
num_train_epochs=config.training.num_train_epochs,
per_device_train_batch_size=config.training.per_device_train_batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
warmup_steps=warmup_steps,
weight_decay=config.training.weight_decay,
bf16=config.training.bf16,
gradient_checkpointing=config.training.gradient_checkpointing,
torch_compile=True,
optim="adamw_torch_fused",
tf32=True,
dataloader_persistent_workers=True,
logging_steps=config.training.logging_steps,
save_strategy=config.training.save_strategy,
eval_strategy=config.training.eval_strategy,
save_total_limit=config.training.save_total_limit,
dataloader_num_workers=config.training.dataloader_num_workers,
seed=config.training.seed,
report_to="none",
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
)
if config.training.gradient_checkpointing:
training_kwargs["gradient_checkpointing_kwargs"] = {"use_reentrant": False}
# Long sequences need small eval batch to avoid OOM
training_kwargs["per_device_eval_batch_size"] = 1
# Only pass step counts when using step-based strategy
if config.training.save_strategy == "steps":
training_kwargs["save_steps"] = config.training.save_steps
if config.training.eval_strategy == "steps":
training_kwargs["eval_steps"] = config.training.eval_steps
args = TrainingArguments(**training_kwargs)
trainer = Trainer(
model=model,
args=args,
train_dataset=split["train"],
eval_dataset=split["test"],
data_collator=collator,
)
# Train (with optional checkpoint resume)
# Auto-detect checkpoint for resume (True = find latest in output_dir)
resume = config.training.resume_from_checkpoint
if resume is None and any((output_dir).glob("checkpoint-*")):
resume = True
trainer.train(resume_from_checkpoint=resume)
# Save final model + tokenizer
final_dir = output_dir / "final"
print(f"\n Saving final model to {final_dir}...")
trainer.save_model(str(final_dir))
tokenizer.save_pretrained(str(final_dir))
# Log final eval
metrics = trainer.evaluate()
print(f"\n Final eval loss: {metrics['eval_loss']:.4f}")
print(f" Final perplexity: {2 ** metrics['eval_loss']:.2f}")
print(f"\n {config.stage.upper()} training complete.")