"""DAPT and TAPT training via HuggingFace Trainer. Both stages use the same masked language modeling objective — the only difference is the corpus (full filings for DAPT, Item 1C paragraphs for TAPT) and the starting checkpoint (base model for DAPT, DAPT checkpoint for TAPT). """ from pathlib import Path from transformers import ( AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments, ) from ..common.config import DAPTConfig from ..data.corpus import load_corpus, tokenize_and_chunk def train(config: DAPTConfig) -> None: """Run DAPT or TAPT training from a config.""" print(f"\n{'='*60}") print(f" SEC-cyBERT {config.stage.upper()} Training") print(f" Model: {config.model.name_or_path}") print(f" Data: {config.data.corpus_path}") print(f" Output: {config.training.output_dir}") print(f"{'='*60}\n") # Load tokenizer tokenizer_name = config.model.tokenizer or config.model.name_or_path tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, trust_remote_code=config.model.trust_remote_code, ) # Use Flash Attention 2 if available (fastest kernel), else SDPA fallback try: import flash_attn # noqa: F401 attn_impl = "flash_attention_2" except ImportError: attn_impl = "sdpa" model = AutoModelForMaskedLM.from_pretrained( config.model.name_or_path, trust_remote_code=config.model.trust_remote_code, attn_implementation=attn_impl, torch_dtype="bfloat16" if config.training.bf16 else None, ) print(f" Model parameters: {model.num_parameters() / 1e6:.0f}M") # Load and prepare data (with disk cache to avoid re-tokenizing on resume) output_dir = Path(config.training.output_dir) cache_dir = output_dir / ".data_cache" if cache_dir.exists(): print(f" Loading cached dataset from {cache_dir}...") from datasets import DatasetDict split = DatasetDict.load_from_disk(str(cache_dir)) print(f" Train: {len(split['train']):,} | Val: {len(split['test']):,}\n") else: print(f" Loading corpus from {config.data.corpus_path}...") dataset = load_corpus(config.data.corpus_path, config.data.text_field) print(f" Raw documents: {len(dataset):,}") # Filter tiny documents (cover pages, empty filings) — DAPT only if config.stage == "dapt": min_chars = 10_000 before = len(dataset) dataset = dataset.filter(lambda x: len(x[config.data.text_field]) >= min_chars) filtered = before - len(dataset) if filtered > 0: print(f" Filtered {filtered} docs < {min_chars:,} chars → {len(dataset):,} remaining") # Subsample corpus if max_tokens is set (Ponnock 2025: diminishing # returns beyond ~250M tokens for SEC DAPT). Takes from the END of # the corpus (newest filings first, since accessions sort chronologically). if config.data.max_tokens is not None: chars_per_token = 4.72 # empirical for ModernBERT tokenizer max_chars = int(config.data.max_tokens * chars_per_token) cumulative = 0 n = len(dataset) keep_from = n for i in range(n - 1, -1, -1): cumulative += len(dataset[i][config.data.text_field]) keep_from = i if cumulative >= max_chars: break dataset = dataset.select(range(keep_from, n)) est_tokens = cumulative / chars_per_token print(f" Subsampled to {n - keep_from:,} docs (~{est_tokens / 1e6:.0f}M tokens, newest filings, max_tokens={config.data.max_tokens:,})") wwm = config.training.whole_word_mask if wwm: print(f" Tokenizing to {config.data.max_seq_length} tokens (whole-word mask)...") else: print(f" Tokenizing and chunking to {config.data.max_seq_length} tokens...") chunked = tokenize_and_chunk( dataset, tokenizer, text_field=config.data.text_field, max_seq_length=config.data.max_seq_length, whole_word_mask=wwm, ) print(f" Training sequences: {len(chunked):,}") # Train/val split split = chunked.train_test_split( test_size=config.data.validation_split, seed=config.training.seed, ) print(f" Train: {len(split['train']):,} | Val: {len(split['test']):,}") # Cache to disk for fast resume split.save_to_disk(str(cache_dir)) print(f" Cached to {cache_dir}\n") # Data collator — handles dynamic masking each epoch collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=config.training.mlm_probability, whole_word_mask=config.training.whole_word_mask, ) # Training arguments output_dir = Path(config.training.output_dir) steps_per_epoch = len(split["train"]) // ( config.training.per_device_train_batch_size * config.training.gradient_accumulation_steps ) warmup_steps = int(config.training.warmup_ratio * steps_per_epoch) training_kwargs: dict = dict( output_dir=str(output_dir), learning_rate=config.training.learning_rate, num_train_epochs=config.training.num_train_epochs, per_device_train_batch_size=config.training.per_device_train_batch_size, gradient_accumulation_steps=config.training.gradient_accumulation_steps, warmup_steps=warmup_steps, weight_decay=config.training.weight_decay, bf16=config.training.bf16, gradient_checkpointing=config.training.gradient_checkpointing, torch_compile=True, optim="adamw_torch_fused", tf32=True, dataloader_persistent_workers=True, logging_steps=config.training.logging_steps, save_strategy=config.training.save_strategy, eval_strategy=config.training.eval_strategy, save_total_limit=config.training.save_total_limit, dataloader_num_workers=config.training.dataloader_num_workers, seed=config.training.seed, report_to="none", load_best_model_at_end=True, metric_for_best_model="eval_loss", ) if config.training.gradient_checkpointing: training_kwargs["gradient_checkpointing_kwargs"] = {"use_reentrant": False} # Long sequences need small eval batch to avoid OOM training_kwargs["per_device_eval_batch_size"] = 1 # Only pass step counts when using step-based strategy if config.training.save_strategy == "steps": training_kwargs["save_steps"] = config.training.save_steps if config.training.eval_strategy == "steps": training_kwargs["eval_steps"] = config.training.eval_steps args = TrainingArguments(**training_kwargs) trainer = Trainer( model=model, args=args, train_dataset=split["train"], eval_dataset=split["test"], data_collator=collator, ) # Train (with optional checkpoint resume) # Auto-detect checkpoint for resume (True = find latest in output_dir) resume = config.training.resume_from_checkpoint if resume is None and any((output_dir).glob("checkpoint-*")): resume = True trainer.train(resume_from_checkpoint=resume) # Save final model + tokenizer final_dir = output_dir / "final" print(f"\n Saving final model to {final_dir}...") trainer.save_model(str(final_dir)) tokenizer.save_pretrained(str(final_dir)) # Log final eval metrics = trainer.evaluate() print(f"\n Final eval loss: {metrics['eval_loss']:.4f}") print(f" Final perplexity: {2 ** metrics['eval_loss']:.2f}") print(f"\n {config.stage.upper()} training complete.")