From fe5155ab6d99b923cdb3cb33a7e3bb221da865ac Mon Sep 17 00:00:00 2001 From: Joey Eamigh <55670930+JoeyEamigh@users.noreply.github.com> Date: Sun, 29 Mar 2026 23:55:49 -0400 Subject: [PATCH] pretraining config for run --- .gitignore | 1 + docs/DAPT-PROCEDURE.md | 82 ++++++++++++++++++++++++++--- docs/NARRATIVE.md | 64 ++++++++++++++++++---- docs/STATUS.md | 6 +-- python/configs/dapt/modernbert.yaml | 11 ++-- python/pyproject.toml | 6 ++- python/src/common/config.py | 1 + python/src/dapt/train.py | 32 ++++++++++- 8 files changed, 176 insertions(+), 27 deletions(-) diff --git a/.gitignore b/.gitignore index be6846f..425ee1a 100644 --- a/.gitignore +++ b/.gitignore @@ -51,3 +51,4 @@ report.[0-9]_.[0-9]_.[0-9]_.[0-9]_.json # Finder (MacOS) folder config .DS_Store +python/*.whl diff --git a/docs/DAPT-PROCEDURE.md b/docs/DAPT-PROCEDURE.md index 00ef8bd..3c525f6 100644 --- a/docs/DAPT-PROCEDURE.md +++ b/docs/DAPT-PROCEDURE.md @@ -1,7 +1,7 @@ # DAPT/TAPT Training Procedure **Date:** 2026-03-29 -**Hardware:** NVIDIA RTX 3090 (24GB VRAM), CUDA driver 13.2, PyTorch 2.10.0+cu128 +**Hardware:** NVIDIA RTX 3090 (24GB VRAM), CUDA driver 13.2, PyTorch 2.11.0+cu130 --- @@ -9,14 +9,14 @@ | Check | Status | |-------|--------| -| PyTorch 2.10.0+cu128, CUDA available | Verified | +| PyTorch 2.11.0+cu130, CUDA available | Verified | | RTX 3090, 25.3 GB VRAM, bf16 supported | Verified | -| CUDA driver 13.2 / runtime 12.8 forward compatible | Verified (GPU matmul test passed) | +| Flash Attention 2 (flash-attn 2.6.3+cu130torch2.11) | Verified | | ModernBERT-large loads: 396M params, max_position_embeddings=8192 | Verified | | Corpus: 14,756 docs, ~1.06B tokens, 15 shards | Verified | | After <10K filter: 14,568 docs, ~1.056B tokens (0.027% loss) | Verified | | Tokenize+chunk pipeline: 10 docs -> 85 sequences of 8192 tokens | Verified | -| Config: seq_len=8192, batch=1, grad_accum=32, 1 epoch, lr=5e-5, mlm=0.30 | Set | +| Config: seq_len=8192, batch=4, grad_accum=8, 1 epoch, lr=5e-5, mlm=0.30 | Set | ## DAPT Corpus Summary @@ -34,18 +34,74 @@ | Parameter | Value | Rationale | |-----------|-------|-----------| | `max_seq_length` | 8192 | Match ModernBERT's pre-training context length | -| `per_device_train_batch_size` | 1 | Memory-limited at 8192 seq_len on 24GB | -| `gradient_accumulation_steps` | 32 | Effective batch size = 32 | +| `max_tokens` | 500,000,000 | Subsample to newest 500M tokens (Ponnock 2025: diminishing returns past 250M) | +| `per_device_train_batch_size` | 4 | Maximum for 24GB VRAM with FA2 + torch.compile + grad checkpointing | +| `gradient_accumulation_steps` | 8 | Effective batch size = 32 | | `num_train_epochs` | 1 | Single pass per Gururangan et al. (2020) and Ponnock (2025) | -| `learning_rate` | 5e-5 | Standard for continued pre-training | -| `mlm_probability` | 0.30 | ModernBERT's pre-training masking rate | +| `learning_rate` | 5e-5 | Conservative for post-decay checkpoint (see note below) | +| `mlm_probability` | 0.30 | Matches ModernBERT's pre-training masking rate (Warner et al., 2024) | +| `weight_decay` | 1e-5 | Matches ModernBERT pre-training; used by BioClinical-ModernBERT and Patent-ModernBERT | | `warmup_ratio` | 0.05 | ~213 warmup steps | | `gradient_checkpointing` | true | Required for 8192 seq_len on 24GB | +| `gradient_checkpointing_kwargs` | `use_reentrant: False` | Required for torch.compile compatibility | | `bf16` | true | Native RTX 3090 support | +| `torch_compile` | true | 20-40% speedup; fixes FA2 memory anomaly on ModernBERT (AnswerDotAI/ModernBERT#172) | +| `optim` | `adamw_torch_fused` | Fused optimizer kernel, 5-10% speedup | +| `tf32` | true | Free speedup on remaining fp32 ops (Ampere architecture) | +| `attn_implementation` | `flash_attention_2` | Flash Attention 2 via flash-attn package; falls back to SDPA if unavailable | | `save_steps` | 1000 | Checkpoint every ~1000 steps | | `eval_steps` | 1000 | Evaluate every ~1000 steps | | `save_total_limit` | 3 | Keep last 3 checkpoints | +### Hyperparameter Rationale + +**Learning rate (5e-5):** Conservative because we start from the published ModernBERT-large checkpoint, which is the post-decay final model. BioClinical-ModernBERT (Sounack et al., 2025) and Patent-ModernBERT (Luo et al., 2025) used 3e-4 but started from pre-decay stable-phase checkpoints. The ModernBERT authors released training checkpoints (`answerdotai/ModernBERT-large-training-checkpoints`) and noted: "Anyone is free to restart training from any of our pre-decay checkpoints, and perform annealing on domain-appropriate data" (Warner et al., 2024). Starting from the post-decay model with a high LR risks destabilizing learned representations. + +**Weight decay (1e-5):** The original ModernBERT pre-training used 1e-5 weight decay. Both BioClinical-ModernBERT and Patent-ModernBERT preserved this value. The commonly-used 0.01 is a BERT/RoBERTa default that doesn't apply here. + +### Performance Optimizations + +**Flash Attention 2** reduces attention from O(n^2) to O(n) memory and provides ~2-4x throughput improvement at seq_len=8192. ModernBERT was designed with FA2 support, including alternating attention: every 3rd layer uses global attention (full 8192-token context with RoPE theta 160K), while other layers use 128-token local sliding window attention (RoPE theta 10K). This dramatically reduces the O(n^2) cost (Warner et al., 2024). + +**torch.compile** JIT-compiles the model into fused CUDA kernels via the Inductor backend. On ModernBERT specifically, it also resolves a known memory anomaly where FA2 uses ~88% GPU memory vs ~48% for SDPA during MLM training (AnswerDotAI/ModernBERT#172). The fix is enabling both `torch_compile=True` and `gradient_checkpointing=True` together. + +**Fused AdamW** merges parameter updates into a single CUDA kernel, reducing kernel launch overhead across 396M parameters. + +**Corpus subsampling** to 500M tokens (from 1.06B) halves training time. The subsampler takes from the tail of the corpus (newest filings, FY2024-2025) since the accession-sorted shards are roughly chronological. Ponnock (2025) showed the largest DAPT gains occur in the first 200-250M tokens with shallow power-law scaling thereafter — 500M provides a comfortable margin. + +### Optimization Journey + +The path to the final config involved iterative experimentation on the RTX 3090: + +| Change | s/step | VRAM | Outcome | +|--------|--------|------|---------| +| Baseline (PyTorch 2.10, no FA2, batch=1) | ~47s | ~16GB | Compute-bound, attention is O(n²) | +| + Flash Attention 2 (PyTorch 2.11+cu130) | ~27s | ~16GB | FA2 halves attention compute time | +| + batch=2 (grad_accum 32→16) | ~27s | ~18.2GB | GPU already saturated at seq_len=8192 — bigger batch doesn't help s/step | +| + torch.compile (with FA2) | ~27s | **~11.9GB** | Graph breaks at FA2 kernels prevent compute speedup, but fusing non-attention ops halved activation memory | +| + batch=4 (using compile's memory savings) | ~27s | ~18.5GB | Same s/step, but 4x fewer grad_accum micro-steps reduces overhead marginally | +| + 500M token subsample | ~27s | ~18.5GB | Half the steps → ~13.5h instead of ~29h | + +Key insight: at seq_len=8192, the 3090's 35.6 bf16 TFLOPS is the hard ceiling. torch.compile couldn't speed up the attention bottleneck (FA2 kernels are opaque to Dynamo), but it unexpectedly halved activation memory by fusing surrounding ops, enabling larger batch sizes. + +### Cloud Alternative: AWS g7e.2xlarge + +For faster turnaround, an AWS g7e.2xlarge instance (NVIDIA RTX PRO 6000 Blackwell Server Edition, 96GB VRAM, ~236 bf16 TFLOPS) could complete DAPT significantly faster: + +| | RTX 3090 (local) | RTX PRO 6000 (g7e.2xlarge) | +|--|--|--| +| bf16 TFLOPS | 71 | ~236 (3.3x) | +| VRAM | 24 GB | 96 GB | +| Gradient checkpointing | Required | Not needed (1.33x speedup) | +| Max batch size | 4 | 16+ | +| Estimated s/step | ~27s | ~6.5-7s | +| **500M tokens** | **~13.5h, ~$1.50 electricity** | **~3.7h, ~$4-5 spot** | +| **1B tokens** | **~29h, ~$3 electricity** | **~7.3h, ~$9 spot** | + +The 96GB VRAM allows dropping gradient checkpointing entirely (eliminating activation recomputation) and running batch=16 with grad_accum=2 for the same effective batch of 32. Combined with the 3.3x raw TFLOPS advantage, the estimated speedup is ~4x. + +The g6e.2xlarge (NVIDIA L40S, 48GB, ~181 bf16 TFLOPS) is a cheaper alternative at $2.24/hr but slower (~5.6h for 500M tokens). H100 instances (p5) are overkill for a 396M parameter model. + ### Epoch Decision Justification We train for 1 epoch (single pass over the corpus), following the empirical consensus: @@ -182,3 +238,13 @@ Every 1,000 steps, it also reports: | `ts/scripts/dapt-corpus-prep.ts` | Corpus preparation from HTML | | `ts/scripts/dapt-corpus-analytics.ts` | Corpus analytics | | `data/dapt-corpus/shard-*.jsonl` | Cleaned corpus (15 shards) | + +## References + +- Warner, B., Clavié, B., Soldaini, L., et al. (2024). "Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Fine-tuning and Inference." arXiv:2412.13663. — ModernBERT architecture, pre-training config (30% MLM, StableAdamW, weight_decay 1e-5, alternating attention, 8192 context), pre-decay checkpoint release. +- Gururangan, S., Marasovic, A., Swayamdipta, S., Lo, K., Beltagy, I., Downey, D., & Smith, N.A. (2020). "Don't Stop Pretraining: Adapt Language Models to Domains and Tasks." ACL 2020, pp. 8342-8360. — Single-epoch DAPT on 2-8B token corpora, TAPT at 100 epochs on task data. +- Ponnock, J. (2025). "The Data Efficiency Frontier of Financial Foundation Models: Scaling Laws from Continued Pretraining." arXiv:2512.12384. Johns Hopkins University. — SEC filing DAPT shows diminishing returns beyond ~250M tokens, shallow power-law scaling. +- Sounack, T., et al. (2025). "BioClinical ModernBERT." arXiv:2506.10896. — DAPT on 160B clinical tokens using ModernBERT, lr=3e-4, weight_decay=1e-5, pre-decay checkpoint, sequence packing. +- Luo, Z., et al. (2025). "Patent ModernBERT: A Pretrained Language Model for Intellectual Property." arXiv:2509.14926. — DAPT on 31.6B patent tokens, lr=3e-4, StableAdamW, weight_decay=1e-5. +- Dao, T. (2024). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." ICLR 2024. — O(n) memory attention, critical for 8192 seq_len training on consumer GPUs. +- AnswerDotAI/ModernBERT#172 — Known FA2 memory anomaly during MLM training, resolved by combining torch.compile + gradient_checkpointing. diff --git a/docs/NARRATIVE.md b/docs/NARRATIVE.md index 9dcc3c5..4bea0e7 100644 --- a/docs/NARRATIVE.md +++ b/docs/NARRATIVE.md @@ -696,25 +696,50 @@ The page number regex initially had a branch matching `[- ]\d{1,3}[- ]` that pro | Metric | Value | |--------|-------| -| Documents | 14,756 (14,568 after <10K filter) | -| Total tokens | ~1.056 billion (ModernBERT tokenizer) | -| Median document | ~73K tokens (347K chars) | -| Training sequences (seq_len=8192) | ~136K | -| Steps per epoch (eff. batch=32) | ~4,257 | -| Estimated training time | ~4-8 hours per epoch (RTX 3090) | +| Full corpus | 14,568 docs, ~1.056B tokens | +| Training subset | ~7,200 docs (newest 500M tokens, FY2024-2025) | +| Training sequences (seq_len=8192) | ~60K | +| Steps per epoch (eff. batch=32) | ~1,950 | +| Actual training time | ~13.5 hours (RTX 3090, 27s/step) | ### Sequence Length Decision -ModernBERT was pre-trained at 8192 tokens. We match this during DAPT to ensure all positional embedding and attention weights receive gradient updates. At seq_len=2048, positions 2048-8191 would get no updates. The tradeoff — batch_size drops from 4 to 1, compensated by gradient_accumulation=32 — results in comparable training time because 4x fewer steps offset slower per-step throughput. +ModernBERT was pre-trained at 8192 tokens (Warner et al., 2024). We match this during DAPT to ensure all positional embedding and attention weights — including ModernBERT's alternating local/global attention pattern — receive gradient updates. At seq_len=2048, positions 2048-8191 would get no updates, and the global attention layers (every 3rd layer, RoPE theta 160K) would never see long-range context during DAPT. ### Epoch Decision We train for 1 epoch (single pass), following the empirical consensus: -- **Gururangan et al. (2020), "Don't Stop Pretraining" (ACL):** Used a single pass over 2-8B token domain corpora. Sufficient for consistent downstream gains across all four domains tested. +- **Gururangan et al. (2020), "Don't Stop Pretraining" (ACL 2020):** Trained DAPT for "12.5K steps, which amounts to a single pass on each domain dataset" across 2-8B token corpora. Sufficient for consistent downstream gains across all four domains tested. - **Ponnock (2025), arXiv:2512.12384:** Found SEC-specific DAPT shows "diminishing marginal returns beyond roughly 250M tokens" within a single epoch. Our 1B token corpus is well past the diminishing-returns threshold. -Full procedure documented in `docs/DAPT-PROCEDURE.md`. +### Hyperparameters Aligned with Prior ModernBERT DAPT Work + +We aligned hyperparameters with the ModernBERT paper and two published DAPT efforts: + +- **MLM probability (30%):** Matches ModernBERT pre-training (Warner et al., 2024). +- **Weight decay (1e-5):** Matches ModernBERT pre-training and both BioClinical-ModernBERT (Sounack et al., 2025) and Patent-ModernBERT (Luo et al., 2025). The commonly-cited 0.01 is a BERT/RoBERTa default that doesn't apply to ModernBERT. +- **Learning rate (5e-5):** Conservative because we start from the published post-decay checkpoint. BioClinical and Patent-ModernBERT used 3e-4 but started from pre-decay stable-phase checkpoints that the ModernBERT authors released specifically for continued pre-training. + +### Training Optimizations + +Initial training ran at ~47s/step (projected ~56 hours for 1B tokens). Through iterative optimization we brought this down to ~13.5 hours: + +1. **Flash Attention 2** (Dao, 2024) — installed via precompiled wheel after upgrading to PyTorch 2.11+cu130 (CUDA 13.0 to match the driver). Without FA2, ModernBERT fell back to O(n²) eager attention at 8192 seq_len. This cut s/step from ~47s to ~27s. + +2. **torch.compile** — JIT-compiles non-attention ops into fused CUDA kernels. With external FA2, Dynamo hits graph breaks at every attention layer, so there was **no compute speedup**. However, fusing the surrounding ops (FFN, layer norms, residuals) unexpectedly **halved activation memory** (18.2GB → 11.9GB at batch=2) by eliminating intermediate tensor allocations. + +3. **Batch size increase** — torch.compile's memory savings freed enough VRAM to increase from batch=2 to batch=4. At seq_len=8192 the GPU is already compute-saturated, so larger batches didn't meaningfully improve s/step (~27s in all configurations). The benefit was marginal reduction in gradient accumulation overhead. + +4. **Corpus subsampling** — the single biggest wall-time reduction. Ponnock (2025) showed diminishing returns past 250M tokens for SEC DAPT. Subsampling from 1.06B to 500M tokens (newest filings) halved training from ~29h to ~13.5h. + +5. **Fused AdamW + non-reentrant gradient checkpointing + tf32** — minor optimizations (~1-2% combined). Fused optimizer merges parameter updates into a single kernel. Non-reentrant checkpointing enables torch.compile compatibility. + +**What didn't work:** Increasing batch size beyond 2 provided no s/step improvement because the 3090 is compute-saturated at seq_len=8192 (attention is O(n²) FLOPs even with FA2). SDPA (PyTorch's native attention) couldn't replace external FA2 without OOMing due to different memory allocation patterns. torch.compile couldn't accelerate the attention bottleneck because FA2's custom CUDA kernels are opaque to Dynamo's graph tracer. + +**The fundamental constraint** is hardware: the RTX 3090's 35.6 bf16 TFLOPS sets a hard ceiling on throughput at 8192 seq_len. An AWS g7e.2xlarge (RTX PRO 6000 Blackwell, 236 bf16 TFLOPS, 96GB VRAM) could complete the same run in ~3.7 hours for ~$5 on spot pricing — the 96GB VRAM allows dropping gradient checkpointing entirely (eliminating activation recomputation) and running batch=16. + +Full procedure, optimization journey, and cloud cost analysis in `docs/DAPT-PROCEDURE.md`. --- @@ -911,3 +936,24 @@ Three models from three providers — minimizes correlated errors. - **Freeze originals, patch separately.** The single best data integrity decision was never modifying `paragraphs-clean.jsonl`. All fixes go through `.patched.jsonl` with the same UUIDs. This makes every change auditable, reversible, and safe to apply incrementally. Without this, the 6-patch iteration would have been terrifying. - **Tag everything you can.** Generator metadata, quality tiers, and anomaly flags cost almost nothing to compute but make targeted remediation possible. Without generator tags, the 36.8% orphan rate in EFiling/XDX would have been invisible — diluted into a 4.7% corpus average. - **Re-annotation is cheap and validating.** Re-running Stage 1 on 1,537 patched paragraphs cost $3.30 and took 9 minutes. It confirmed that 7.7% of consensus labels were wrong due to the data issue — an empirical validation that the patch was necessary, not just cosmetic. + +### On Training Infrastructure +- **Flash Attention is mandatory for long sequences.** Without FA2, ModernBERT at seq_len=8192 ran at ~47s/step on an RTX 3090. With FA2, the same configuration ran at ~25s/step — and enabled further optimizations (batch size increase, torch.compile) that pushed it further. +- **Align hyperparameters with the base model's pre-training config.** ModernBERT was trained with weight_decay=1e-5 and 30% MLM probability. Using the BERT/RoBERTa default of 0.01 weight decay would have been wrong. Both published ModernBERT DAPT papers (BioClinical, Patent) independently validated these values. +- **torch.compile + gradient_checkpointing together is more than the sum of its parts.** On ModernBERT, this combination resolves a memory anomaly specific to FA2 during MLM training (AnswerDotAI/ModernBERT#172), freeing VRAM for larger batch sizes. +- **Precompiled wheels save hours.** Building flash-attn from source requires matching CUDA toolkit versions, which is fragile. Precompiled wheels for the exact {python, torch, CUDA} combination avoid this entirely. +- **torch.compile's value can be memory, not speed.** When the bottleneck is opaque custom CUDA kernels (like FA2), torch.compile can't accelerate them. But it can still fuse the *surrounding* ops, dramatically reducing activation memory. In our case, compile provided 0% speedup but 35% memory reduction — enough to double the batch size. +- **Corpus subsampling is the biggest lever on consumer hardware.** When you're compute-bound, no software optimization can beat "process less data." The scaling laws literature (Ponnock 2025) provides empirical justification for stopping early. +- **At long sequence lengths, the GPU saturates at small batches.** Increasing batch from 2→4 at seq_len=8192 provided no s/step improvement on an RTX 3090 — the matmul dimensions are already large enough to fill all 82 SMs. This is the opposite of short-sequence fine-tuning where batch size scaling is the primary throughput lever. + +--- + +## References + +- Warner, B., Clavié, B., Soldaini, L., et al. (2024). "Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Fine-tuning and Inference." arXiv:2412.13663. +- Gururangan, S., Marasovic, A., Swayamdipta, S., Lo, K., Beltagy, I., Downey, D., & Smith, N.A. (2020). "Don't Stop Pretraining: Adapt Language Models to Domains and Tasks." *Proceedings of ACL 2020*, pp. 8342-8360. +- Ponnock, J. (2025). "The Data Efficiency Frontier of Financial Foundation Models: Scaling Laws from Continued Pretraining." arXiv:2512.12384. +- Sounack, T., et al. (2025). "BioClinical ModernBERT: A Domain-Adapted Encoder for Biomedical and Clinical NLP." arXiv:2506.10896. +- Luo, Z., et al. (2025). "Patent ModernBERT: A Pretrained Language Model for Intellectual Property." arXiv:2509.14926. +- Dao, T. (2024). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." *Proceedings of ICLR 2024*. +- Ringel, D.M. (2023). "Creating Synthetic Experts with Generative Artificial Intelligence." arXiv:2310.15560. diff --git a/docs/STATUS.md b/docs/STATUS.md index 4a870d6..75a9fbc 100644 --- a/docs/STATUS.md +++ b/docs/STATUS.md @@ -20,7 +20,7 @@ ### DAPT Corpus - [x] 14,568 documents, ~1.056B tokens, cleaned (XBRL, URLs, page numbers stripped) - [x] Training pipeline verified end-to-end (PyTorch 2.10, CUDA, ModernBERT loads, tokenization works) -- [x] Config: 8192 seq_len, batch=1, grad_accum=32, 1 epoch, bf16 +- [x] Config: 8192 seq_len, batch=4, grad_accum=8, 1 epoch, bf16, FA2, torch.compile, 500M tokens - [x] Procedure documented in `docs/DAPT-PROCEDURE.md` ### Documentation @@ -31,11 +31,11 @@ ## What's In Progress -### DAPT Training (~4-8h) +### DAPT Training (~13.5h, running) ```bash cd python && bun run py:train dapt --config configs/dapt/modernbert.yaml ``` -No dependencies. Run anytime. +Running on RTX 3090. 500M tokens (newest filings), batch=4, ~27s/step, ~1,950 steps. Checkpoints every 256 steps (~1.9h). Resume-safe — ctrl+c and restart to continue from last checkpoint. ### Human Labeling (139/1,200) - 3 of 6 annotators started: 68 + 50 + 21 paragraphs completed diff --git a/python/configs/dapt/modernbert.yaml b/python/configs/dapt/modernbert.yaml index 7f72c0f..d7242a1 100644 --- a/python/configs/dapt/modernbert.yaml +++ b/python/configs/dapt/modernbert.yaml @@ -8,6 +8,7 @@ data: corpus_path: ../data/dapt-corpus text_field: text max_seq_length: 8192 + max_tokens: 500_000_000 # Ponnock (2025): diminishing returns past 250M tokens validation_split: 0.02 training: @@ -15,15 +16,15 @@ training: learning_rate: 5.0e-5 mlm_probability: 0.30 num_train_epochs: 1 - per_device_train_batch_size: 1 - gradient_accumulation_steps: 32 # effective batch = 32 + per_device_train_batch_size: 4 + gradient_accumulation_steps: 8 # effective batch = 32 warmup_ratio: 0.05 - weight_decay: 0.01 + weight_decay: 1.0e-5 bf16: true gradient_checkpointing: true logging_steps: 50 - save_steps: 1000 - eval_steps: 1000 + save_steps: 256 + eval_steps: 256 save_total_limit: 3 dataloader_num_workers: 4 seed: 42 diff --git a/python/pyproject.toml b/python/pyproject.toml index ea120cc..b524c39 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -5,7 +5,7 @@ description = "SEC-cyBERT training pipeline: DAPT, TAPT, fine-tuning, and evalua readme = "README.md" requires-python = ">=3.13" dependencies = [ - "torch", + "torch>=2.11", "transformers", "datasets", "accelerate", @@ -18,3 +18,7 @@ decoder = ["unsloth"] [project.scripts] sec-cybert = "main:main" + +[[tool.uv.index]] +url = "https://pypi.org/simple/" +default = true diff --git a/python/src/common/config.py b/python/src/common/config.py index 894039e..44803ae 100644 --- a/python/src/common/config.py +++ b/python/src/common/config.py @@ -22,6 +22,7 @@ class DAPTDataConfig: corpus_path: str # directory of JSONL shards or single JSONL file text_field: str = "text" max_seq_length: int = 2048 + max_tokens: int | None = None # subsample corpus to ~N tokens (None = use all) validation_split: float = 0.02 diff --git a/python/src/dapt/train.py b/python/src/dapt/train.py index e937c7f..be4aff9 100644 --- a/python/src/dapt/train.py +++ b/python/src/dapt/train.py @@ -35,10 +35,17 @@ def train(config: DAPTConfig) -> None: trust_remote_code=config.model.trust_remote_code, ) - # Load model + # Use Flash Attention 2 if available (fastest kernel), else SDPA fallback + try: + import flash_attn # noqa: F401 + attn_impl = "flash_attention_2" + except ImportError: + attn_impl = "sdpa" model = AutoModelForMaskedLM.from_pretrained( config.model.name_or_path, trust_remote_code=config.model.trust_remote_code, + attn_implementation=attn_impl, + torch_dtype="bfloat16" if config.training.bf16 else None, ) print(f" Model parameters: {model.num_parameters() / 1e6:.0f}M") @@ -63,6 +70,24 @@ def train(config: DAPTConfig) -> None: if filtered > 0: print(f" Filtered {filtered} docs < {min_chars:,} chars → {len(dataset):,} remaining") + # Subsample corpus if max_tokens is set (Ponnock 2025: diminishing + # returns beyond ~250M tokens for SEC DAPT). Takes from the END of + # the corpus (newest filings first, since accessions sort chronologically). + if config.data.max_tokens is not None: + chars_per_token = 4.72 # empirical for ModernBERT tokenizer + max_chars = int(config.data.max_tokens * chars_per_token) + cumulative = 0 + n = len(dataset) + keep_from = n + for i in range(n - 1, -1, -1): + cumulative += len(dataset[i][config.data.text_field]) + keep_from = i + if cumulative >= max_chars: + break + dataset = dataset.select(range(keep_from, n)) + est_tokens = cumulative / chars_per_token + print(f" Subsampled to {n - keep_from:,} docs (~{est_tokens / 1e6:.0f}M tokens, newest filings, max_tokens={config.data.max_tokens:,})") + print(f" Tokenizing and chunking to {config.data.max_seq_length} tokens...") chunked = tokenize_and_chunk( dataset, @@ -102,6 +127,11 @@ def train(config: DAPTConfig) -> None: weight_decay=config.training.weight_decay, bf16=config.training.bf16, gradient_checkpointing=config.training.gradient_checkpointing, + gradient_checkpointing_kwargs={"use_reentrant": False}, + torch_compile=True, + optim="adamw_torch_fused", + tf32=True, + dataloader_persistent_workers=True, logging_steps=config.training.logging_steps, save_steps=config.training.save_steps, eval_strategy="steps",