decisions for TAPT

This commit is contained in:
Joey Eamigh 2026-03-30 00:33:28 -04:00
parent fe5155ab6d
commit 313e14fb96
No known key found for this signature in database
GPG Key ID: CE8C05DFFC53C9CB
2 changed files with 38 additions and 6 deletions

View File

@ -178,17 +178,25 @@ bun run py:train dapt --config configs/dapt/modernbert.yaml \
--stage tapt
```
### TAPT Configuration Differences
The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than DAPT. This changes the training dynamics:
**Epochs: 5-10 (not 1).** Gururangan et al. (2020) ran TAPT for 100 epochs, but their corpora were 50-500K tokens — 20-200x smaller than ours. We match on total token exposure (~50-100M) rather than epoch count: 5-10 epochs × 10M tokens = 50-100M tokens, comparable to the upper end of their TAPT exposure.
**Whole-word masking (optional).** `DataCollatorForWholeWordMasking` masks entire words instead of random subword tokens, naturally emphasizing content words over fragments. Worth trying for TAPT since the model already knows subword patterns from DAPT — TAPT should focus on domain-specific whole words ("CISO", "materiality", "tabletop"). One-line change in `train.py`.
### What happens
1. Loads the DAPT checkpoint (not the base ModernBERT)
2. Loads 72,045 patched paragraphs from `paragraphs-clean.patched.jsonl`
3. Tokenizes, concatenates, chunks (much smaller corpus — ~10M tokens)
4. Trains MLM with same hyperparameters
3. Tokenizes, concatenates, chunks (~10M tokens → ~1,220 sequences at 8192)
4. Trains 5-10 epochs of MLM with different masking each epoch
5. Saves to `checkpoints/tapt/modernbert-large/final/`
### Expected duration
~2-3 hours (much smaller corpus).
~1-2 hours (5-10 epochs on ~1,220 sequences = ~190-380 optimizer steps).
### Output
@ -220,10 +228,19 @@ Every 1,000 steps, it also reports:
- `eval_loss` — validation MLM loss
- Perplexity can be computed as `2^eval_loss`
**Expected loss trajectory:**
- Starting loss: ~0.80 (the model already knows English — it's only learning SEC-specific patterns)
- Final loss: ~0.55-0.65 (a gentle downward drift, not a dramatic curve)
- For reference, a randomly initialized model would start at ~10.8 (ln(50280 vocab size))
**What to watch for:**
- Training loss should decrease steadily from ~2.5-3.0 to ~1.5-2.0
- Eval loss should track training loss (if eval loss diverges upward, the model is overfitting — but this is unlikely in 1 epoch)
- If loss spikes or goes to NaN, the learning rate may be too high
- `grad_norm` should stay small (0.05-0.15). Healthy = gentle weight updates. Spikes >1.0 = LR too high.
- `learning_rate` ramps up during warmup (first 5% of steps ≈ 93 steps), then decays.
- Loss going *up* after warmup → LR too high or data issue
- Loss stuck flat after 500+ steps → model isn't learning, LR too low
- Loss < 0.3 possible overfitting (unlikely in 1 epoch of 500M tokens)
**The DAPT loss number itself matters less than the downstream impact.** DAPT teaches the model SEC vocabulary and co-occurrence patterns ("NIST CSF", "tabletop exercise", "materiality assessment"). Whether the final loss is 0.55 or 0.65 is less important than whether the [CLS] embeddings produce better classification after fine-tuning. The real evaluation is the ablation: base vs +DAPT vs +DAPT+TAPT.
## Artifacts

View File

@ -741,6 +741,21 @@ Initial training ran at ~47s/step (projected ~56 hours for 1B tokens). Through i
Full procedure, optimization journey, and cloud cost analysis in `docs/DAPT-PROCEDURE.md`.
### Early Training Results
First eval at step 54 (~3% through):
- **Loss: 0.80** — the model already knows English, so loss starts low. For comparison, a randomly initialized model would start at ~10.8. The loss reflects the model's ability to predict masked SEC filing tokens from context.
- **grad_norm: 0.066** — very small, indicating gentle weight updates. Healthy sign.
- **learning_rate: 2.66e-5** — still in warmup phase (first 93 steps, 5% of training).
Expected trajectory: loss drifts from ~0.80 to ~0.55-0.65 over the run. This is not the dramatic loss curve of fine-tuning — DAPT is nudging a capable language model toward SEC-specific vocabulary and co-occurrence patterns, not teaching it a new task from scratch.
### TAPT Planning
The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than the DAPT corpus. Following Gururangan et al. (2020), we run multiple epochs to compensate, but match on total token exposure rather than blindly copying their 100-epoch setting (which was calibrated for 50-500K token corpora, 20-200x smaller than ours). 5-10 epochs × 10M = 50-100M total tokens, matching the upper end of their TAPT exposure. Estimated training time: ~1-2 hours.
One planned experiment: **whole-word masking** (`DataCollatorForWholeWordMasking`) for TAPT, which masks entire words rather than random subword tokens. Since the model already knows subword patterns from DAPT, TAPT should focus on domain-specific whole words. This is a one-line change.
---
## Cost and Time Ledger