tapt setup
This commit is contained in:
parent
c0273c9e2e
commit
3292980d33
@ -166,42 +166,72 @@ checkpoints/dapt/modernbert-large/
|
|||||||
|
|
||||||
## Step 2: TAPT
|
## Step 2: TAPT
|
||||||
|
|
||||||
After DAPT completes, continue MLM on the 72K Item 1C paragraphs specifically.
|
After DAPT completes, continue MLM on the 72K Item 1C paragraphs using the DAPT checkpoint.
|
||||||
|
|
||||||
### Command
|
### Command
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
bun run py:train dapt --config configs/dapt/modernbert.yaml \
|
cd python
|
||||||
--model-path ../checkpoints/dapt/modernbert-large/final \
|
bun run py:train dapt --config configs/tapt/modernbert.yaml
|
||||||
--data-path ../data/paragraphs/paragraphs-clean.patched.jsonl \
|
|
||||||
--output-dir ../checkpoints/tapt/modernbert-large \
|
|
||||||
--stage tapt
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### TAPT Configuration Differences
|
Equivalent to: `uv run main.py dapt --config configs/tapt/modernbert.yaml`
|
||||||
|
|
||||||
The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than DAPT. This changes the training dynamics:
|
### TAPT Configuration
|
||||||
|
|
||||||
**Epochs: 5-10 (not 1).** Gururangan et al. (2020) ran TAPT for 100 epochs, but their corpora were 50-500K tokens — 20-200x smaller than ours. We match on total token exposure (~50-100M) rather than epoch count: 5-10 epochs × 10M tokens = 50-100M tokens, comparable to the upper end of their TAPT exposure.
|
**Config file:** `python/configs/tapt/modernbert.yaml`
|
||||||
|
|
||||||
**Whole-word masking (optional).** `DataCollatorForWholeWordMasking` masks entire words instead of random subword tokens, naturally emphasizing content words over fragments. Worth trying for TAPT since the model already knows subword patterns from DAPT — TAPT should focus on domain-specific whole words ("CISO", "materiality", "tabletop"). One-line change in `train.py`.
|
The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than DAPT. This changes several training decisions:
|
||||||
|
|
||||||
|
| Parameter | Value | vs. DAPT | Rationale |
|
||||||
|
|-----------|-------|----------|-----------|
|
||||||
|
| `max_seq_length` | 512 | 8192 → 512 | Data-driven: paragraphs average 127 tokens (P99=386, 99.6% fit in 512). 8192 would be 98.5% padding. |
|
||||||
|
| `num_train_epochs` | 5 | 1 → 5 | Match total token exposure: 5 × 10M = 50M ≈ upper bound of Gururangan et al. (2020) TAPT exposure. |
|
||||||
|
| `whole_word_mask` | true | false → true | Mask entire words, not subword pieces. Model knows subword composition from DAPT; TAPT focuses on domain-specific whole words ("CISO", "materiality", "tabletop"). |
|
||||||
|
| `per_device_train_batch_size` | 32 | 4 → 32 | Short sequences free VRAM. Tested: 22.7 GB peak with torch.compile at batch=32 (OOM at 48). |
|
||||||
|
| `gradient_accumulation_steps` | 1 | 8 → 1 | Effective batch = 32 in both cases. |
|
||||||
|
| `gradient_checkpointing` | false | true → false | Not needed at 512 seq_len. Would add 30-40% overhead for no benefit. |
|
||||||
|
| `save_strategy` | epoch | steps → epoch | Checkpoint + evaluate after each of 5 epochs. |
|
||||||
|
| `validation_split` | 0.05 | 0.02 → 0.05 | Larger val split for 50x smaller dataset — need stable eval loss. |
|
||||||
|
|
||||||
|
### Sequence Length Decision (512 vs. 8192)
|
||||||
|
|
||||||
|
DAPT used 8192 to match ModernBERT's pre-training context and exercise all positional embeddings and global attention layers. TAPT uses 512 because:
|
||||||
|
|
||||||
|
1. **The data is 512.** Paragraphs average 127 tokens (P99=386). There is no long-range structure to learn.
|
||||||
|
2. **50M tokens won't cause forgetting.** TAPT's 50M token-exposures is 0.0025% of ModernBERT's ~2T pre-training and 10% of DAPT. The model's long-range patterns are deeply established.
|
||||||
|
3. **RoPE is position-independent.** Positions 0-511 compute identically at any max_length. Positions 512-8191 remain untouched from DAPT.
|
||||||
|
|
||||||
|
### Whole-Word Masking Implementation
|
||||||
|
|
||||||
|
Whole-word masking requires `offset_mapping` from the tokenizer to determine word boundaries. This is incompatible with DAPT's concatenate-and-chunk approach (which destroys offset_mapping by merging documents). For TAPT, each paragraph is tokenized individually with truncation, preserving `offset_mapping`. The data collator handles dynamic padding per batch.
|
||||||
|
|
||||||
|
Note: with `whole_word_mask=True`, the HuggingFace collator automatically disables random token replacement (`mask_replace_prob=1.0`). All masked positions receive the `[MASK]` token.
|
||||||
|
|
||||||
### What happens
|
### What happens
|
||||||
|
|
||||||
1. Loads the DAPT checkpoint (not the base ModernBERT)
|
1. Loads the DAPT checkpoint from `checkpoints/dapt/modernbert-large/final/`
|
||||||
2. Loads 72,045 patched paragraphs from `paragraphs-clean.patched.jsonl`
|
2. Loads 72,045 patched paragraphs from `paragraphs-clean.patched.jsonl`
|
||||||
3. Tokenizes, concatenates, chunks (~10M tokens → ~1,220 sequences at 8192)
|
3. Tokenizes each paragraph individually (truncation at 512, with offset_mapping for whole-word masking)
|
||||||
4. Trains 5-10 epochs of MLM with different masking each epoch
|
4. Splits 5% validation (~3,602 paragraphs), 95% train (~68,443 paragraphs)
|
||||||
5. Saves to `checkpoints/tapt/modernbert-large/final/`
|
5. Trains 5 epochs of MLM with whole-word masking — different masking each epoch
|
||||||
|
6. Saves checkpoint after each epoch; saves final model to `checkpoints/tapt/modernbert-large/final/`
|
||||||
|
|
||||||
### Expected duration
|
### Expected duration
|
||||||
|
|
||||||
~1-2 hours (5-10 epochs on ~1,220 sequences = ~190-380 optimizer steps).
|
~2,138 steps/epoch × 5 epochs = ~10,700 total steps. At seq_len=512 on the 3090 (~0.5-1s/step), estimated 1-3 hours.
|
||||||
|
|
||||||
|
### Resume if interrupted
|
||||||
|
|
||||||
|
Re-run the same command — it detects existing checkpoints and resumes automatically.
|
||||||
|
|
||||||
### Output
|
### Output
|
||||||
|
|
||||||
```
|
```
|
||||||
checkpoints/tapt/modernbert-large/
|
checkpoints/tapt/modernbert-large/
|
||||||
|
checkpoint-epoch-1/
|
||||||
|
checkpoint-epoch-2/
|
||||||
|
...
|
||||||
final/ <- SEC-cyBERT-large (DAPT + TAPT)
|
final/ <- SEC-cyBERT-large (DAPT + TAPT)
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -247,6 +277,7 @@ Every 1,000 steps, it also reports:
|
|||||||
| File | Purpose |
|
| File | Purpose |
|
||||||
|------|---------|
|
|------|---------|
|
||||||
| `python/configs/dapt/modernbert.yaml` | DAPT config |
|
| `python/configs/dapt/modernbert.yaml` | DAPT config |
|
||||||
|
| `python/configs/tapt/modernbert.yaml` | TAPT config |
|
||||||
| `python/configs/dapt/neobert.yaml` | NeoBERT config (if needed) |
|
| `python/configs/dapt/neobert.yaml` | NeoBERT config (if needed) |
|
||||||
| `python/main.py` | CLI entrypoint |
|
| `python/main.py` | CLI entrypoint |
|
||||||
| `python/src/dapt/train.py` | Training loop |
|
| `python/src/dapt/train.py` | Training loop |
|
||||||
|
|||||||
@ -745,16 +745,39 @@ Full procedure, optimization journey, and cloud cost analysis in `docs/DAPT-PROC
|
|||||||
|
|
||||||
| Step | Loss | grad_norm | LR | Epoch | Note |
|
| Step | Loss | grad_norm | LR | Epoch | Note |
|
||||||
|------|------|-----------|-----|-------|------|
|
|------|------|-----------|-----|-------|------|
|
||||||
| 54 | 0.7991 | 0.066 | 2.66e-5 | 0.03 | Still in warmup (first 93 steps) |
|
| 54 | 0.7991 | 0.066 | 2.66e-5 | 0.03 | Warmup phase |
|
||||||
| 1280 | 0.7233 | 0.068 | 1.57e-5 | 0.70 | 70% through, steady decline |
|
| 1280 | 0.7233 | 0.068 | 1.57e-5 | 0.70 | Steady decline |
|
||||||
|
| 1800 | 0.7253 | 0.073 | 1.48e-6 | 0.97 | LR near zero, loss plateaued |
|
||||||
|
| **Final** | **0.7250** | **0.043** | **5.7e-8** | **1.00** | **Eval loss: 0.7250, perplexity: 1.65** |
|
||||||
|
|
||||||
The loss dropped 0.076 over ~1,200 steps — a gentle, steady downward drift. For comparison, a randomly initialized model would start at ~10.8 (ln(50280 vocab size)). Starting at 0.80 reflects that ModernBERT already knows English; the model is learning SEC-specific token co-occurrence patterns, not language fundamentals. grad_norm remained stable at ~0.07 throughout, indicating healthy, non-volatile weight updates.
|
The loss dropped from 0.80 → 0.72 — a gentle 10% decline over one epoch. For comparison, a randomly initialized model would start at ~10.8 (ln(50280 vocab size)). Starting at 0.80 reflects that ModernBERT already knows English; DAPT taught it SEC-specific token co-occurrence patterns ("NIST CSF", "materiality assessment", "tabletop exercise"), not language fundamentals. grad_norm remained stable at 0.04-0.07 throughout with zero instability. Total training time: ~14 hours across two sessions on an RTX 3090 (resumed from checkpoint-1280).
|
||||||
|
|
||||||
### TAPT Planning
|
The DAPT checkpoint is saved at `checkpoints/dapt/modernbert-large/final/` and is ready for TAPT.
|
||||||
|
|
||||||
The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than the DAPT corpus. Following Gururangan et al. (2020), we run multiple epochs to compensate, but match on total token exposure rather than blindly copying their 100-epoch setting (which was calibrated for 50-500K token corpora, 20-200x smaller than ours). 5-10 epochs × 10M = 50-100M total tokens, matching the upper end of their TAPT exposure. Estimated training time: ~1-2 hours.
|
### TAPT Configuration
|
||||||
|
|
||||||
One planned experiment: **whole-word masking** (`DataCollatorForWholeWordMasking`) for TAPT, which masks entire words rather than random subword tokens. Since the model already knows subword patterns from DAPT, TAPT should focus on domain-specific whole words. This is a one-line change.
|
The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than the DAPT corpus. This changes several training decisions vs. DAPT. Config file: `python/configs/tapt/modernbert.yaml`.
|
||||||
|
|
||||||
|
| Parameter | DAPT | TAPT | Rationale for change |
|
||||||
|
|-----------|------|------|---------------------|
|
||||||
|
| `max_seq_length` | 8192 | 512 | Data-driven: paragraphs average 127 tokens (P99=386, 99.6% fit in 512). Using 8192 would mean 98.5% padding — pure waste. See seq_len discussion below. |
|
||||||
|
| `num_train_epochs` | 1 | 5 | Gururangan et al. (2020) ran 100 epochs on 50-500K token TAPT corpora. We match total token exposure: 5 × 10M = 50M tokens ≈ upper bound of their TAPT exposure. |
|
||||||
|
| `whole_word_mask` | false | true | Masks entire words instead of subword pieces. Prevents trivially solvable masking patterns (e.g., masked `cyber` next to unmasked `security`). The model already knows subword composition from DAPT — TAPT should focus on domain-specific whole words ("CISO", "materiality", "tabletop"). |
|
||||||
|
| `per_device_train_batch_size` | 4 | 32 | Short sequences free VRAM. Tested: batch=32 uses 22.7 GB with torch.compile (vs. OOM at batch=48). |
|
||||||
|
| `gradient_accumulation_steps` | 8 | 1 | Effective batch = 32 in both cases. No accumulation needed since batch=32 fits directly. |
|
||||||
|
| `gradient_checkpointing` | true | false | Not needed at seq_len=512 — activations are small. Gradient checkpointing would slow training 30-40% for no memory benefit. |
|
||||||
|
| `save_strategy` / `eval_strategy` | steps (256) | epoch | 5 epochs; checkpoint and evaluate after each one. |
|
||||||
|
| `validation_split` | 0.02 | 0.05 | Larger val split for a 50x smaller dataset — need enough samples for stable eval loss. |
|
||||||
|
|
||||||
|
**Sequence length (512 vs. 8192):** The concern with a shorter seq_len is degrading the model's long-range attention capabilities. Three factors make this a non-issue for TAPT:
|
||||||
|
|
||||||
|
1. **The data is short.** Paragraphs average 127 tokens. There is no long-range structure to learn — the information simply isn't there.
|
||||||
|
2. **Scale of exposure.** TAPT is 50M token-exposures (5 epochs × 10M). ModernBERT was pre-trained on ~2T tokens; DAPT added 500M. 50M is 0.0025% of original pre-training — far too small to cause catastrophic forgetting of patterns established over trillions of tokens.
|
||||||
|
3. **RoPE positions are independent.** ModernBERT uses rotary position embeddings. Positions 0-511 compute identically whether max_length is 512 or 8192. Training at 512 updates the same parameters; positions 512-8191 remain as-is from DAPT, not degraded.
|
||||||
|
|
||||||
|
**Whole-word masking and tokenization:** Whole-word masking requires `offset_mapping` from the tokenizer to determine word boundaries. This is incompatible with DAPT's concatenate-and-chunk approach (which destroys offset_mapping by merging documents). TAPT tokenizes each paragraph individually with truncation, preserving offset_mapping. The data collator handles dynamic padding per batch. This is a different code path from DAPT's concatenation, but the data justifies it: paragraphs are natural self-contained units, unlike DAPT's long filings that must be chunked.
|
||||||
|
|
||||||
|
**Estimated training time:** ~2,138 steps/epoch × 5 epochs = ~10,700 total steps. At seq_len=512 on the 3090 (~0.5-1s/step), ballpark 1-3 hours.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -801,7 +824,7 @@ Only nano's portion ($21.24) of the first run was wasted — the gemini and grok
|
|||||||
| Stage 2 judge production run (~3-5K paragraphs) | ~1h | ~$20-40 |
|
| Stage 2 judge production run (~3-5K paragraphs) | ~1h | ~$20-40 |
|
||||||
| Training data assembly | ~2h | $0 |
|
| Training data assembly | ~2h | $0 |
|
||||||
| DAPT pre-training (1 epoch) | ~4-8h GPU | $0 (own 3090) |
|
| DAPT pre-training (1 epoch) | ~4-8h GPU | $0 (own 3090) |
|
||||||
| TAPT pre-training | ~2-3h GPU | $0 |
|
| TAPT pre-training (5 epochs, WWM) | ~1-3h GPU | $0 |
|
||||||
| Fine-tuning + ablations (7 experiments) | ~12-20h GPU | $0 |
|
| Fine-tuning + ablations (7 experiments) | ~12-20h GPU | $0 |
|
||||||
| Full GenAI benchmark on 1,200 holdout (9 models) | ~1h | ~$30-50 |
|
| Full GenAI benchmark on 1,200 holdout (9 models) | ~1h | ~$30-50 |
|
||||||
| Evaluation + comparison + write-up | ~6-8h | $0 |
|
| Evaluation + comparison + write-up | ~6-8h | $0 |
|
||||||
|
|||||||
@ -31,11 +31,8 @@
|
|||||||
|
|
||||||
## What's In Progress
|
## What's In Progress
|
||||||
|
|
||||||
### DAPT Training (~13.5h, running)
|
### DAPT Training — Complete
|
||||||
```bash
|
Final eval loss: 0.7250, perplexity: 1.65. Loss: 0.80 → 0.72 over 1 epoch on 500M tokens. ~14h total across 2 sessions on RTX 3090. Checkpoint at `checkpoints/dapt/modernbert-large/final/`.
|
||||||
cd python && bun run py:train dapt --config configs/dapt/modernbert.yaml
|
|
||||||
```
|
|
||||||
Running on RTX 3090. 500M tokens (newest filings), batch=4, ~27s/step, ~1,950 steps. Checkpoints every 256 steps (~1.9h). Resume-safe — ctrl+c and restart to continue from last checkpoint.
|
|
||||||
|
|
||||||
### Human Labeling (139/1,200)
|
### Human Labeling (139/1,200)
|
||||||
- 3 of 6 annotators started: 68 + 50 + 21 paragraphs completed
|
- 3 of 6 annotators started: 68 + 50 + 21 paragraphs completed
|
||||||
@ -44,13 +41,10 @@ Running on RTX 3090. 500M tokens (newest filings), batch=4, ~27s/step, ~1,950 st
|
|||||||
|
|
||||||
## What's Next (in dependency order)
|
## What's Next (in dependency order)
|
||||||
|
|
||||||
### 1. TAPT (~2-3h, blocked on DAPT)
|
### 1. TAPT (~1-3h, ready to run)
|
||||||
Continue MLM on 72K Item 1C paragraphs using the DAPT checkpoint.
|
Continue MLM on 72K Item 1C paragraphs using the DAPT checkpoint. 5 epochs, whole-word masking, seq_len=512, batch=32.
|
||||||
```bash
|
```bash
|
||||||
bun run py:train dapt --config configs/dapt/modernbert.yaml \
|
cd python && bun run py:train dapt --config configs/tapt/modernbert.yaml
|
||||||
--model-path ../checkpoints/dapt/modernbert-large/final \
|
|
||||||
--data-path ../data/paragraphs/paragraphs-clean.patched.jsonl \
|
|
||||||
--output-dir ../checkpoints/tapt/modernbert-large --stage tapt
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Fine-tuning pipeline (no blockers — can build now)
|
### 2. Fine-tuning pipeline (no blockers — can build now)
|
||||||
|
|||||||
@ -23,7 +23,9 @@ training:
|
|||||||
bf16: true
|
bf16: true
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
logging_steps: 50
|
logging_steps: 50
|
||||||
|
save_strategy: steps
|
||||||
save_steps: 256
|
save_steps: 256
|
||||||
|
eval_strategy: steps
|
||||||
eval_steps: 256
|
eval_steps: 256
|
||||||
save_total_limit: 8
|
save_total_limit: 8
|
||||||
dataloader_num_workers: 4
|
dataloader_num_workers: 4
|
||||||
|
|||||||
30
python/configs/tapt/modernbert.yaml
Normal file
30
python/configs/tapt/modernbert.yaml
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
stage: tapt
|
||||||
|
|
||||||
|
model:
|
||||||
|
name_or_path: ../checkpoints/dapt/modernbert-large/final
|
||||||
|
trust_remote_code: false
|
||||||
|
|
||||||
|
data:
|
||||||
|
corpus_path: ../data/paragraphs/paragraphs-clean.patched.jsonl
|
||||||
|
text_field: text
|
||||||
|
max_seq_length: 512 # 99.6% of paragraphs fit; mean=127, P99=386
|
||||||
|
validation_split: 0.05 # larger val split — small dataset
|
||||||
|
|
||||||
|
training:
|
||||||
|
output_dir: ../checkpoints/tapt/modernbert-large
|
||||||
|
learning_rate: 5.0e-5
|
||||||
|
mlm_probability: 0.30
|
||||||
|
whole_word_mask: true
|
||||||
|
num_train_epochs: 5
|
||||||
|
per_device_train_batch_size: 32 # 22.7 GB peak w/ torch.compile at seq_len=512
|
||||||
|
gradient_accumulation_steps: 1 # effective batch = 32 (matches DAPT)
|
||||||
|
warmup_ratio: 0.05
|
||||||
|
weight_decay: 1.0e-5
|
||||||
|
bf16: true
|
||||||
|
gradient_checkpointing: false # short sequences, not needed
|
||||||
|
logging_steps: 50
|
||||||
|
save_strategy: epoch
|
||||||
|
eval_strategy: epoch
|
||||||
|
save_total_limit: 6 # keep all 5 epoch checkpoints + final
|
||||||
|
dataloader_num_workers: 4
|
||||||
|
seed: 42
|
||||||
@ -10,6 +10,8 @@ dependencies = [
|
|||||||
"datasets",
|
"datasets",
|
||||||
"accelerate",
|
"accelerate",
|
||||||
"pyyaml",
|
"pyyaml",
|
||||||
|
"nvidia-cusparselt-cu12>=0.8.1",
|
||||||
|
"nvidia-nvshmem-cu12>=3.6.5",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@ -33,6 +33,7 @@ class TrainingConfig:
|
|||||||
output_dir: str
|
output_dir: str
|
||||||
learning_rate: float = 5e-5
|
learning_rate: float = 5e-5
|
||||||
mlm_probability: float = 0.30
|
mlm_probability: float = 0.30
|
||||||
|
whole_word_mask: bool = False
|
||||||
num_train_epochs: int = 1
|
num_train_epochs: int = 1
|
||||||
per_device_train_batch_size: int = 4
|
per_device_train_batch_size: int = 4
|
||||||
gradient_accumulation_steps: int = 8
|
gradient_accumulation_steps: int = 8
|
||||||
@ -41,8 +42,10 @@ class TrainingConfig:
|
|||||||
bf16: bool = True
|
bf16: bool = True
|
||||||
gradient_checkpointing: bool = True
|
gradient_checkpointing: bool = True
|
||||||
logging_steps: int = 50
|
logging_steps: int = 50
|
||||||
save_steps: int = 1000
|
save_strategy: str = "steps" # "steps" or "epoch"
|
||||||
eval_steps: int = 1000
|
save_steps: int = 1000 # ignored when save_strategy="epoch"
|
||||||
|
eval_strategy: str = "steps" # "steps" or "epoch"
|
||||||
|
eval_steps: int = 1000 # ignored when eval_strategy="epoch"
|
||||||
save_total_limit: int = 3
|
save_total_limit: int = 3
|
||||||
dataloader_num_workers: int = 4
|
dataloader_num_workers: int = 4
|
||||||
seed: int = 42
|
seed: int = 42
|
||||||
|
|||||||
@ -62,13 +62,14 @@ def train(config: DAPTConfig) -> None:
|
|||||||
dataset = load_corpus(config.data.corpus_path, config.data.text_field)
|
dataset = load_corpus(config.data.corpus_path, config.data.text_field)
|
||||||
print(f" Raw documents: {len(dataset):,}")
|
print(f" Raw documents: {len(dataset):,}")
|
||||||
|
|
||||||
# Filter tiny documents (cover pages, empty filings)
|
# Filter tiny documents (cover pages, empty filings) — DAPT only
|
||||||
min_chars = 10_000
|
if config.stage == "dapt":
|
||||||
before = len(dataset)
|
min_chars = 10_000
|
||||||
dataset = dataset.filter(lambda x: len(x[config.data.text_field]) >= min_chars)
|
before = len(dataset)
|
||||||
filtered = before - len(dataset)
|
dataset = dataset.filter(lambda x: len(x[config.data.text_field]) >= min_chars)
|
||||||
if filtered > 0:
|
filtered = before - len(dataset)
|
||||||
print(f" Filtered {filtered} docs < {min_chars:,} chars → {len(dataset):,} remaining")
|
if filtered > 0:
|
||||||
|
print(f" Filtered {filtered} docs < {min_chars:,} chars → {len(dataset):,} remaining")
|
||||||
|
|
||||||
# Subsample corpus if max_tokens is set (Ponnock 2025: diminishing
|
# Subsample corpus if max_tokens is set (Ponnock 2025: diminishing
|
||||||
# returns beyond ~250M tokens for SEC DAPT). Takes from the END of
|
# returns beyond ~250M tokens for SEC DAPT). Takes from the END of
|
||||||
@ -88,12 +89,17 @@ def train(config: DAPTConfig) -> None:
|
|||||||
est_tokens = cumulative / chars_per_token
|
est_tokens = cumulative / chars_per_token
|
||||||
print(f" Subsampled to {n - keep_from:,} docs (~{est_tokens / 1e6:.0f}M tokens, newest filings, max_tokens={config.data.max_tokens:,})")
|
print(f" Subsampled to {n - keep_from:,} docs (~{est_tokens / 1e6:.0f}M tokens, newest filings, max_tokens={config.data.max_tokens:,})")
|
||||||
|
|
||||||
print(f" Tokenizing and chunking to {config.data.max_seq_length} tokens...")
|
wwm = config.training.whole_word_mask
|
||||||
|
if wwm:
|
||||||
|
print(f" Tokenizing to {config.data.max_seq_length} tokens (whole-word mask)...")
|
||||||
|
else:
|
||||||
|
print(f" Tokenizing and chunking to {config.data.max_seq_length} tokens...")
|
||||||
chunked = tokenize_and_chunk(
|
chunked = tokenize_and_chunk(
|
||||||
dataset,
|
dataset,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
text_field=config.data.text_field,
|
text_field=config.data.text_field,
|
||||||
max_seq_length=config.data.max_seq_length,
|
max_seq_length=config.data.max_seq_length,
|
||||||
|
whole_word_mask=wwm,
|
||||||
)
|
)
|
||||||
print(f" Training sequences: {len(chunked):,}")
|
print(f" Training sequences: {len(chunked):,}")
|
||||||
|
|
||||||
@ -113,30 +119,34 @@ def train(config: DAPTConfig) -> None:
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
mlm=True,
|
mlm=True,
|
||||||
mlm_probability=config.training.mlm_probability,
|
mlm_probability=config.training.mlm_probability,
|
||||||
|
whole_word_mask=config.training.whole_word_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training arguments
|
# Training arguments
|
||||||
output_dir = Path(config.training.output_dir)
|
output_dir = Path(config.training.output_dir)
|
||||||
args = TrainingArguments(
|
steps_per_epoch = len(split["train"]) // (
|
||||||
|
config.training.per_device_train_batch_size
|
||||||
|
* config.training.gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
warmup_steps = int(config.training.warmup_ratio * steps_per_epoch)
|
||||||
|
|
||||||
|
training_kwargs: dict = dict(
|
||||||
output_dir=str(output_dir),
|
output_dir=str(output_dir),
|
||||||
learning_rate=config.training.learning_rate,
|
learning_rate=config.training.learning_rate,
|
||||||
num_train_epochs=config.training.num_train_epochs,
|
num_train_epochs=config.training.num_train_epochs,
|
||||||
per_device_train_batch_size=config.training.per_device_train_batch_size,
|
per_device_train_batch_size=config.training.per_device_train_batch_size,
|
||||||
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
|
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
|
||||||
warmup_steps=int(config.training.warmup_ratio * (len(split["train"]) // (config.training.per_device_train_batch_size * config.training.gradient_accumulation_steps))),
|
warmup_steps=warmup_steps,
|
||||||
weight_decay=config.training.weight_decay,
|
weight_decay=config.training.weight_decay,
|
||||||
bf16=config.training.bf16,
|
bf16=config.training.bf16,
|
||||||
gradient_checkpointing=config.training.gradient_checkpointing,
|
gradient_checkpointing=config.training.gradient_checkpointing,
|
||||||
gradient_checkpointing_kwargs={"use_reentrant": False},
|
|
||||||
torch_compile=True,
|
torch_compile=True,
|
||||||
optim="adamw_torch_fused",
|
optim="adamw_torch_fused",
|
||||||
tf32=True,
|
tf32=True,
|
||||||
per_device_eval_batch_size=1,
|
|
||||||
dataloader_persistent_workers=True,
|
dataloader_persistent_workers=True,
|
||||||
logging_steps=config.training.logging_steps,
|
logging_steps=config.training.logging_steps,
|
||||||
save_steps=config.training.save_steps,
|
save_strategy=config.training.save_strategy,
|
||||||
eval_strategy="steps",
|
eval_strategy=config.training.eval_strategy,
|
||||||
eval_steps=config.training.eval_steps,
|
|
||||||
save_total_limit=config.training.save_total_limit,
|
save_total_limit=config.training.save_total_limit,
|
||||||
dataloader_num_workers=config.training.dataloader_num_workers,
|
dataloader_num_workers=config.training.dataloader_num_workers,
|
||||||
seed=config.training.seed,
|
seed=config.training.seed,
|
||||||
@ -145,6 +155,19 @@ def train(config: DAPTConfig) -> None:
|
|||||||
metric_for_best_model="eval_loss",
|
metric_for_best_model="eval_loss",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.training.gradient_checkpointing:
|
||||||
|
training_kwargs["gradient_checkpointing_kwargs"] = {"use_reentrant": False}
|
||||||
|
# Long sequences need small eval batch to avoid OOM
|
||||||
|
training_kwargs["per_device_eval_batch_size"] = 1
|
||||||
|
|
||||||
|
# Only pass step counts when using step-based strategy
|
||||||
|
if config.training.save_strategy == "steps":
|
||||||
|
training_kwargs["save_steps"] = config.training.save_steps
|
||||||
|
if config.training.eval_strategy == "steps":
|
||||||
|
training_kwargs["eval_steps"] = config.training.eval_steps
|
||||||
|
|
||||||
|
args = TrainingArguments(**training_kwargs)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
model=model,
|
model=model,
|
||||||
args=args,
|
args=args,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user