SEC-cyBERT/docs/training/DAPT-PROCEDURE.md
2026-04-05 21:00:40 -04:00

299 lines
18 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# DAPT/TAPT Training Procedure
**Date:** 2026-03-29
**Hardware:** NVIDIA RTX 3090 (24GB VRAM), CUDA driver 13.2, PyTorch 2.11.0+cu130
---
## Pre-flight Checklist
| Check | Status |
|-------|--------|
| PyTorch 2.11.0+cu130, CUDA available | Verified |
| RTX 3090, 25.3 GB VRAM, bf16 supported | Verified |
| Flash Attention 2 (flash-attn 2.6.3+cu130torch2.11) | Verified |
| ModernBERT-large loads: 396M params, max_position_embeddings=8192 | Verified |
| Corpus: 14,756 docs, ~1.06B tokens, 15 shards | Verified |
| After <10K filter: 14,568 docs, ~1.056B tokens (0.027% loss) | Verified |
| Tokenize+chunk pipeline: 10 docs -> 85 sequences of 8192 tokens | Verified |
| Config: seq_len=8192, batch=4, grad_accum=8, 1 epoch, lr=5e-5, mlm=0.30 | Set |
## DAPT Corpus Summary
- **14,568 documents** (after filtering 188 cover pages <10K chars)
- **~1.056 billion tokens** (ModernBERT tokenizer, 4.72 chars/token)
- **~136K training sequences** at seq_len=8192
- **Median document: ~73K tokens** (347K chars) 90.6% of docs exceed 8192 tokens
- Cleaned: XBRL data blobs stripped, exhibit listings stripped, URLs removed, F-N page numbers removed
- Source: 14,759 cached 10-K HTML filings, FY2023-FY2025, processed by `ts/scripts/dapt-corpus-prep.ts`
## Training Configuration
**Config file:** `python/configs/dapt/modernbert.yaml`
| Parameter | Value | Rationale |
|-----------|-------|-----------|
| `max_seq_length` | 8192 | Match ModernBERT's pre-training context length |
| `max_tokens` | 500,000,000 | Subsample to newest 500M tokens (Ponnock 2025: diminishing returns past 250M) |
| `per_device_train_batch_size` | 4 | Maximum for 24GB VRAM with FA2 + torch.compile + grad checkpointing |
| `gradient_accumulation_steps` | 8 | Effective batch size = 32 |
| `num_train_epochs` | 1 | Single pass per Gururangan et al. (2020) and Ponnock (2025) |
| `learning_rate` | 5e-5 | Conservative for post-decay checkpoint (see note below) |
| `mlm_probability` | 0.30 | Matches ModernBERT's pre-training masking rate (Warner et al., 2024) |
| `weight_decay` | 1e-5 | Matches ModernBERT pre-training; used by BioClinical-ModernBERT and Patent-ModernBERT |
| `warmup_ratio` | 0.05 | ~213 warmup steps |
| `gradient_checkpointing` | true | Required for 8192 seq_len on 24GB |
| `gradient_checkpointing_kwargs` | `use_reentrant: False` | Required for torch.compile compatibility |
| `bf16` | true | Native RTX 3090 support |
| `torch_compile` | true | 20-40% speedup; fixes FA2 memory anomaly on ModernBERT (AnswerDotAI/ModernBERT#172) |
| `optim` | `adamw_torch_fused` | Fused optimizer kernel, 5-10% speedup |
| `tf32` | true | Free speedup on remaining fp32 ops (Ampere architecture) |
| `attn_implementation` | `flash_attention_2` | Flash Attention 2 via flash-attn package; falls back to SDPA if unavailable |
| `save_steps` | 1000 | Checkpoint every ~1000 steps |
| `eval_steps` | 1000 | Evaluate every ~1000 steps |
| `save_total_limit` | 3 | Keep last 3 checkpoints |
### Hyperparameter Rationale
**Learning rate (5e-5):** Conservative because we start from the published ModernBERT-large checkpoint, which is the post-decay final model. BioClinical-ModernBERT (Sounack et al., 2025) and Patent-ModernBERT (Luo et al., 2025) used 3e-4 but started from pre-decay stable-phase checkpoints. The ModernBERT authors released training checkpoints (`answerdotai/ModernBERT-large-training-checkpoints`) and noted: "Anyone is free to restart training from any of our pre-decay checkpoints, and perform annealing on domain-appropriate data" (Warner et al., 2024). Starting from the post-decay model with a high LR risks destabilizing learned representations.
**Weight decay (1e-5):** The original ModernBERT pre-training used 1e-5 weight decay. Both BioClinical-ModernBERT and Patent-ModernBERT preserved this value. The commonly-used 0.01 is a BERT/RoBERTa default that doesn't apply here.
### Performance Optimizations
**Flash Attention 2** reduces attention from O(n^2) to O(n) memory and provides ~2-4x throughput improvement at seq_len=8192. ModernBERT was designed with FA2 support, including alternating attention: every 3rd layer uses global attention (full 8192-token context with RoPE theta 160K), while other layers use 128-token local sliding window attention (RoPE theta 10K). This dramatically reduces the O(n^2) cost (Warner et al., 2024).
**torch.compile** JIT-compiles the model into fused CUDA kernels via the Inductor backend. On ModernBERT specifically, it also resolves a known memory anomaly where FA2 uses ~88% GPU memory vs ~48% for SDPA during MLM training (AnswerDotAI/ModernBERT#172). The fix is enabling both `torch_compile=True` and `gradient_checkpointing=True` together.
**Fused AdamW** merges parameter updates into a single CUDA kernel, reducing kernel launch overhead across 396M parameters.
**Corpus subsampling** to 500M tokens (from 1.06B) halves training time. The subsampler takes from the tail of the corpus (newest filings, FY2024-2025) since the accession-sorted shards are roughly chronological. Ponnock (2025) showed the largest DAPT gains occur in the first 200-250M tokens with shallow power-law scaling thereafter 500M provides a comfortable margin.
### Optimization Journey
The path to the final config involved iterative experimentation on the RTX 3090:
| Change | s/step | VRAM | Outcome |
|--------|--------|------|---------|
| Baseline (PyTorch 2.10, no FA2, batch=1) | ~47s | ~16GB | Compute-bound, attention is O(n²) |
| + Flash Attention 2 (PyTorch 2.11+cu130) | ~27s | ~16GB | FA2 halves attention compute time |
| + batch=2 (grad_accum 3216) | ~27s | ~18.2GB | GPU already saturated at seq_len=8192 bigger batch doesn't help s/step |
| + torch.compile (with FA2) | ~27s | **~11.9GB** | Graph breaks at FA2 kernels prevent compute speedup, but fusing non-attention ops halved activation memory |
| + batch=4 (using compile's memory savings) | ~27s | ~18.5GB | Same s/step, but 4x fewer grad_accum micro-steps reduces overhead marginally |
| + 500M token subsample | ~27s | ~18.5GB | Half the steps ~13.5h instead of ~29h |
Key insight: at seq_len=8192, the 3090's 35.6 bf16 TFLOPS is the hard ceiling. torch.compile couldn't speed up the attention bottleneck (FA2 kernels are opaque to Dynamo), but it unexpectedly halved activation memory by fusing surrounding ops, enabling larger batch sizes.
### Cloud Alternative: AWS g7e.2xlarge
For faster turnaround, an AWS g7e.2xlarge instance (NVIDIA RTX PRO 6000 Blackwell Server Edition, 96GB VRAM, ~236 bf16 TFLOPS) could complete DAPT significantly faster:
| | RTX 3090 (local) | RTX PRO 6000 (g7e.2xlarge) |
|--|--|--|
| bf16 TFLOPS | 71 | ~236 (3.3x) |
| VRAM | 24 GB | 96 GB |
| Gradient checkpointing | Required | Not needed (1.33x speedup) |
| Max batch size | 4 | 16+ |
| Estimated s/step | ~27s | ~6.5-7s |
| **500M tokens** | **~13.5h, ~$1.50 electricity** | **~3.7h, ~$4-5 spot** |
| **1B tokens** | **~29h, ~$3 electricity** | **~7.3h, ~$9 spot** |
The 96GB VRAM allows dropping gradient checkpointing entirely (eliminating activation recomputation) and running batch=16 with grad_accum=2 for the same effective batch of 32. Combined with the 3.3x raw TFLOPS advantage, the estimated speedup is ~4x.
The g6e.2xlarge (NVIDIA L40S, 48GB, ~181 bf16 TFLOPS) is a cheaper alternative at $2.24/hr but slower (~5.6h for 500M tokens). H100 instances (p5) are overkill for a 396M parameter model.
### Epoch Decision Justification
We train for 1 epoch (single pass over the corpus), following the empirical consensus:
- **Gururangan et al. (2020), "Don't Stop Pretraining" (ACL 2020):** Trained DAPT for "12.5K steps, which amounts to a single pass on each domain dataset" across corpora ranging from 2-8B tokens. A single pass was sufficient for consistent downstream gains across all four domains and eight tasks.
- **Ponnock (2025), "The Data Efficiency Frontier of Financial Foundation Models" (arXiv:2512.12384):** Found that SEC-specific DAPT exhibits diminishing marginal returns beyond ~250M tokens within a single epoch: "Both models exhibit their largest improvements in the early stages of continued pretraining: loss drops noticeably between 50M and 200M tokens, after which the rate of improvement slows." Our ~1B token corpus is already well past the diminishing-returns threshold.
Additional epochs risk overfitting to the domain corpus without proportional downstream benefit, while general-domain capability remains stable through a single pass.
### Sequence Length Decision
ModernBERT was pre-trained with 8192-token context. We match this during DAPT to ensure all positional embedding and attention weights receive gradient updates. At seq_len=2048, the weights for positions 2048-8191 would receive no updates during DAPT.
The tradeoff is memory: batch_size drops from 4 (at 2048) to 1 (at 8192), compensated by gradient_accumulation=32 to maintain effective batch size of 32. Training time is comparable because 4x fewer steps offset the slower per-step time.
For our downstream task (paragraph classification at ~50-400 tokens), the long-context benefit is modest the primary DAPT benefit is vocabulary and domain language patterns, which transfer at any sequence length. But there is no cost to using 8192, so we preserve the model's full capability.
## Step 1: DAPT
### Command
```bash
cd python
bun run py:train dapt --config configs/dapt/modernbert.yaml
```
Equivalent to: `uv run main.py dapt --config configs/dapt/modernbert.yaml`
### What happens
1. Loads ModernBERT-large from HuggingFace (cached after first download)
2. Loads 14,756 docs from `data/dapt-corpus/`, filters 188 < 10K chars
3. Tokenizes all text, concatenates, chunks into ~136K sequences of 8192 tokens
4. Splits 2% validation (~2,700 sequences), 98% train (~133K sequences)
5. Trains 1 epoch of MLM with 30% masking, bf16, gradient checkpointing
6. ~4,257 steps total, logging every 50, checkpoint+eval every 1,000
7. Saves final model + tokenizer to `checkpoints/dapt/modernbert-large/final/`
8. Reports final eval loss and perplexity
### Expected duration
~4-8 hours on RTX 3090 (depends on actual seconds/step at 8192 with gradient checkpointing).
### Resume if interrupted
HuggingFace Trainer auto-saves checkpoints every 1,000 steps. Re-run the same command it detects existing checkpoints and resumes automatically.
### Output
```
checkpoints/dapt/modernbert-large/
checkpoint-1000/
checkpoint-2000/
checkpoint-3000/
final/ <- final model + tokenizer
config.json
model.safetensors
tokenizer.json
...
```
## Step 2: TAPT
After DAPT completes, continue MLM on the 72K Item 1C paragraphs using the DAPT checkpoint.
### Command
```bash
cd python
bun run py:train dapt --config configs/tapt/modernbert.yaml
```
Equivalent to: `uv run main.py dapt --config configs/tapt/modernbert.yaml`
### TAPT Configuration
**Config file:** `python/configs/tapt/modernbert.yaml`
The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) 50x smaller than DAPT. This changes several training decisions:
| Parameter | Value | vs. DAPT | Rationale |
|-----------|-------|----------|-----------|
| `max_seq_length` | 512 | 8192 512 | Data-driven: paragraphs average 127 tokens (P99=386, 99.6% fit in 512). 8192 would be 98.5% padding. |
| `num_train_epochs` | 5 | 1 5 | Match total token exposure: 5 × 10M = 50M upper bound of Gururangan et al. (2020) TAPT exposure. |
| `whole_word_mask` | true | false true | Mask entire words, not subword pieces. Model knows subword composition from DAPT; TAPT focuses on domain-specific whole words ("CISO", "materiality", "tabletop"). |
| `per_device_train_batch_size` | 32 | 4 32 | Short sequences free VRAM. Tested: 22.7 GB peak with torch.compile at batch=32 (OOM at 48). |
| `gradient_accumulation_steps` | 1 | 8 1 | Effective batch = 32 in both cases. |
| `gradient_checkpointing` | false | true false | Not needed at 512 seq_len. Would add 30-40% overhead for no benefit. |
| `save_strategy` | epoch | steps epoch | Checkpoint + evaluate after each of 5 epochs. |
| `validation_split` | 0.05 | 0.02 0.05 | Larger val split for 50x smaller dataset need stable eval loss. |
### Sequence Length Decision (512 vs. 8192)
DAPT used 8192 to match ModernBERT's pre-training context and exercise all positional embeddings and global attention layers. TAPT uses 512 because:
1. **The data is 512.** Paragraphs average 127 tokens (P99=386). There is no long-range structure to learn.
2. **50M tokens won't cause forgetting.** TAPT's 50M token-exposures is 0.0025% of ModernBERT's ~2T pre-training and 10% of DAPT. The model's long-range patterns are deeply established.
3. **RoPE is position-independent.** Positions 0-511 compute identically at any max_length. Positions 512-8191 remain untouched from DAPT.
### Whole-Word Masking Implementation
Whole-word masking requires `offset_mapping` from the tokenizer to determine word boundaries. This is incompatible with DAPT's concatenate-and-chunk approach (which destroys offset_mapping by merging documents). For TAPT, each paragraph is tokenized individually with truncation, preserving `offset_mapping`. The data collator handles dynamic padding per batch.
Note: with `whole_word_mask=True`, the HuggingFace collator automatically disables random token replacement (`mask_replace_prob=1.0`). All masked positions receive the `[MASK]` token.
### What happens
1. Loads the DAPT checkpoint from `checkpoints/dapt/modernbert-large/final/`
2. Loads 72,045 patched paragraphs from `paragraphs-clean.patched.jsonl`
3. Tokenizes each paragraph individually (truncation at 512, with offset_mapping for whole-word masking)
4. Splits 5% validation (~3,602 paragraphs), 95% train (~68,443 paragraphs)
5. Trains 5 epochs of MLM with whole-word masking different masking each epoch
6. Saves checkpoint after each epoch; saves final model to `checkpoints/tapt/modernbert-large/final/`
### Expected duration
~2,138 steps/epoch × 5 epochs = ~10,700 total steps. At seq_len=512 on the 3090 (~0.5-1s/step), estimated 1-3 hours.
### Resume if interrupted
Re-run the same command it detects existing checkpoints and resumes automatically.
### Output
```
checkpoints/tapt/modernbert-large/
checkpoint-epoch-1/
checkpoint-epoch-2/
...
final/ <- SEC-cyBERT-large (DAPT + TAPT)
```
## Step 3: Ablation Checkpoints
The training pipeline produces clean ablation rows for the paper:
| Model | Checkpoint | Description |
|-------|-----------|-------------|
| Base | `answerdotai/ModernBERT-large` | Off-the-shelf, no domain adaptation |
| +DAPT | `checkpoints/dapt/modernbert-large/final` | After domain pre-training on 14.5K filings |
| +DAPT+TAPT | `checkpoints/tapt/modernbert-large/final` | After task pre-training on 72K paragraphs |
Each checkpoint can be independently fine-tuned with classification heads to isolate the contribution of each pre-training stage.
## Monitoring
During training, the Trainer logs to stderr every 50 steps:
- `loss` training MLM loss (cross-entropy on masked tokens)
- `learning_rate` current LR (ramps up during warmup, then decays)
- `epoch` progress through the epoch
Every 1,000 steps, it also reports:
- `eval_loss` validation MLM loss
- Perplexity can be computed as `2^eval_loss`
**Expected loss trajectory:**
- Starting loss: ~0.80 (the model already knows English it's only learning SEC-specific patterns)
- Final loss: ~0.55-0.65 (a gentle downward drift, not a dramatic curve)
- For reference, a randomly initialized model would start at ~10.8 (ln(50280 vocab size))
**What to watch for:**
- `grad_norm` should stay small (0.05-0.15). Healthy = gentle weight updates. Spikes >1.0 = LR too high.
- `learning_rate` ramps up during warmup (first 5% of steps ≈ 93 steps), then decays.
- Loss going *up* after warmup → LR too high or data issue
- Loss stuck flat after 500+ steps → model isn't learning, LR too low
- Loss < 0.3 possible overfitting (unlikely in 1 epoch of 500M tokens)
**The DAPT loss number itself matters less than the downstream impact.** DAPT teaches the model SEC vocabulary and co-occurrence patterns ("NIST CSF", "tabletop exercise", "materiality assessment"). Whether the final loss is 0.55 or 0.65 is less important than whether the [CLS] embeddings produce better classification after fine-tuning. The real evaluation is the ablation: base vs +DAPT vs +DAPT+TAPT.
## Artifacts
| File | Purpose |
|------|---------|
| `python/configs/dapt/modernbert.yaml` | DAPT config |
| `python/configs/tapt/modernbert.yaml` | TAPT config |
| `python/configs/dapt/neobert.yaml` | NeoBERT config (if needed) |
| `python/main.py` | CLI entrypoint |
| `python/src/dapt/train.py` | Training loop |
| `python/src/data/corpus.py` | Corpus loading + tokenization |
| `python/src/common/config.py` | Typed YAML config |
| `ts/scripts/dapt-corpus-prep.ts` | Corpus preparation from HTML |
| `ts/scripts/dapt-corpus-analytics.ts` | Corpus analytics |
| `data/dapt-corpus/shard-*.jsonl` | Cleaned corpus (15 shards) |
## References
- Warner, B., Clavié, B., Soldaini, L., et al. (2024). "Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Fine-tuning and Inference." arXiv:2412.13663. ModernBERT architecture, pre-training config (30% MLM, StableAdamW, weight_decay 1e-5, alternating attention, 8192 context), pre-decay checkpoint release.
- Gururangan, S., Marasovic, A., Swayamdipta, S., Lo, K., Beltagy, I., Downey, D., & Smith, N.A. (2020). "Don't Stop Pretraining: Adapt Language Models to Domains and Tasks." ACL 2020, pp. 8342-8360. Single-epoch DAPT on 2-8B token corpora, TAPT at 100 epochs on task data.
- Ponnock, J. (2025). "The Data Efficiency Frontier of Financial Foundation Models: Scaling Laws from Continued Pretraining." arXiv:2512.12384. Johns Hopkins University. SEC filing DAPT shows diminishing returns beyond ~250M tokens, shallow power-law scaling.
- Sounack, T., et al. (2025). "BioClinical ModernBERT." arXiv:2506.10896. DAPT on 160B clinical tokens using ModernBERT, lr=3e-4, weight_decay=1e-5, pre-decay checkpoint, sequence packing.
- Luo, Z., et al. (2025). "Patent ModernBERT: A Pretrained Language Model for Intellectual Property." arXiv:2509.14926. DAPT on 31.6B patent tokens, lr=3e-4, StableAdamW, weight_decay=1e-5.
- Dao, T. (2024). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR 2024. O(n) memory attention, critical for 8192 seq_len training on consumer GPUs.
- AnswerDotAI/ModernBERT#172 Known FA2 memory anomaly during MLM training, resolved by combining torch.compile + gradient_checkpointing.