SEC-cyBERT/docs/DAPT-PROCEDURE.md
2026-03-30 19:46:20 -04:00

18 KiB
Raw Blame History

DAPT/TAPT Training Procedure

Date: 2026-03-29 Hardware: NVIDIA RTX 3090 (24GB VRAM), CUDA driver 13.2, PyTorch 2.11.0+cu130


Pre-flight Checklist

Check Status
PyTorch 2.11.0+cu130, CUDA available Verified
RTX 3090, 25.3 GB VRAM, bf16 supported Verified
Flash Attention 2 (flash-attn 2.6.3+cu130torch2.11) Verified
ModernBERT-large loads: 396M params, max_position_embeddings=8192 Verified
Corpus: 14,756 docs, ~1.06B tokens, 15 shards Verified
After <10K filter: 14,568 docs, ~1.056B tokens (0.027% loss) Verified
Tokenize+chunk pipeline: 10 docs -> 85 sequences of 8192 tokens Verified
Config: seq_len=8192, batch=4, grad_accum=8, 1 epoch, lr=5e-5, mlm=0.30 Set

DAPT Corpus Summary

  • 14,568 documents (after filtering 188 cover pages <10K chars)
  • ~1.056 billion tokens (ModernBERT tokenizer, 4.72 chars/token)
  • ~136K training sequences at seq_len=8192
  • Median document: ~73K tokens (347K chars) — 90.6% of docs exceed 8192 tokens
  • Cleaned: XBRL data blobs stripped, exhibit listings stripped, URLs removed, F-N page numbers removed
  • Source: 14,759 cached 10-K HTML filings, FY2023-FY2025, processed by ts/scripts/dapt-corpus-prep.ts

Training Configuration

Config file: python/configs/dapt/modernbert.yaml

