fix some bugs for tapt

This commit is contained in:
Joey Eamigh 2026-03-30 20:44:10 -04:00
parent 75ab92628b
commit 8190950f1a
No known key found for this signature in database
GPG Key ID: CE8C05DFFC53C9CB
5 changed files with 61 additions and 18 deletions

View File

@ -779,6 +779,18 @@ The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than the
**Estimated training time:** ~2,138 steps/epoch × 5 epochs = ~10,700 total steps. At seq_len=512 on the 3090 (~0.5-1s/step), ballpark 1-3 hours. **Estimated training time:** ~2,138 steps/epoch × 5 epochs = ~10,700 total steps. At seq_len=512 on the 3090 (~0.5-1s/step), ballpark 1-3 hours.
### TAPT Launch — Whole-Word Masking Collator Bugs
Launching TAPT hit two bugs in `transformers`' `DataCollatorForLanguageModeling` when `whole_word_mask=True`:
1. **`offset_mapping` stripped before reaching the collator.** The Trainer's default `remove_unused_columns=True` drops any dataset column not in the model's `forward()` signature. Since `offset_mapping` is a collator input (not a model input), it was silently removed, causing the collator to receive a 0-dimensional array and crash with `IndexError: too many indices for array`. Fix: set `remove_unused_columns=False` when whole-word masking is enabled.
2. **`offset_mapping` can't survive `tokenizer.pad()`.** Even with the column present, the collator's `torch_call()` passes all features — including `offset_mapping` — through `tokenizer.pad()`, which tries to tensorize the variable-length nested lists and crashes with `ValueError: expected sequence of length 138 at dim 1 (got 54)`. The collator pops `offset_mapping` *after* padding (line 784), but padding already failed. Fix: subclass `DataCollatorForLanguageModeling` with a `WholeWordMaskCollator` that strips `offset_mapping` from features before padding, pads `input_ids`/`attention_mask` normally, then manually pads `offset_mapping` to the batch's sequence length and passes it as a numpy array to `torch_mask_tokens()`.
3. **Python 3.14 pickle requirement.** Python 3.14 changed the default multiprocessing start method on Linux from `fork` to `forkserver`, which requires all dataloader worker arguments to be picklable. The initial fix used a closure (nested function capturing `base_collator` and `tokenizer`), which can't be pickled. Refactored from a closure to a proper class (`WholeWordMaskCollator`) to make it serializable.
These are upstream issues in `transformers` — whole-word masking with `offset_mapping` appears to be under-tested in the current release. The `WholeWordMaskCollator` class in `python/src/dapt/train.py` works around all three.
--- ---
## Cost and Time Ledger ## Cost and Time Ledger
@ -976,6 +988,8 @@ Three models from three providers — minimizes correlated errors.
- **Re-annotation is cheap and validating.** Re-running Stage 1 on 1,537 patched paragraphs cost $3.30 and took 9 minutes. It confirmed that 7.7% of consensus labels were wrong due to the data issue — an empirical validation that the patch was necessary, not just cosmetic. - **Re-annotation is cheap and validating.** Re-running Stage 1 on 1,537 patched paragraphs cost $3.30 and took 9 minutes. It confirmed that 7.7% of consensus labels were wrong due to the data issue — an empirical validation that the patch was necessary, not just cosmetic.
### On Training Infrastructure ### On Training Infrastructure
- **Whole-word masking is a minefield in `transformers`.** Three separate bugs (column stripping, padding crash, pickle failure) had to be worked around to use `DataCollatorForLanguageModeling(whole_word_mask=True)` with `offset_mapping`. The feature exists in the API but is fragile in practice — expect to subclass the collator.
- **Python 3.14 breaks multiprocessing assumptions.** The switch from `fork` to `forkserver` means any closure or lambda passed as a dataloader collator will crash with a pickle error. Always use proper classes for collators when `dataloader_num_workers > 0`.
- **Flash Attention is mandatory for long sequences.** Without FA2, ModernBERT at seq_len=8192 ran at ~47s/step on an RTX 3090. With FA2, the same configuration ran at ~25s/step — and enabled further optimizations (batch size increase, torch.compile) that pushed it further. - **Flash Attention is mandatory for long sequences.** Without FA2, ModernBERT at seq_len=8192 ran at ~47s/step on an RTX 3090. With FA2, the same configuration ran at ~25s/step — and enabled further optimizations (batch size increase, torch.compile) that pushed it further.
- **Align hyperparameters with the base model's pre-training config.** ModernBERT was trained with weight_decay=1e-5 and 30% MLM probability. Using the BERT/RoBERTa default of 0.01 weight decay would have been wrong. Both published ModernBERT DAPT papers (BioClinical, Patent) independently validated these values. - **Align hyperparameters with the base model's pre-training config.** ModernBERT was trained with weight_decay=1e-5 and 30% MLM probability. Using the BERT/RoBERTa default of 0.01 weight decay would have been wrong. Both published ModernBERT DAPT papers (BioClinical, Patent) independently validated these values.
- **torch.compile + gradient_checkpointing together is more than the sum of its parts.** On ModernBERT, this combination resolves a memory anomaly specific to FA2 during MLM training (AnswerDotAI/ModernBERT#172), freeing VRAM for larger batch sizes. - **torch.compile + gradient_checkpointing together is more than the sum of its parts.** On ModernBERT, this combination resolves a memory anomaly specific to FA2 during MLM training (AnswerDotAI/ModernBERT#172), freeing VRAM for larger batch sizes.

View File

@ -19,7 +19,7 @@
"la:docker": "docker build -f labelapp/Dockerfile -t registry.claiborne.soy/labelapp:latest . --push", "la:docker": "docker build -f labelapp/Dockerfile -t registry.claiborne.soy/labelapp:latest . --push",
"ts:sec": "bun run --filter sec-cybert sec", "ts:sec": "bun run --filter sec-cybert sec",
"ts:typecheck": "bun run --filter sec-cybert typecheck", "ts:typecheck": "bun run --filter sec-cybert typecheck",
"py:train": "cd python && uv run main.py", "py:train": "cd python && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True uv run main.py",
"typecheck": "bun run --filter '*' typecheck", "typecheck": "bun run --filter '*' typecheck",
"data:push": "./scripts/data-push.sh", "data:push": "./scripts/data-push.sh",
"data:pull": "./scripts/data-pull.sh", "data:pull": "./scripts/data-pull.sh",
@ -34,4 +34,4 @@
"@types/bun": "^1.3.11", "@types/bun": "^1.3.11",
"@types/node": "^25.5.0" "@types/node": "^25.5.0"
} }
} }

