diff --git a/docs/NARRATIVE.md b/docs/NARRATIVE.md index 9ffd935..0e293e4 100644 --- a/docs/NARRATIVE.md +++ b/docs/NARRATIVE.md @@ -779,6 +779,18 @@ The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than the **Estimated training time:** ~2,138 steps/epoch × 5 epochs = ~10,700 total steps. At seq_len=512 on the 3090 (~0.5-1s/step), ballpark 1-3 hours. +### TAPT Launch — Whole-Word Masking Collator Bugs + +Launching TAPT hit two bugs in `transformers`' `DataCollatorForLanguageModeling` when `whole_word_mask=True`: + +1. **`offset_mapping` stripped before reaching the collator.** The Trainer's default `remove_unused_columns=True` drops any dataset column not in the model's `forward()` signature. Since `offset_mapping` is a collator input (not a model input), it was silently removed, causing the collator to receive a 0-dimensional array and crash with `IndexError: too many indices for array`. Fix: set `remove_unused_columns=False` when whole-word masking is enabled. + +2. **`offset_mapping` can't survive `tokenizer.pad()`.** Even with the column present, the collator's `torch_call()` passes all features — including `offset_mapping` — through `tokenizer.pad()`, which tries to tensorize the variable-length nested lists and crashes with `ValueError: expected sequence of length 138 at dim 1 (got 54)`. The collator pops `offset_mapping` *after* padding (line 784), but padding already failed. Fix: subclass `DataCollatorForLanguageModeling` with a `WholeWordMaskCollator` that strips `offset_mapping` from features before padding, pads `input_ids`/`attention_mask` normally, then manually pads `offset_mapping` to the batch's sequence length and passes it as a numpy array to `torch_mask_tokens()`. + +3. **Python 3.14 pickle requirement.** Python 3.14 changed the default multiprocessing start method on Linux from `fork` to `forkserver`, which requires all dataloader worker arguments to be picklable. The initial fix used a closure (nested function capturing `base_collator` and `tokenizer`), which can't be pickled. Refactored from a closure to a proper class (`WholeWordMaskCollator`) to make it serializable. + +These are upstream issues in `transformers` — whole-word masking with `offset_mapping` appears to be under-tested in the current release. The `WholeWordMaskCollator` class in `python/src/dapt/train.py` works around all three. + --- ## Cost and Time Ledger @@ -976,6 +988,8 @@ Three models from three providers — minimizes correlated errors. - **Re-annotation is cheap and validating.** Re-running Stage 1 on 1,537 patched paragraphs cost $3.30 and took 9 minutes. It confirmed that 7.7% of consensus labels were wrong due to the data issue — an empirical validation that the patch was necessary, not just cosmetic. ### On Training Infrastructure +- **Whole-word masking is a minefield in `transformers`.** Three separate bugs (column stripping, padding crash, pickle failure) had to be worked around to use `DataCollatorForLanguageModeling(whole_word_mask=True)` with `offset_mapping`. The feature exists in the API but is fragile in practice — expect to subclass the collator. +- **Python 3.14 breaks multiprocessing assumptions.** The switch from `fork` to `forkserver` means any closure or lambda passed as a dataloader collator will crash with a pickle error. Always use proper classes for collators when `dataloader_num_workers > 0`. - **Flash Attention is mandatory for long sequences.** Without FA2, ModernBERT at seq_len=8192 ran at ~47s/step on an RTX 3090. With FA2, the same configuration ran at ~25s/step — and enabled further optimizations (batch size increase, torch.compile) that pushed it further. - **Align hyperparameters with the base model's pre-training config.** ModernBERT was trained with weight_decay=1e-5 and 30% MLM probability. Using the BERT/RoBERTa default of 0.01 weight decay would have been wrong. Both published ModernBERT DAPT papers (BioClinical, Patent) independently validated these values. - **torch.compile + gradient_checkpointing together is more than the sum of its parts.** On ModernBERT, this combination resolves a memory anomaly specific to FA2 during MLM training (AnswerDotAI/ModernBERT#172), freeing VRAM for larger batch sizes. diff --git a/package.json b/package.json index 3e63b14..1088d87 100644 --- a/package.json +++ b/package.json @@ -19,7 +19,7 @@ "la:docker": "docker build -f labelapp/Dockerfile -t registry.claiborne.soy/labelapp:latest . --push", "ts:sec": "bun run --filter sec-cybert sec", "ts:typecheck": "bun run --filter sec-cybert typecheck", - "py:train": "cd python && uv run main.py", + "py:train": "cd python && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True uv run main.py", "typecheck": "bun run --filter '*' typecheck", "data:push": "./scripts/data-push.sh", "data:pull": "./scripts/data-pull.sh", @@ -34,4 +34,4 @@ "@types/bun": "^1.3.11", "@types/node": "^25.5.0" } -} \ No newline at end of file +} diff --git a/python/.python-version b/python/.python-version index 24ee5b1..6324d40 100644 --- a/python/.python-version +++ b/python/.python-version @@ -1 +1 @@ -3.13 +3.14 diff --git a/python/pyproject.toml b/python/pyproject.toml index 5f43454..0e52f54 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -3,17 +3,16 @@ name = "sec-cybert-train" version = "0.1.0" description = "SEC-cyBERT training pipeline: DAPT, TAPT, fine-tuning, and evaluation" readme = "README.md" -requires-python = ">=3.13" +requires-python = ">=3.14,<3.15" dependencies = [ - "torch>=2.11", - "transformers", - "datasets", - "accelerate", - "pyyaml", - "nvidia-nvshmem-cu13>=3.4.5", - "nvidia-cuda-cccl>=13.2.27", - "flash-attn", - "unsloth", + "torch>=2.11,<2.12", + "torchao>=0.17,<0.18", + "transformers>=5,<6", + "datasets>=4,<5", + "accelerate>=1,<2", + "pyyaml>=6,<7", + "flash-attn==2.6.3+cu130torch2.11", + "unsloth==2026.3.11", ] [project.scripts] @@ -29,7 +28,6 @@ url = "https://pypi.org/simple/" default = true [tool.uv.sources] -torch = [ - { index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, -] -flash-attn = { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.4/flash_attn-2.6.3%2Bcu130torch2.11-cp313-cp313-linux_x86_64.whl" } +torch = [ { index = "pytorch-cu130" } ] +flash-attn = { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.4/flash_attn-2.6.3%2Bcu130torch2.11-cp314-cp314-linux_x86_64.whl" } + diff --git a/python/src/dapt/train.py b/python/src/dapt/train.py index 3d3e9f1..34d0d2d 100644 --- a/python/src/dapt/train.py +++ b/python/src/dapt/train.py @@ -7,6 +7,7 @@ and the starting checkpoint (base model for DAPT, DAPT checkpoint for TAPT). from pathlib import Path +import numpy as np from transformers import ( AutoModelForMaskedLM, AutoTokenizer, @@ -14,11 +15,39 @@ from transformers import ( Trainer, TrainingArguments, ) +from transformers.data.data_collator import pad_without_fast_tokenizer_warning from ..common.config import DAPTConfig from ..data.corpus import load_corpus, tokenize_and_chunk +class WholeWordMaskCollator(DataCollatorForLanguageModeling): + """DataCollatorForLanguageModeling that handles offset_mapping correctly. + + The upstream torch_call() passes all features through tokenizer.pad(), + which chokes on offset_mapping (variable-length nested lists). This + subclass strips offset_mapping before padding and re-injects it as a + properly padded numpy array for the whole-word masking logic. + """ + + def torch_call(self, examples: list[dict]) -> dict: + oms = [ex.pop("offset_mapping") for ex in examples] + batch = pad_without_fast_tokenizer_warning( + self.tokenizer, examples, return_tensors="pt", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + special_tokens_mask = batch.pop("special_tokens_mask", None) + seq_len = batch["input_ids"].shape[1] + padded_oms = [om + [(0, 0)] * (seq_len - len(om)) for om in oms] + offset_mapping = np.array(padded_oms) + batch["input_ids"], batch["labels"] = self.torch_mask_tokens( + batch["input_ids"], + special_tokens_mask=special_tokens_mask, + offset_mapping=offset_mapping, + ) + return batch + + def train(config: DAPTConfig) -> None: """Run DAPT or TAPT training from a config.""" print(f"\n{'='*60}") @@ -115,7 +144,8 @@ def train(config: DAPTConfig) -> None: print(f" Cached to {cache_dir}\n") # Data collator — handles dynamic masking each epoch - collator = DataCollatorForLanguageModeling( + collator_cls = WholeWordMaskCollator if config.training.whole_word_mask else DataCollatorForLanguageModeling + collator = collator_cls( tokenizer=tokenizer, mlm=True, mlm_probability=config.training.mlm_probability, @@ -150,6 +180,7 @@ def train(config: DAPTConfig) -> None: save_total_limit=config.training.save_total_limit, dataloader_num_workers=config.training.dataloader_num_workers, seed=config.training.seed, + remove_unused_columns=not config.training.whole_word_mask, report_to="none", load_best_model_at_end=True, metric_for_best_model="eval_loss",