Parameter Value Rationale
max_seq_length 8192 Match ModernBERT's pre-training context length
max_tokens 500,000,000 Subsample to newest 500M tokens (Ponnock 2025: diminishing returns past 250M)
per_device_train_batch_size 4 Maximum for 24GB VRAM with FA2 + torch.compile + grad checkpointing
gradient_accumulation_steps 8 Effective batch size = 32
num_train_epochs 1 Single pass per Gururangan et al. (2020) and Ponnock (2025)
learning_rate 5e-5 Conservative for post-decay checkpoint (see note below)
mlm_probability 0.30 Matches ModernBERT's pre-training masking rate (Warner et al., 2024)
weight_decay 1e-5 Matches ModernBERT pre-training; used by BioClinical-ModernBERT and Patent-ModernBERT
warmup_ratio 0.05 ~213 warmup steps
gradient_checkpointing true Required for 8192 seq_len on 24GB
gradient_checkpointing_kwargs use_reentrant: False Required for torch.compile compatibility
bf16 true Native RTX 3090 support
torch_compile true 20-40% speedup; fixes FA2 memory anomaly on ModernBERT (AnswerDotAI/ModernBERT#172)
optim adamw_torch_fused Fused optimizer kernel, 5-10% speedup
tf32 true Free speedup on remaining fp32 ops (Ampere architecture)
attn_implementation flash_attention_2 Flash Attention 2 via flash-attn package; falls back to SDPA if unavailable
save_steps 1000 Checkpoint every ~1000 steps
eval_steps 1000 Evaluate every ~1000 steps
save_total_limit 3 Keep last 3 checkpoints

Hyperparameter Rationale

Learning rate (5e-5): Conservative because we start from the published ModernBERT-large checkpoint, which is the post-decay final model. BioClinical-ModernBERT (Sounack et al., 2025) and Patent-ModernBERT (Luo et al., 2025) used 3e-4 but started from pre-decay stable-phase checkpoints. The ModernBERT authors released training checkpoints (answerdotai/ModernBERT-large-training-checkpoints) and noted: "Anyone is free to restart training from any of our pre-decay checkpoints, and perform annealing on domain-appropriate data" (Warner et al., 2024). Starting from the post-decay model with a high LR risks destabilizing learned representations.

Weight decay (1e-5): The original ModernBERT pre-training used 1e-5 weight decay. Both BioClinical-ModernBERT and Patent-ModernBERT preserved this value. The commonly-used 0.01 is a BERT/RoBERTa default that doesn't apply here.

Performance Optimizations

Flash Attention 2 reduces attention from O(n^2) to O(n) memory and provides ~2-4x throughput improvement at seq_len=8192. ModernBERT was designed with FA2 support, including alternating attention: every 3rd layer uses global attention (full 8192-token context with RoPE theta 160K), while other layers use 128-token local sliding window attention (RoPE theta 10K). This dramatically reduces the O(n^2) cost (Warner et al., 2024).

torch.compile JIT-compiles the model into fused CUDA kernels via the Inductor backend. On ModernBERT specifically, it also resolves a known memory anomaly where FA2 uses ~88% GPU memory vs ~48% for SDPA during MLM training (AnswerDotAI/ModernBERT#172). The fix is enabling both torch_compile=True and gradient_checkpointing=True together.

Fused AdamW merges parameter updates into a single CUDA kernel, reducing kernel launch overhead across 396M parameters.

Corpus subsampling to 500M tokens (from 1.06B) halves training time. The subsampler takes from the tail of the corpus (newest filings, FY2024-2025) since the accession-sorted shards are roughly chronological. Ponnock (2025) showed the largest DAPT gains occur in the first 200-250M tokens with shallow power-law scaling thereafter — 500M provides a comfortable margin.

Optimization Journey

The path to the final config involved iterative experimentation on the RTX 3090:

Change s/step VRAM Outcome
Baseline (PyTorch 2.10, no FA2, batch=1) ~47s ~16GB Compute-bound, attention is O(n²)
+ Flash Attention 2 (PyTorch 2.11+cu130) ~27s ~16GB FA2 halves attention compute time
+ batch=2 (grad_accum 32→16) ~27s ~18.2GB GPU already saturated at seq_len=8192 — bigger batch doesn't help s/step
+ torch.compile (with FA2) ~27s ~11.9GB Graph breaks at FA2 kernels prevent compute speedup, but fusing non-attention ops halved activation memory
+ batch=4 (using compile's memory savings) ~27s ~18.5GB Same s/step, but 4x fewer grad_accum micro-steps reduces overhead marginally
+ 500M token subsample ~27s ~18.5GB Half the steps → ~13.5h instead of ~29h

Key insight: at seq_len=8192, the 3090's 35.6 bf16 TFLOPS is the hard ceiling. torch.compile couldn't speed up the attention bottleneck (FA2 kernels are opaque to Dynamo), but it unexpectedly halved activation memory by fusing surrounding ops, enabling larger batch sizes.

Cloud Alternative: AWS g7e.2xlarge

For faster turnaround, an AWS g7e.2xlarge instance (NVIDIA RTX PRO 6000 Blackwell Server Edition, 96GB VRAM, ~236 bf16 TFLOPS) could complete DAPT significantly faster:

RTX 3090 (local) RTX PRO 6000 (g7e.2xlarge)
bf16 TFLOPS 71 ~236 (3.3x)
VRAM 24 GB 96 GB
Gradient checkpointing Required Not needed (1.33x speedup)
Max batch size 4 16+
Estimated s/step ~27s ~6.5-7s
500M tokens ~13.5h, ~$1.50 electricity ~3.7h, ~$4-5 spot
1B tokens ~29h, ~$3 electricity ~7.3h, ~$9 spot

The 96GB VRAM allows dropping gradient checkpointing entirely (eliminating activation recomputation) and running batch=16 with grad_accum=2 for the same effective batch of 32. Combined with the 3.3x raw TFLOPS advantage, the estimated speedup is ~4x.

The g6e.2xlarge (NVIDIA L40S, 48GB, ~181 bf16 TFLOPS) is a cheaper alternative at $2.24/hr but slower (~5.6h for 500M tokens). H100 instances (p5) are overkill for a 396M parameter model.

Epoch Decision Justification

We train for 1 epoch (single pass over the corpus), following the empirical consensus:

  • Gururangan et al. (2020), "Don't Stop Pretraining" (ACL 2020): Trained DAPT for "12.5K steps, which amounts to a single pass on each domain dataset" across corpora ranging from 2-8B tokens. A single pass was sufficient for consistent downstream gains across all four domains and eight tasks.

  • Ponnock (2025), "The Data Efficiency Frontier of Financial Foundation Models" (arXiv:2512.12384): Found that SEC-specific DAPT exhibits diminishing marginal returns beyond ~250M tokens within a single epoch: "Both models exhibit their largest improvements in the early stages of continued pretraining: loss drops noticeably between 50M and 200M tokens, after which the rate of improvement slows." Our ~1B token corpus is already well past the diminishing-returns threshold.

Additional epochs risk overfitting to the domain corpus without proportional downstream benefit, while general-domain capability remains stable through a single pass.

Sequence Length Decision

ModernBERT was pre-trained with 8192-token context. We match this during DAPT to ensure all positional embedding and attention weights receive gradient updates. At seq_len=2048, the weights for positions 2048-8191 would receive no updates during DAPT.

The tradeoff is memory: batch_size drops from 4 (at 2048) to 1 (at 8192), compensated by gradient_accumulation=32 to maintain effective batch size of 32. Training time is comparable because 4x fewer steps offset the slower per-step time.

For our downstream task (paragraph classification at ~50-400 tokens), the long-context benefit is modest — the primary DAPT benefit is vocabulary and domain language patterns, which transfer at any sequence length. But there is no cost to using 8192, so we preserve the model's full capability.

Step 1: DAPT

Command

cd python
bun run py:train dapt --config configs/dapt/modernbert.yaml

Equivalent to: uv run main.py dapt --config configs/dapt/modernbert.yaml

What happens

  1. Loads ModernBERT-large from HuggingFace (cached after first download)
  2. Loads 14,756 docs from data/dapt-corpus/, filters 188 < 10K chars
  3. Tokenizes all text, concatenates, chunks into ~136K sequences of 8192 tokens
  4. Splits 2% validation (~2,700 sequences), 98% train (~133K sequences)
  5. Trains 1 epoch of MLM with 30% masking, bf16, gradient checkpointing
  6. ~4,257 steps total, logging every 50, checkpoint+eval every 1,000
  7. Saves final model + tokenizer to checkpoints/dapt/modernbert-large/final/
  8. Reports final eval loss and perplexity

Expected duration

~4-8 hours on RTX 3090 (depends on actual seconds/step at 8192 with gradient checkpointing).

Resume if interrupted

HuggingFace Trainer auto-saves checkpoints every 1,000 steps. Re-run the same command — it detects existing checkpoints and resumes automatically.

Output

checkpoints/dapt/modernbert-large/
  checkpoint-1000/
  checkpoint-2000/
  checkpoint-3000/
  final/                  <- final model + tokenizer
    config.json
    model.safetensors
    tokenizer.json
    ...

Step 2: TAPT

After DAPT completes, continue MLM on the 72K Item 1C paragraphs using the DAPT checkpoint.

Command

cd python
bun run py:train dapt --config configs/tapt/modernbert.yaml

Equivalent to: uv run main.py dapt --config configs/tapt/modernbert.yaml

TAPT Configuration

Config file: python/configs/tapt/modernbert.yaml

The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than DAPT. This changes several training decisions:

Parameter Value vs. DAPT Rationale
max_seq_length 512 8192 → 512 Data-driven: paragraphs average 127 tokens (P99=386, 99.6% fit in 512). 8192 would be 98.5% padding.
num_train_epochs 5 1 → 5 Match total token exposure: 5 × 10M = 50M ≈ upper bound of Gururangan et al. (2020) TAPT exposure.
whole_word_mask true false → true Mask entire words, not subword pieces. Model knows subword composition from DAPT; TAPT focuses on domain-specific whole words ("CISO", "materiality", "tabletop").
per_device_train_batch_size 32 4 → 32 Short sequences free VRAM. Tested: 22.7 GB peak with torch.compile at batch=32 (OOM at 48).
gradient_accumulation_steps 1 8 → 1 Effective batch = 32 in both cases.
gradient_checkpointing false true → false Not needed at 512 seq_len. Would add 30-40% overhead for no benefit.
save_strategy epoch steps → epoch Checkpoint + evaluate after each of 5 epochs.
validation_split 0.05 0.02 → 0.05 Larger val split for 50x smaller dataset — need stable eval loss.

Sequence Length Decision (512 vs. 8192)

DAPT used 8192 to match ModernBERT's pre-training context and exercise all positional embeddings and global attention layers. TAPT uses 512 because:

  1. The data is 512. Paragraphs average 127 tokens (P99=386). There is no long-range structure to learn.
  2. 50M tokens won't cause forgetting. TAPT's 50M token-exposures is 0.0025% of ModernBERT's ~2T pre-training and 10% of DAPT. The model's long-range patterns are deeply established.
  3. RoPE is position-independent. Positions 0-511 compute identically at any max_length. Positions 512-8191 remain untouched from DAPT.

Whole-Word Masking Implementation

Whole-word masking requires offset_mapping from the tokenizer to determine word boundaries. This is incompatible with DAPT's concatenate-and-chunk approach (which destroys offset_mapping by merging documents). For TAPT, each paragraph is tokenized individually with truncation, preserving offset_mapping. The data collator handles dynamic padding per batch.

Note: with whole_word_mask=True, the HuggingFace collator automatically disables random token replacement (mask_replace_prob=1.0). All masked positions receive the [MASK] token.

What happens

  1. Loads the DAPT checkpoint from checkpoints/dapt/modernbert-large/final/
  2. Loads 72,045 patched paragraphs from paragraphs-clean.patched.jsonl
  3. Tokenizes each paragraph individually (truncation at 512, with offset_mapping for whole-word masking)
  4. Splits 5% validation (~3,602 paragraphs), 95% train (~68,443 paragraphs)
  5. Trains 5 epochs of MLM with whole-word masking — different masking each epoch
  6. Saves checkpoint after each epoch; saves final model to checkpoints/tapt/modernbert-large/final/

Expected duration

~2,138 steps/epoch × 5 epochs = ~10,700 total steps. At seq_len=512 on the 3090 (~0.5-1s/step), estimated 1-3 hours.

Resume if interrupted

Re-run the same command — it detects existing checkpoints and resumes automatically.

Output

checkpoints/tapt/modernbert-large/
  checkpoint-epoch-1/
  checkpoint-epoch-2/
  ...
  final/                  <- SEC-cyBERT-large (DAPT + TAPT)

Step 3: Ablation Checkpoints

The training pipeline produces clean ablation rows for the paper:

Model Checkpoint Description
Base answerdotai/ModernBERT-large Off-the-shelf, no domain adaptation
+DAPT checkpoints/dapt/modernbert-large/final After domain pre-training on 14.5K filings
+DAPT+TAPT checkpoints/tapt/modernbert-large/final After task pre-training on 72K paragraphs

Each checkpoint can be independently fine-tuned with classification heads to isolate the contribution of each pre-training stage.

Monitoring

During training, the Trainer logs to stderr every 50 steps:

  • loss — training MLM loss (cross-entropy on masked tokens)
  • learning_rate — current LR (ramps up during warmup, then decays)
  • epoch — progress through the epoch

Every 1,000 steps, it also reports:

  • eval_loss — validation MLM loss
  • Perplexity can be computed as 2^eval_loss

Expected loss trajectory:

  • Starting loss: ~0.80 (the model already knows English — it's only learning SEC-specific patterns)
  • Final loss: ~0.55-0.65 (a gentle downward drift, not a dramatic curve)
  • For reference, a randomly initialized model would start at ~10.8 (ln(50280 vocab size))

What to watch for:

  • grad_norm should stay small (0.05-0.15). Healthy = gentle weight updates. Spikes >1.0 = LR too high.
  • learning_rate ramps up during warmup (first 5% of steps ≈ 93 steps), then decays.
  • Loss going up after warmup → LR too high or data issue
  • Loss stuck flat after 500+ steps → model isn't learning, LR too low
  • Loss < 0.3 → possible overfitting (unlikely in 1 epoch of 500M tokens)

The DAPT loss number itself matters less than the downstream impact. DAPT teaches the model SEC vocabulary and co-occurrence patterns ("NIST CSF", "tabletop exercise", "materiality assessment"). Whether the final loss is 0.55 or 0.65 is less important than whether the [CLS] embeddings produce better classification after fine-tuning. The real evaluation is the ablation: base vs +DAPT vs +DAPT+TAPT.

Artifacts

File Purpose
python/configs/dapt/modernbert.yaml DAPT config
python/configs/tapt/modernbert.yaml TAPT config
python/configs/dapt/neobert.yaml NeoBERT config (if needed)
python/main.py CLI entrypoint
python/src/dapt/train.py Training loop
python/src/data/corpus.py Corpus loading + tokenization
python/src/common/config.py Typed YAML config
ts/scripts/dapt-corpus-prep.ts Corpus preparation from HTML
ts/scripts/dapt-corpus-analytics.ts Corpus analytics
data/dapt-corpus/shard-*.jsonl Cleaned corpus (15 shards)

References

  • Warner, B., Clavié, B., Soldaini, L., et al. (2024). "Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Fine-tuning and Inference." arXiv:2412.13663. — ModernBERT architecture, pre-training config (30% MLM, StableAdamW, weight_decay 1e-5, alternating attention, 8192 context), pre-decay checkpoint release.
  • Gururangan, S., Marasovic, A., Swayamdipta, S., Lo, K., Beltagy, I., Downey, D., & Smith, N.A. (2020). "Don't Stop Pretraining: Adapt Language Models to Domains and Tasks." ACL 2020, pp. 8342-8360. — Single-epoch DAPT on 2-8B token corpora, TAPT at 100 epochs on task data.
  • Ponnock, J. (2025). "The Data Efficiency Frontier of Financial Foundation Models: Scaling Laws from Continued Pretraining." arXiv:2512.12384. Johns Hopkins University. — SEC filing DAPT shows diminishing returns beyond ~250M tokens, shallow power-law scaling.
  • Sounack, T., et al. (2025). "BioClinical ModernBERT." arXiv:2506.10896. — DAPT on 160B clinical tokens using ModernBERT, lr=3e-4, weight_decay=1e-5, pre-decay checkpoint, sequence packing.
  • Luo, Z., et al. (2025). "Patent ModernBERT: A Pretrained Language Model for Intellectual Property." arXiv:2509.14926. — DAPT on 31.6B patent tokens, lr=3e-4, StableAdamW, weight_decay=1e-5.
  • Dao, T. (2024). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR 2024. — O(n) memory attention, critical for 8192 seq_len training on consumer GPUs.
  • AnswerDotAI/ModernBERT#172 — Known FA2 memory anomaly during MLM training, resolved by combining torch.compile + gradient_checkpointing.