roll back to python 3.13 to fix everything lol

This commit is contained in:
Joey Eamigh 2026-03-30 21:25:46 -04:00
parent 8190950f1a
commit 7b660fe361
No known key found for this signature in database
GPG Key ID: CE8C05DFFC53C9CB
5 changed files with 82 additions and 56 deletions

View File

@ -777,19 +777,25 @@ The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than the
**Whole-word masking and tokenization:** Whole-word masking requires `offset_mapping` from the tokenizer to determine word boundaries. This is incompatible with DAPT's concatenate-and-chunk approach (which destroys offset_mapping by merging documents). TAPT tokenizes each paragraph individually with truncation, preserving offset_mapping. The data collator handles dynamic padding per batch. This is a different code path from DAPT's concatenation, but the data justifies it: paragraphs are natural self-contained units, unlike DAPT's long filings that must be chunked.
**Estimated training time:** ~2,138 steps/epoch × 5 epochs = ~10,700 total steps. At seq_len=512 on the 3090 (~0.5-1s/step), ballpark 1-3 hours.
**Training time:** ~2,139 steps/epoch × 5 epochs = ~10,695 total steps. At ~1.84 it/s on the 3090, ~1.6 hours total.
### TAPT Launch — Whole-Word Masking Collator Bugs
### TAPT Launch — Whole-Word Masking Bugs
Launching TAPT hit two bugs in `transformers`' `DataCollatorForLanguageModeling` when `whole_word_mask=True`:
Launching TAPT required fighting through four bugs in `transformers`' `DataCollatorForLanguageModeling` when `whole_word_mask=True`, plus a Python 3.14 incompatibility that forced a version rollback.
1. **`offset_mapping` stripped before reaching the collator.** The Trainer's default `remove_unused_columns=True` drops any dataset column not in the model's `forward()` signature. Since `offset_mapping` is a collator input (not a model input), it was silently removed, causing the collator to receive a 0-dimensional array and crash with `IndexError: too many indices for array`. Fix: set `remove_unused_columns=False` when whole-word masking is enabled.
**Bug 1: `offset_mapping` stripped before reaching the collator.** The Trainer's default `remove_unused_columns=True` drops any dataset column not in the model's `forward()` signature. Since `offset_mapping` is a collator input (not a model input), it was silently removed, causing the collator to receive a 0-dimensional array and crash with `IndexError: too many indices for array`. Fix: set `remove_unused_columns=False` when whole-word masking is enabled.
2. **`offset_mapping` can't survive `tokenizer.pad()`.** Even with the column present, the collator's `torch_call()` passes all features — including `offset_mapping` — through `tokenizer.pad()`, which tries to tensorize the variable-length nested lists and crashes with `ValueError: expected sequence of length 138 at dim 1 (got 54)`. The collator pops `offset_mapping` *after* padding (line 784), but padding already failed. Fix: subclass `DataCollatorForLanguageModeling` with a `WholeWordMaskCollator` that strips `offset_mapping` from features before padding, pads `input_ids`/`attention_mask` normally, then manually pads `offset_mapping` to the batch's sequence length and passes it as a numpy array to `torch_mask_tokens()`.
**Bug 2: `offset_mapping` can't survive `tokenizer.pad()`.** Even with the column present, the collator's `torch_call()` passes all features — including `offset_mapping` — through `tokenizer.pad()`, which tries to tensorize the variable-length nested lists and crashes with `ValueError`. The collator pops `offset_mapping` *after* padding, but padding already failed. Fix: subclass `DataCollatorForLanguageModeling` to strip `offset_mapping` before padding.
3. **Python 3.14 pickle requirement.** Python 3.14 changed the default multiprocessing start method on Linux from `fork` to `forkserver`, which requires all dataloader worker arguments to be picklable. The initial fix used a closure (nested function capturing `base_collator` and `tokenizer`), which can't be pickled. Refactored from a closure to a proper class (`WholeWordMaskCollator`) to make it serializable.
**Bug 3: `offset_mapping` word boundary detection is broken for BPE tokenizers.** This was the most insidious bug — training ran but loss was ~6-8 (near-random, vs. expected ~1.5-2.0). The upstream `_calc_word_ids_and_prob_mask` detects word boundaries by checking if `token_start != prev_token_end` in the offset mapping. But BPE tokenizers (like ModernBERT's) absorb leading spaces into tokens, making ALL offsets contiguous: `"The" → (0,3), " company" → (3,11)`. Since 3 == 3, the algorithm treats the entire sequence as one giant "word." When 30% masking is applied to these mega-groups, it masks enormous contiguous spans, making prediction nearly impossible.
These are upstream issues in `transformers` — whole-word masking with `offset_mapping` appears to be under-tested in the current release. The `WholeWordMaskCollator` class in `python/src/dapt/train.py` works around all three.
**Fix:** Replaced `offset_mapping` entirely with the tokenizer's `word_ids()` method, which correctly identifies word boundaries for any tokenizer type (BPE, WordPiece, SentencePiece). The `WholeWordMaskCollator` in `python/src/dapt/train.py` implements whole-word masking from scratch: extracts `word_ids` before padding, selects `mlm_probability` fraction of unique word IDs per sequence, and masks all tokens belonging to selected words.
**Python 3.14 incompatibility.** Two separate issues forced a rollback to Python 3.13:
1. Python 3.14 changed the multiprocessing start method from `fork` to `forkserver`, requiring picklable dataloader collators (closures crash with `PicklingError`).
2. Python 3.14 changed `pickle.Pickler._batch_setitems` to take 3 arguments, breaking `dill` (used by `datasets` for config hashing). This was unfixable — even `dill` 0.4.1 and `datasets` 4.8.4 crashed. The breakage is deep in the `datasets` builder machinery and hit every codepath (`load_dataset`, `Dataset.from_list`, `dataset.map`).
Rolled `pyproject.toml` from `requires-python = ">=3.14"` to `">=3.13,<3.14"` and updated the flash-attn wheel URL from cp314 to cp313.
---
@ -826,7 +832,9 @@ Only nano's portion ($21.24) of the first run was wasted — the gemini and grok
| Data quality audit + remediation | ~4h | Generator investigation, 6 patches, orphan re-annotation, quality tier system, docs |
| Documentation + narrative | ~2h | Codebook updates, narrative writing, technical guide updates |
| Labelapp build + infrastructure | ~8h | Monorepo restructure, Next.js app, quiz/warmup/labeling flows, BIBD assignment, sampling, Docker deployment, timer + migration infrastructure |
| **Total to date** | **~35h** | |
| DAPT pre-training | ~14.5h GPU | 1 epoch on 500M tokens, RTX 3090. Two sessions (resumed from checkpoint-1280). |
| TAPT debugging + pre-training | ~2h dev + ~1.6h GPU | 4 bugs in transformers whole-word masking + Python 3.14 rollback. Training: 5 epochs on 72K paragraphs. |
| **Total to date** | **~53h** | Includes ~16h GPU time |
### Remaining Work (estimated)
@ -835,8 +843,6 @@ Only nano's portion ($21.24) of the first run was wasted — the gemini and grok
| Human labeling (1,200 paragraphs, 6 annotators) | ~6-8h | $0 (team labor) |
| Stage 2 judge production run (~3-5K paragraphs) | ~1h | ~$20-40 |
| Training data assembly | ~2h | $0 |
| DAPT pre-training (1 epoch) | ~4-8h GPU | $0 (own 3090) |
| TAPT pre-training (5 epochs, WWM) | ~1-3h GPU | $0 |
| Fine-tuning + ablations (7 experiments) | ~12-20h GPU | $0 |
| Full GenAI benchmark on 1,200 holdout (9 models) | ~1h | ~$30-50 |
| Evaluation + comparison + write-up | ~6-8h | $0 |
@ -988,8 +994,8 @@ Three models from three providers — minimizes correlated errors.
- **Re-annotation is cheap and validating.** Re-running Stage 1 on 1,537 patched paragraphs cost $3.30 and took 9 minutes. It confirmed that 7.7% of consensus labels were wrong due to the data issue — an empirical validation that the patch was necessary, not just cosmetic.
### On Training Infrastructure
- **Whole-word masking is a minefield in `transformers`.** Three separate bugs (column stripping, padding crash, pickle failure) had to be worked around to use `DataCollatorForLanguageModeling(whole_word_mask=True)` with `offset_mapping`. The feature exists in the API but is fragile in practice — expect to subclass the collator.
- **Python 3.14 breaks multiprocessing assumptions.** The switch from `fork` to `forkserver` means any closure or lambda passed as a dataloader collator will crash with a pickle error. Always use proper classes for collators when `dataloader_num_workers > 0`.
- **Whole-word masking in `transformers` is broken for BPE tokenizers.** The upstream `DataCollatorForLanguageModeling(whole_word_mask=True)` uses `offset_mapping` to detect word boundaries by checking for gaps in character offsets. This fails silently for BPE tokenizers that absorb leading spaces — all offsets are contiguous, so the entire sequence becomes one "word." Loss appears to train but sits at ~6-8 (near-random). The fix is to use the tokenizer's `word_ids()` method, which correctly identifies word boundaries for any tokenizer type, and implement masking yourself.
- **Python 3.14 is not ready for ML.** Both `dill` (via `datasets`) and PyTorch's multiprocessing (`fork``forkserver`) have breaking incompatibilities. Rolling back to 3.13 was the only viable path.
- **Flash Attention is mandatory for long sequences.** Without FA2, ModernBERT at seq_len=8192 ran at ~47s/step on an RTX 3090. With FA2, the same configuration ran at ~25s/step — and enabled further optimizations (batch size increase, torch.compile) that pushed it further.
- **Align hyperparameters with the base model's pre-training config.** ModernBERT was trained with weight_decay=1e-5 and 30% MLM probability. Using the BERT/RoBERTa default of 0.01 weight decay would have been wrong. Both published ModernBERT DAPT papers (BioClinical, Patent) independently validated these values.
- **torch.compile + gradient_checkpointing together is more than the sum of its parts.** On ModernBERT, this combination resolves a memory anomaly specific to FA2 during MLM training (AnswerDotAI/ModernBERT#172), freeing VRAM for larger batch sizes.

View File

@ -1,4 +1,4 @@
# Project Status — 2026-03-29
# Project Status — 2026-03-30
## What's Done
@ -17,22 +17,28 @@
- [x] Orphan word re-annotation: 1,537 paragraphs re-run ($3.30), merged into `stage1.patched.jsonl`
- [x] Codebook v3.0 with 3 major rulings
### DAPT Corpus
- [x] 14,568 documents, ~1.056B tokens, cleaned (XBRL, URLs, page numbers stripped)
- [x] Training pipeline verified end-to-end (PyTorch 2.10, CUDA, ModernBERT loads, tokenization works)
- [x] Config: 8192 seq_len, batch=4, grad_accum=8, 1 epoch, bf16, FA2, torch.compile, 500M tokens
### DAPT + TAPT Pre-Training
- [x] DAPT corpus: 14,568 documents, ~1.056B tokens, cleaned (XBRL, URLs, page numbers stripped)
- [x] DAPT training complete: eval loss 0.7250, perplexity 1.65. 1 epoch on 500M tokens, ~14.5h on RTX 3090.
- [x] DAPT checkpoint at `checkpoints/dapt/modernbert-large/final/`
- [x] TAPT config: 5 epochs, whole-word masking, seq_len=512, batch=32
- [x] Custom `WholeWordMaskCollator` (upstream `transformers` collator broken for BPE tokenizers)
- [x] Python 3.14 → 3.13 rollback (dill/datasets pickle incompatibility)
- [x] Procedure documented in `docs/DAPT-PROCEDURE.md`
### Documentation
- [x] `docs/DATA-QUALITY-AUDIT.md` — full audit with all patches and quality tiers
- [x] `docs/EDGAR-FILING-GENERATORS.md` — 14 generators with signatures and quality profiles
- [x] `docs/DAPT-PROCEDURE.md` — pre-flight checklist, commands, monitoring guide
- [x] `docs/NARRATIVE.md` — 11 phases documented through DAPT corpus prep
- [x] `docs/NARRATIVE.md` — 11 phases documented through TAPT launch
## What's In Progress
### DAPT Training — Complete
Final eval loss: 0.7250, perplexity: 1.65. Loss: 0.80 → 0.72 over 1 epoch on 500M tokens. ~14h total across 2 sessions on RTX 3090. Checkpoint at `checkpoints/dapt/modernbert-large/final/`.
### TAPT Training — Running
Training on 72K Item 1C paragraphs using DAPT checkpoint. 5 epochs, whole-word masking, seq_len=512, batch=32. Early loss: 1.46 → 1.40 (first 1% of training). Expected ~1.6h total on RTX 3090. Expecting final loss ~1.0-1.2.
```bash
bun run py:train dapt --config configs/tapt/modernbert.yaml
```
### Human Labeling (139/1,200)
- 3 of 6 annotators started: 68 + 50 + 21 paragraphs completed
@ -41,13 +47,7 @@ Final eval loss: 0.7250, perplexity: 1.65. Loss: 0.80 → 0.72 over 1 epoch on 5
## What's Next (in dependency order)
### 1. TAPT (~1-3h, ready to run)
Continue MLM on 72K Item 1C paragraphs using the DAPT checkpoint. 5 epochs, whole-word masking, seq_len=512, batch=32.
```bash
cd python && bun run py:train dapt --config configs/tapt/modernbert.yaml
```
### 2. Fine-tuning pipeline (no blockers — can build now)
### 1. Fine-tuning pipeline (no blockers — can build now)
Build the dual-head classifier (7-class category + 4-class specificity) with:
- Shared ModernBERT backbone + 2 linear classification heads
- Sample weighting from quality tiers (1.0 clean/headed/minor, 0.5 degraded)
@ -70,28 +70,28 @@ Combine all annotation sources into final training dataset:
- Judge low-confidence → downweight or exclude
- Quality tier sample weights applied
### 5. Judge production run (blocked on human gold labels)
### 4. Judge production run (blocked on human gold labels)
Run judge on ~409 unresolved + flagged majority cases. Validate against expanded gold set from human labels.
### 6. Fine-tuning + ablations (blocked on steps 1-4)
### 5. Fine-tuning + ablations (blocked on steps 1-3)
7 experiments: {base, +DAPT, +DAPT+TAPT} × {with/without SCL} + best config.
### 7. Evaluation + paper (blocked on everything above)
### 6. Evaluation + paper (blocked on everything above)
Full GenAI benchmark (9 models) on 1,200 holdout. Comparison tables. Write-up.
## Parallel Tracks
```
Track A (GPU): DAPT ──→ TAPT ──→ Fine-tuning ──→ Eval
Track A (GPU): DAPT ✓ → TAPT (running) → Fine-tuning → Eval
Track B (API): Judge v3 → Judge run ───┤
Track B (API): Judge v3 → Judge run ───────────
Track C (Human): Labeling (139/1200) → Gold set validation
Track D (Code): Fine-tune pipeline build ┘
Track D (Code): Fine-tune pipeline build ───────
```
Tracks A and D can proceed now. Track B can start (prompt update) but production run waits for Track C. Everything converges at fine-tuning.
TAPT finishes in ~1.5h. Track D (fine-tune pipeline) can proceed now. Track B can start (prompt update) but production run waits for Track C. Everything converges at fine-tuning.
## Key File Locations
@ -102,4 +102,6 @@ Tracks A and D can proceed now. Track B can start (prompt update) but production
| Quality scores | `data/paragraphs/quality/quality-scores.jsonl` (72,045) |
| DAPT corpus | `data/dapt-corpus/shard-*.jsonl` (14,756 docs) |
| DAPT config | `python/configs/dapt/modernbert.yaml` |
| TAPT config | `python/configs/tapt/modernbert.yaml` |
| DAPT checkpoint | `checkpoints/dapt/modernbert-large/final/` |
| Training CLI | `python/main.py dapt --config ...` |

View File

@ -1 +1 @@
3.14
3.13

View File

@ -3,7 +3,7 @@ name = "sec-cybert-train"
version = "0.1.0"
description = "SEC-cyBERT training pipeline: DAPT, TAPT, fine-tuning, and evaluation"
readme = "README.md"
requires-python = ">=3.14,<3.15"
requires-python = ">=3.13,<3.14"
dependencies = [
"torch>=2.11,<2.12",
"torchao>=0.17,<0.18",
@ -29,5 +29,5 @@ default = true
[tool.uv.sources]
torch = [ { index = "pytorch-cu130" } ]
flash-attn = { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.4/flash_attn-2.6.3%2Bcu130torch2.11-cp314-cp314-linux_x86_64.whl" }
flash-attn = { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.4/flash_attn-2.6.3%2Bcu130torch2.11-cp313-cp313-linux_x86_64.whl" }

View File

@ -7,7 +7,7 @@ and the starting checkpoint (base model for DAPT, DAPT checkpoint for TAPT).
from pathlib import Path
import numpy as np
import torch
from transformers import (
AutoModelForMaskedLM,
AutoTokenizer,
@ -22,29 +22,47 @@ from ..data.corpus import load_corpus, tokenize_and_chunk
class WholeWordMaskCollator(DataCollatorForLanguageModeling):
"""DataCollatorForLanguageModeling that handles offset_mapping correctly.
"""Whole-word masking collator using the tokenizer's word_ids.
The upstream torch_call() passes all features through tokenizer.pad(),
which chokes on offset_mapping (variable-length nested lists). This
subclass strips offset_mapping before padding and re-injects it as a
properly padded numpy array for the whole-word masking logic.
The upstream DataCollatorForLanguageModeling uses offset_mapping for
word boundary detection, which fails for BPE tokenizers that absorb
leading spaces (all token offsets are contiguous, so every token
appears to be the same "word"). This collator uses the tokenizer's
word_ids instead, which correctly identifies word boundaries for any
tokenizer type.
"""
def torch_call(self, examples: list[dict]) -> dict:
oms = [ex.pop("offset_mapping") for ex in examples]
if self.seed and self.generator is None:
self.create_rng()
# Strip word_ids before padding (not a model input)
word_ids_list = [ex.pop("word_ids") for ex in examples]
batch = pad_without_fast_tokenizer_warning(
self.tokenizer, examples, return_tensors="pt",
pad_to_multiple_of=self.pad_to_multiple_of,
)
special_tokens_mask = batch.pop("special_tokens_mask", None)
seq_len = batch["input_ids"].shape[1]
padded_oms = [om + [(0, 0)] * (seq_len - len(om)) for om in oms]
offset_mapping = np.array(padded_oms)
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
batch["input_ids"],
special_tokens_mask=special_tokens_mask,
offset_mapping=offset_mapping,
)
labels = batch["input_ids"].clone()
batch_size, seq_len = labels.shape
masked_indices = torch.zeros(batch_size, seq_len, dtype=torch.bool)
for i, wids in enumerate(word_ids_list):
# Unique word IDs (>= 0; -1 = special/padding tokens)
unique_words = list({w for w in wids if w >= 0})
if not unique_words:
continue
n_mask = max(1, round(len(unique_words) * self.mlm_probability))
perm = torch.randperm(len(unique_words), generator=self.generator)
words_to_mask = {unique_words[perm[j].item()] for j in range(n_mask)}
for j, wid in enumerate(wids):
if wid in words_to_mask:
masked_indices[i, j] = True
labels[~masked_indices] = -100
batch["input_ids"][masked_indices] = self.tokenizer.mask_token_id
batch["labels"] = labels
return batch