fix some bugs for tapt
This commit is contained in:
parent
75ab92628b
commit
8190950f1a
@ -779,6 +779,18 @@ The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than the
|
||||
|
||||
**Estimated training time:** ~2,138 steps/epoch × 5 epochs = ~10,700 total steps. At seq_len=512 on the 3090 (~0.5-1s/step), ballpark 1-3 hours.
|
||||
|
||||
### TAPT Launch — Whole-Word Masking Collator Bugs
|
||||
|
||||
Launching TAPT hit two bugs in `transformers`' `DataCollatorForLanguageModeling` when `whole_word_mask=True`:
|
||||
|
||||
1. **`offset_mapping` stripped before reaching the collator.** The Trainer's default `remove_unused_columns=True` drops any dataset column not in the model's `forward()` signature. Since `offset_mapping` is a collator input (not a model input), it was silently removed, causing the collator to receive a 0-dimensional array and crash with `IndexError: too many indices for array`. Fix: set `remove_unused_columns=False` when whole-word masking is enabled.
|
||||
|
||||
2. **`offset_mapping` can't survive `tokenizer.pad()`.** Even with the column present, the collator's `torch_call()` passes all features — including `offset_mapping` — through `tokenizer.pad()`, which tries to tensorize the variable-length nested lists and crashes with `ValueError: expected sequence of length 138 at dim 1 (got 54)`. The collator pops `offset_mapping` *after* padding (line 784), but padding already failed. Fix: subclass `DataCollatorForLanguageModeling` with a `WholeWordMaskCollator` that strips `offset_mapping` from features before padding, pads `input_ids`/`attention_mask` normally, then manually pads `offset_mapping` to the batch's sequence length and passes it as a numpy array to `torch_mask_tokens()`.
|
||||
|
||||
3. **Python 3.14 pickle requirement.** Python 3.14 changed the default multiprocessing start method on Linux from `fork` to `forkserver`, which requires all dataloader worker arguments to be picklable. The initial fix used a closure (nested function capturing `base_collator` and `tokenizer`), which can't be pickled. Refactored from a closure to a proper class (`WholeWordMaskCollator`) to make it serializable.
|
||||
|
||||
These are upstream issues in `transformers` — whole-word masking with `offset_mapping` appears to be under-tested in the current release. The `WholeWordMaskCollator` class in `python/src/dapt/train.py` works around all three.
|
||||
|
||||
---
|
||||
|
||||
## Cost and Time Ledger
|
||||
@ -976,6 +988,8 @@ Three models from three providers — minimizes correlated errors.
|
||||
- **Re-annotation is cheap and validating.** Re-running Stage 1 on 1,537 patched paragraphs cost $3.30 and took 9 minutes. It confirmed that 7.7% of consensus labels were wrong due to the data issue — an empirical validation that the patch was necessary, not just cosmetic.
|
||||
|
||||
### On Training Infrastructure
|
||||
- **Whole-word masking is a minefield in `transformers`.** Three separate bugs (column stripping, padding crash, pickle failure) had to be worked around to use `DataCollatorForLanguageModeling(whole_word_mask=True)` with `offset_mapping`. The feature exists in the API but is fragile in practice — expect to subclass the collator.
|
||||
- **Python 3.14 breaks multiprocessing assumptions.** The switch from `fork` to `forkserver` means any closure or lambda passed as a dataloader collator will crash with a pickle error. Always use proper classes for collators when `dataloader_num_workers > 0`.
|
||||
- **Flash Attention is mandatory for long sequences.** Without FA2, ModernBERT at seq_len=8192 ran at ~47s/step on an RTX 3090. With FA2, the same configuration ran at ~25s/step — and enabled further optimizations (batch size increase, torch.compile) that pushed it further.
|
||||
- **Align hyperparameters with the base model's pre-training config.** ModernBERT was trained with weight_decay=1e-5 and 30% MLM probability. Using the BERT/RoBERTa default of 0.01 weight decay would have been wrong. Both published ModernBERT DAPT papers (BioClinical, Patent) independently validated these values.
|
||||
- **torch.compile + gradient_checkpointing together is more than the sum of its parts.** On ModernBERT, this combination resolves a memory anomaly specific to FA2 during MLM training (AnswerDotAI/ModernBERT#172), freeing VRAM for larger batch sizes.
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
"la:docker": "docker build -f labelapp/Dockerfile -t registry.claiborne.soy/labelapp:latest . --push",
|
||||
"ts:sec": "bun run --filter sec-cybert sec",
|
||||
"ts:typecheck": "bun run --filter sec-cybert typecheck",
|
||||
"py:train": "cd python && uv run main.py",
|
||||
"py:train": "cd python && PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True uv run main.py",
|
||||
"typecheck": "bun run --filter '*' typecheck",
|
||||
"data:push": "./scripts/data-push.sh",
|
||||
"data:pull": "./scripts/data-pull.sh",
|
||||
|
||||
@ -1 +1 @@
|
||||
3.13
|
||||
3.14
|
||||
|
||||
@ -3,17 +3,16 @@ name = "sec-cybert-train"
|
||||
version = "0.1.0"
|
||||
description = "SEC-cyBERT training pipeline: DAPT, TAPT, fine-tuning, and evaluation"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
requires-python = ">=3.14,<3.15"
|
||||
dependencies = [
|
||||
"torch>=2.11",
|
||||
"transformers",
|
||||
"datasets",
|
||||
"accelerate",
|
||||
"pyyaml",
|
||||
"nvidia-nvshmem-cu13>=3.4.5",
|
||||
"nvidia-cuda-cccl>=13.2.27",
|
||||
"flash-attn",
|
||||
"unsloth",
|
||||
"torch>=2.11,<2.12",
|
||||
"torchao>=0.17,<0.18",
|
||||
"transformers>=5,<6",
|
||||
"datasets>=4,<5",
|
||||
"accelerate>=1,<2",
|
||||
"pyyaml>=6,<7",
|
||||
"flash-attn==2.6.3+cu130torch2.11",
|
||||
"unsloth==2026.3.11",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@ -29,7 +28,6 @@ url = "https://pypi.org/simple/"
|
||||
default = true
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cu130", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
flash-attn = { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.4/flash_attn-2.6.3%2Bcu130torch2.11-cp313-cp313-linux_x86_64.whl" }
|
||||
torch = [ { index = "pytorch-cu130" } ]
|
||||
flash-attn = { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.4/flash_attn-2.6.3%2Bcu130torch2.11-cp314-cp314-linux_x86_64.whl" }
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ and the starting checkpoint (base model for DAPT, DAPT checkpoint for TAPT).
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from transformers import (
|
||||
AutoModelForMaskedLM,
|
||||
AutoTokenizer,
|
||||
@ -14,11 +15,39 @@ from transformers import (
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
|
||||
|
||||
from ..common.config import DAPTConfig
|
||||
from ..data.corpus import load_corpus, tokenize_and_chunk
|
||||
|
||||
|
||||
class WholeWordMaskCollator(DataCollatorForLanguageModeling):
|
||||
"""DataCollatorForLanguageModeling that handles offset_mapping correctly.
|
||||
|
||||
The upstream torch_call() passes all features through tokenizer.pad(),
|
||||
which chokes on offset_mapping (variable-length nested lists). This
|
||||
subclass strips offset_mapping before padding and re-injects it as a
|
||||
properly padded numpy array for the whole-word masking logic.
|
||||
"""
|
||||
|
||||
def torch_call(self, examples: list[dict]) -> dict:
|
||||
oms = [ex.pop("offset_mapping") for ex in examples]
|
||||
batch = pad_without_fast_tokenizer_warning(
|
||||
self.tokenizer, examples, return_tensors="pt",
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
)
|
||||
special_tokens_mask = batch.pop("special_tokens_mask", None)
|
||||
seq_len = batch["input_ids"].shape[1]
|
||||
padded_oms = [om + [(0, 0)] * (seq_len - len(om)) for om in oms]
|
||||
offset_mapping = np.array(padded_oms)
|
||||
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
|
||||
batch["input_ids"],
|
||||
special_tokens_mask=special_tokens_mask,
|
||||
offset_mapping=offset_mapping,
|
||||
)
|
||||
return batch
|
||||
|
||||
|
||||
def train(config: DAPTConfig) -> None:
|
||||
"""Run DAPT or TAPT training from a config."""
|
||||
print(f"\n{'='*60}")
|
||||
@ -115,7 +144,8 @@ def train(config: DAPTConfig) -> None:
|
||||
print(f" Cached to {cache_dir}\n")
|
||||
|
||||
# Data collator — handles dynamic masking each epoch
|
||||
collator = DataCollatorForLanguageModeling(
|
||||
collator_cls = WholeWordMaskCollator if config.training.whole_word_mask else DataCollatorForLanguageModeling
|
||||
collator = collator_cls(
|
||||
tokenizer=tokenizer,
|
||||
mlm=True,
|
||||
mlm_probability=config.training.mlm_probability,
|
||||
@ -150,6 +180,7 @@ def train(config: DAPTConfig) -> None:
|
||||
save_total_limit=config.training.save_total_limit,
|
||||
dataloader_num_workers=config.training.dataloader_num_workers,
|
||||
seed=config.training.seed,
|
||||
remove_unused_columns=not config.training.whole_word_mask,
|
||||
report_to="none",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="eval_loss",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user