View File

@ -1 +1 @@
3.13 3.14

View File

@ -3,17 +3,16 @@ name = "sec-cybert-train"
version = "0.1.0" version = "0.1.0"
description = "SEC-cyBERT training pipeline: DAPT, TAPT, fine-tuning, and evaluation" description = "SEC-cyBERT training pipeline: DAPT, TAPT, fine-tuning, and evaluation"
readme = "README.md" readme = "README.md"
requires-python = ">=3.13" requires-python = ">=3.14,<3.15"
dependencies = [ dependencies = [
"torch>=2.11", "torch>=2.11,<2.12",
"transformers", "torchao>=0.17,<0.18",
"datasets", "transformers>=5,<6",
"accelerate", "datasets>=4,<5",
"pyyaml", "accelerate>=1,<2",
"nvidia-nvshmem-cu13>=3.4.5", "pyyaml>=6,<7",
"nvidia-cuda-cccl>=13.2.27", "flash-attn==2.6.3+cu130torch2.11",
"flash-attn", "unsloth==2026.3.11",
"unsloth",
] ]
[project.scripts] [project.scripts]
@ -29,7 +28,6 @@ url = "https://pypi.org/simple/"
default = true default = true
[tool.uv.sources] [tool.uv.sources]
torch = [ torch = [ { index = "pytorch-cu130" } ]
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" }, flash-attn = { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.4/flash_attn-2.6.3%2Bcu130torch2.11-cp314-cp314-linux_x86_64.whl" }
]
flash-attn = { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.4/flash_attn-2.6.3%2Bcu130torch2.11-cp313-cp313-linux_x86_64.whl" }

View File

@ -7,6 +7,7 @@ and the starting checkpoint (base model for DAPT, DAPT checkpoint for TAPT).
from pathlib import Path from pathlib import Path
import numpy as np
from transformers import ( from transformers import (
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoTokenizer, AutoTokenizer,
@ -14,11 +15,39 @@ from transformers import (
Trainer, Trainer,
TrainingArguments, TrainingArguments,
) )
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
from ..common.config import DAPTConfig from ..common.config import DAPTConfig
from ..data.corpus import load_corpus, tokenize_and_chunk from ..data.corpus import load_corpus, tokenize_and_chunk
class WholeWordMaskCollator(DataCollatorForLanguageModeling):
"""DataCollatorForLanguageModeling that handles offset_mapping correctly.
The upstream torch_call() passes all features through tokenizer.pad(),
which chokes on offset_mapping (variable-length nested lists). This
subclass strips offset_mapping before padding and re-injects it as a
properly padded numpy array for the whole-word masking logic.
"""
def torch_call(self, examples: list[dict]) -> dict:
oms = [ex.pop("offset_mapping") for ex in examples]
batch = pad_without_fast_tokenizer_warning(
self.tokenizer, examples, return_tensors="pt",
pad_to_multiple_of=self.pad_to_multiple_of,
)
special_tokens_mask = batch.pop("special_tokens_mask", None)
seq_len = batch["input_ids"].shape[1]
padded_oms = [om + [(0, 0)] * (seq_len - len(om)) for om in oms]
offset_mapping = np.array(padded_oms)
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
batch["input_ids"],
special_tokens_mask=special_tokens_mask,
offset_mapping=offset_mapping,
)
return batch
def train(config: DAPTConfig) -> None: def train(config: DAPTConfig) -> None:
"""Run DAPT or TAPT training from a config.""" """Run DAPT or TAPT training from a config."""
print(f"\n{'='*60}") print(f"\n{'='*60}")
@ -115,7 +144,8 @@ def train(config: DAPTConfig) -> None:
print(f" Cached to {cache_dir}\n") print(f" Cached to {cache_dir}\n")
# Data collator — handles dynamic masking each epoch # Data collator — handles dynamic masking each epoch
collator = DataCollatorForLanguageModeling( collator_cls = WholeWordMaskCollator if config.training.whole_word_mask else DataCollatorForLanguageModeling
collator = collator_cls(
tokenizer=tokenizer, tokenizer=tokenizer,
mlm=True, mlm=True,
mlm_probability=config.training.mlm_probability, mlm_probability=config.training.mlm_probability,
@ -150,6 +180,7 @@ def train(config: DAPTConfig) -> None:
save_total_limit=config.training.save_total_limit, save_total_limit=config.training.save_total_limit,
dataloader_num_workers=config.training.dataloader_num_workers, dataloader_num_workers=config.training.dataloader_num_workers,
seed=config.training.seed, seed=config.training.seed,
remove_unused_columns=not config.training.whole_word_mask,
report_to="none", report_to="none",
load_best_model_at_end=True, load_best_model_at_end=True,
metric_for_best_model="eval_loss", metric_for_best_model="eval_loss",