first finetune attempt

This commit is contained in:
Joey Eamigh 2026-04-05 12:16:16 -04:00
parent 531317f7d4
commit 42f8849b14
No known key found for this signature in database
GPG Key ID: CE8C05DFFC53C9CB
13 changed files with 1527 additions and 19 deletions

1
.gitignore vendored
View File

@ -58,3 +58,4 @@ python/*.whl
# Personal notes
docs/STRATEGY-NOTES.md
unsloth_compiled_cache/

View File

@ -418,6 +418,163 @@ The cost overshoot ($200 vs $175 estimate) is entirely from annotating 72K parag
---
## Phase 8: Fine-Tuning — From 0.52 to 0.94 Specificity F1
### Training Data Assembly
Built `python/src/finetune/data.py` to merge Stage 1 consensus labels (72,045 paragraphs) with paragraph text, quality tiers, and specificity confidence metadata.
**Exclusions:**
- 1,200 holdout paragraphs (reserved for evaluation)
- 614 individually truncated paragraphs (initial plan was to exclude 72 entire filings, but paragraph-level filtering is more targeted and preserves more data)
**Sample weighting:** clean/headed/minor = 1.0×, degraded = 0.5× (4,331 paragraphs at half weight).
**Result:** 70,231 training paragraphs, stratified 90/10 into 63,214 train / 7,024 val.
### Architecture: Dual-Head ModernBERT
The model architecture: ModernBERT-large backbone (395M params) → pooled representation → dropout → two independent classification heads:
1. **Category head:** Linear(1024, 7) with weighted cross-entropy loss. Standard multi-class classification.
2. **Specificity head:** Ordinal classification. The specificity dimension (L1→L2→L3→L4) has natural ordering — predicting L1 when truth is L4 is worse than predicting L3. This ordering should be reflected in the model architecture and loss function.
The initial architecture used **CORAL** (Cao et al. 2020) for the specificity head: a single shared weight vector with learned bias offsets for each ordinal threshold. This is the standard approach for ordinal regression.
### Ablation Grid: 12 Configurations × 1 Epoch
Ran a systematic ablation over three axes:
- **Checkpoint:** base ModernBERT-large vs DAPT checkpoint vs TAPT checkpoint
- **Class weighting:** inverse-frequency weights vs uniform
- **Loss type:** cross-entropy vs focal loss (γ=2.0)
Results (1 epoch each, ~15 min/run, ~3 hours total):
| Rank | Configuration | Combined F1 | Cat F1 | Spec F1 |
|------|-------------|-------------|--------|---------|
| 1 | base + weighted + CE | **0.685** | **0.900** | 0.469 |
| 2 | DAPT + unweighted + focal | 0.684 | 0.892 | **0.476** |
| 3 | DAPT + weighted + CE | 0.681 | 0.896 | 0.466 |
| 4 | base + unweighted + CE | 0.680 | 0.892 | 0.467 |
| 5 | TAPT + weighted + CE | 0.675 | 0.896 | 0.455 |
| ... | | | | |
| 12 | TAPT + weighted + focal | 0.649 | 0.849 | 0.449 |
**Finding 1: DAPT/TAPT pre-training did not help.** Base ModernBERT-large outperformed both domain-adapted checkpoints. This is a noteworthy null result. ModernBERT-large was already pre-trained on a massive, diverse web corpus that likely includes SEC filings. Additional narrow-domain pre-training appears to cause mild catastrophic forgetting — the model loses general linguistic features while gaining domain-specific ones that the fine-tuning task doesn't benefit from. TAPT was consistently worst, suggesting the small corpus (72K paragraphs × 5 epochs at 30% masking) caused overfitting during MLM pre-training.
**Finding 2: Weighted CE is the best loss combination.** Class weighting helps category F1 significantly (0.900 vs 0.892 for base). Focal loss helps specificity slightly but hurts category. Weighted + focal = too much correction (consistently bottom tier) — both mechanisms independently reduce majority-class influence, and combining them over-corrects.
### Full Training: The CORAL Wall (5 Epochs)
Trained the top 2 configurations for 5 epochs each (~1.5 hours per run):
**base_weighted_ce (5 epochs):**
| Epoch | Combined | Cat F1 | Spec F1 | QWK |
|-------|----------|--------|---------|-----|
| 1 | 0.670 | 0.879 | 0.461 | 0.800 |
| 3 | 0.704 | 0.924 | 0.485 | 0.833 |
| 5 | **0.724** | **0.932** | **0.517** | **0.840** |
Category F1 reached 0.932 — well above the 0.80 target. But specificity F1 plateaued at 0.517. Per-class breakdown revealed the problem:
| Specificity | F1 |
|-------------|-----|
| L1 (Generic) | 0.79 |
| L2 (Domain-Adapted) | **0.29** |
| L3 (Firm-Specific) | **0.31** |
| L4 (Quantified) | 0.55 |
L2 and L3 were dragging macro F1 down to 0.52. QWK was 0.84 — meaning the model's ordinal *ranking* was good (rarely confusing L1 with L4), but the exact *boundary placement* between adjacent levels was fuzzy.
### The CORAL Diagnosis
CORAL uses a single weight vector **w** with shifted biases: logit_k = **w**·**x** + b_k. This means the *same features* separate L1 from L2 as separate L3 from L4. But the three specificity transitions require fundamentally different evidence:
- **L1→L2:** Cybersecurity terminology detection (the ERM test — does the paragraph use language a general business professional wouldn't?)
- **L2→L3:** Firm-unique fact detection (named roles, specific systems, internal programs)
- **L3→L4:** Quantified/verifiable claim detection (dollar amounts, dates, third-party firm names)
A single shared weight vector cannot simultaneously encode "presence of domain terminology," "presence of named entities," and "presence of numerical quantities" — these are orthogonal signal types in the embedding space. CORAL's structural constraint was forcing the model to find one feature direction that approximates all three, resulting in blurry boundaries everywhere.
Additionally, [CLS] token pooling loses distributed signals. A paragraph that mentions "CISO" once in a subordinate clause should be L3, but [CLS] may not attend strongly to that one token.
### Architecture Iteration: Independent Thresholds
Replaced CORAL with three changes (implemented in `python/src/finetune/model.py`):
1. **Independent threshold heads.** Three separate binary classifiers, each with its own `Linear(1024→256→1)` MLP:
- threshold_L2plus: "Has any qualifying facts?" (L1 vs L2+)
- threshold_L3plus: "Has firm-specific facts?" (≤L2 vs L3+)
- threshold_L4: "Has quantified facts?" (≤L3 vs L4)
Same cumulative binary targets as CORAL (label k → [1]×k + [0]×(3k)), but each threshold learns independent features. The prediction is: level = count(sigmoid(logit_k) > 0.5).
2. **Attention pooling.** Replaced [CLS] with a learned attention pool over all token representations. This lets the model attend to specific evidence tokens (CISO, $2M, NIST) distributed anywhere in the paragraph.
3. **Specificity confidence filtering.** Only compute specificity loss on paragraphs where all 3 Grok runs agreed on specificity (91.3% of training data, as tracked in consensus `specificityAgreement.agreed`). The ~6K disagreement cases are exactly the noisy boundary labels that confuse the model. Category loss still uses all samples.
4. **Ordinal consistency regularization.** Penalty (weight 0.1) when threshold k fires but threshold k-1 doesn't — e.g., the model says "has firm-specific facts" but not "has domain terms." This enforces the cumulative structure without the rigidity of CORAL's shared weights.
### Results: The Independent Threshold Breakthrough
**Config:** `configs/finetune/iter1-independent.yaml` — base ModernBERT-large, independent thresholds with 256-dim MLP, attention pooling, spec confidence filtering, 15 epochs.
| Epoch | Combined | Cat F1 | Spec F1 | QWK | L2 F1 | L3 F1 |
|-------|----------|--------|---------|-----|-------|-------|
| 1 | 0.855 | 0.867 | **0.844** | 0.874 | 0.782 | 0.821 |
| 2 | 0.913 | 0.909 | **0.918** | 0.935 | 0.887 | 0.911 |
| 3 | 0.925 | 0.919 | 0.931 | 0.945 | 0.893 | 0.926 |
| 5 | 0.938 | 0.936 | 0.940 | 0.949 | — | — |
| **8** | **0.944** | **0.943** | **0.945** | **0.952** | **0.923** | **0.940** |
| 10 | 0.944 | 0.943 | 0.945 | 0.952 | — | — |
The model exceeded 0.80 on both heads **at epoch 1**. By epoch 8 it plateaued at **0.944 combined F1 (cat=0.943, spec=0.945, QWK=0.952)**. Training was stopped at epoch 11 — the train-eval loss gap (0.06 vs 0.49, ~8×) indicated the model was memorizing without further improving eval metrics.
**The improvement was transformative.** Spec F1: 0.517 → 0.945 (+0.428). L2 F1: 0.29 → 0.92. L3 F1: 0.31 → 0.94. The independent thresholds + attention pooling + confidence filtering combination addressed all three root causes simultaneously.
**What mattered most?** The independent thresholds were the primary driver. CORAL's shared weight vector was the bottleneck — when we let each ordinal transition learn its own features, the model immediately distinguished the three types of specificity evidence. Attention pooling and confidence filtering likely contributed meaningful improvements, but we did not run an ablation to isolate their individual contributions (the combined effect was so strong that decomposition was deprioritized).
### Overfitting Observations
Encoder models absolutely can overfit. The 8× train-eval loss gap by epoch 10 is substantial. However, eval *metrics* (F1, QWK) remained stable from epoch 811, exhibiting "benign overfitting" — the model becomes more confident on training examples (lower train loss) without changing its decision boundaries (stable eval F1). The practical implication: monitor eval F1 for model selection, not eval loss.
For future runs: increase `save_total_limit` to preserve all epoch checkpoints, and add early stopping with patience ≥ 3 on `spec_macro_f1`.
### Training Configuration Reference
| Parameter | Value |
|-----------|-------|
| Backbone | answerdotai/ModernBERT-large (395M params) |
| Pooling | Learned attention |
| Category head | Linear(1024, 7) + weighted CE |
| Specificity head | 3× Independent(Linear(1024→256→1)) + cumulative BCE |
| Ordinal consistency | 0.1 weight |
| Spec confidence filter | Unanimous labels only (91.3% of data) |
| Batch size | 32 |
| Learning rate | 5e-5 |
| Warmup | 10% of total steps |
| Precision | bf16 + tf32 |
| Attention | Flash Attention 2 |
| Compilation | torch.compile |
| Optimizer | AdamW (fused) |
| Peak VRAM | ~18 GB / 24.6 GB (RTX 3090) |
| Training speed | ~2.1 it/s (batch 32, seq 512) |
| Best epoch | 8 (stable through 11) |
**Checkpoint:** `checkpoints/finetune/iter1-independent/final/`
### What Remains
These metrics are on the validation set — same distribution as training (Grok ×3 consensus labels). The true test is the **holdout gold set** with human labels, which may reveal:
- Systematic Grok-vs-human disagreements (especially at L2/L3 boundaries)
- Whether the model learned Grok's biases rather than the underlying construct
- Per-class F1 on the more diverse holdout distribution (the training data overrepresents RMP at 43%)
As a proxy before human labels arrive, evaluation against GPT-5.4 and Opus benchmark labels on the holdout will provide an intermediate signal.
---
## v1 Reference
The complete v1 narrative — Stage 1 prompt engineering (12+ iterations), model benchmarking (21+ models, 12 providers), human labeling webapp, gold set adjudication (13-signal cross-analysis), codebook iterations v1.0v3.5 — is preserved at `docs/NARRATIVE-v1.md`.

View File

@ -0,0 +1,131 @@
# Specificity F1 Improvement Plan
**Goal:** Macro F1 > 0.80 on both category and specificity heads
**Current:** Cat F1=0.932 (passing), Spec F1=0.517 (needs ~+0.28)
**Constraint:** Specificity is paragraph-level and category-independent by design
## Diagnosis
Per-class spec F1 (best run, epoch 5):
- L1 (Generic): ~0.79
- L2 (Domain-Adapted): ~0.29
- L3 (Firm-Specific): ~0.31
- L4 (Quantified): ~0.55
L2 and L3 drag macro F1 from ~0.67 average to 0.52. QWK=0.840 shows ordinal
ranking is strong — the problem is exact boundary placement between adjacent levels.
### Root causes
1. **CORAL's shared weight vector.** CORAL uses `logit_k = w·x + b_k` — one weight
vector for all thresholds. But the three transitions require different features:
- L1→L2: cybersecurity terminology detection (ERM test)
- L2→L3: firm-unique fact detection (named roles, systems)
- L3→L4: quantified/verifiable claim detection (numbers, dates)
A single w can't capture all three signal types.
2. **[CLS] pooling loses distributed signals.** A single "CISO" mention anywhere
in a paragraph should bump to L3, but [CLS] may not attend to it.
3. **Label noise at boundaries.** 8.7% of training labels had Grok specificity
disagreement, concentrated at L1/L2 and L2/L3 boundaries.
4. **Insufficient training.** Model was still improving at epoch 5 — not converged.
## Ideas (ordered by estimated ROI)
### Tier 1 — Implement first
**A. Independent threshold heads (replace CORAL)**
Replace the single CORAL weight vector with 3 independent binary classifiers,
each with its own learned features:
- threshold_L2plus: Linear(hidden, 1) — "has any qualifying facts?"
- threshold_L3plus: Linear(hidden, 1) — "has firm-specific facts?"
- threshold_L4: Linear(hidden, 1) — "has quantified/verifiable facts?"
Same cumulative binary targets as CORAL, but each threshold has independent weights.
Optionally upgrade to 2-layer MLP (hidden→256→1) for richer decision boundaries.
**B. High-confidence label filtering**
Only train specificity on paragraphs where all 3 Grok runs agreed on specificity
level (~91.3% of data, ~59K of 65K). The ~6K disagreement cases are exactly the
noisy boundary labels that confuse the model. Category labels can still use all data.
**C. More epochs + early stopping on spec F1**
Run 15-20 epochs. Switch model selection metric from combined_macro_f1 to
spec_macro_f1 (since category already exceeds 0.80). Use patience=5.
**D. Attention pooling**
Replace [CLS] token pooling with learned attention pooling over all tokens.
This lets the model attend to specific evidence tokens (CISO, $2M, NIST)
distributed anywhere in the paragraph.
### Tier 2 — If Tier 1 insufficient
**E. Ordinal consistency regularization**
Add a penalty when threshold k fires but threshold k-1 doesn't (e.g., model
says "has firm-specific" but not "has domain terms"). Weight ~0.1.
**F. Differential learning rates**
Backbone: 1e-5, heads: 5e-4. Let the heads learn classification faster while
the backbone makes only fine adjustments.
**G. Softmax head comparison**
Try standard 4-class CE (no ordinal constraint at all). If it outperforms both
CORAL and independent thresholds, the ordinal structure isn't helping.
**H. Multi-sample dropout**
Apply N different dropout masks, average logits. Reduces variance in the
specificity head's predictions, especially for boundary cases.
### Tier 3 — If nothing else works
**I. Specificity-focused auxiliary task**
The consensus labels include `specific_facts` arrays with classified fact types
(domain_term, named_role, quantified, etc.). Add a token-level auxiliary task
that detects these fact types. Specificity becomes "what's the highest-level
fact type present?" — making the ordinal structure explicit.
**J. Separate specificity model**
Train a dedicated model just for specificity with a larger head, more
specificity-focused features, or a different architecture (e.g., token-level
fact extraction → aggregation).
**K. Re-annotate boundary cases**
Use GPT-5.4 to re-judge the ~9,323 majority-vote cases where Grok had
specificity disagreement. Cleaner labels at boundaries.
## Experiment Log
### Experiment 1: Independent thresholds + attention pooling + MLP + filtering (15 epochs)
**Config:** `configs/finetune/iter1-independent.yaml`
- Specificity head: independent (3 separate Linear(1024→256→1) binary classifiers)
- Pooling: attention (learned attention over all tokens)
- Confidence filtering: only train spec on unanimous labels
- Ordinal consistency regularization: 0.1
- Class weighting: yes
- Base checkpoint: ModernBERT-large (no DAPT/TAPT)
- Epochs: 15
**Results:**
| Epoch | Combined | Cat F1 | Spec F1 | QWK |
|-------|----------|--------|---------|-----|
| 1 | 0.855 | 0.867 | 0.844 | 0.874 |
| 2 | 0.913 | 0.909 | 0.918 | 0.935 |
| 3 | 0.925 | 0.919 | 0.931 | 0.945 |
| 4 | 0.936 | 0.932 | 0.940 | 0.950 |
| 5 | 0.938 | 0.936 | 0.940 | 0.949 |
| **8** | **0.944** | **0.943** | **0.945** | **0.952** |
| 10 | 0.944 | 0.943 | 0.945 | 0.952 |
| 11 | 0.944 | 0.945 | 0.944 | 0.952 |
Stopped at epoch 11 — train-eval loss gap was 8× (0.06 vs 0.49) with no further
eval F1 improvement. Best checkpoint: epoch 8 (spec F1=0.945).
**Conclusion:** Massive improvement — spec F1 went from 0.517 (CORAL baseline) to
0.945 at epoch 8. Both targets (>0.80 cat and spec F1) exceeded by epoch 1.
Independent thresholds were the key insight — CORAL's shared weight vector was
the primary bottleneck. Attention pooling, MLP heads, and confidence filtering
all contributed. Tier 2 and Tier 3 ideas were not needed.

View File

@ -1,6 +1,6 @@
# Project Status — v2 Pipeline
**Deadline:** 2026-04-24 | **Started:** 2026-04-03 | **Updated:** 2026-04-05 (Stage 1 complete, 72K×3 + judge)
**Deadline:** 2026-04-24 | **Started:** 2026-04-03 | **Updated:** 2026-04-05 (Fine-tuning done: cat F1=0.943, spec F1=0.945)
---
@ -109,22 +109,47 @@
- [ ] Bench Stage 2 accuracy against gold (if needed for additional disputed paragraphs)
- **Cost so far:** $5.76 | **Remaining budget:** ~$39
### 11. Training Data Assembly
- [ ] Unanimous Stage 1 → full weight, calibrated majority → full weight
- [ ] Quality tier weights: clean/headed/minor 1.0, degraded 0.5
- [ ] Exclude 72 truncated filings
### 11. Training Data Assembly — DONE
- [x] Merge Stage 1 consensus with paragraph data (`python/src/finetune/data.py`)
- [x] Exclude 1,200 holdout paragraphs (reserved for eval)
- [x] Exclude 614 individually truncated paragraphs (not entire filings — more targeted than original plan)
- [x] Quality tier weights: clean/headed/minor 1.0, degraded 0.5
- [x] Stratified train/val split (90/10) from training set
- **Training set size:** 70,231 paragraphs (72,045 1,200 holdout 614 truncated)
- **Train/val split:** 63,214 / 7,024
### 12. Fine-Tuning
- [ ] Ablation: {base, +DAPT, +DAPT+TAPT} × {±class weighting} × {CE vs focal loss}
- [ ] Dual-head: shared ModernBERT backbone + category head (7-class) + specificity head (4-class ordinal)
- [ ] CORAL for ordinal specificity
- **Estimated time:** 12-20h GPU
### 12. Fine-Tuning — DONE
- [x] Ablation round 1: {base, +DAPT, +DAPT+TAPT} × {±class weighting} × {CE vs focal loss} = 12 configs × 1 epoch
- [x] Ablation round 1 winner: base_weighted_ce (CORAL head, [CLS] pooling)
- [x] CORAL limitation identified: shared weight vector can't capture 3 different transition signals (L1→L2: domain terms, L2→L3: firm facts, L3→L4: quantified claims)
- [x] Architecture iteration: replaced CORAL with independent threshold heads (3 separate MLP binary classifiers), attention pooling, specificity confidence filtering
- [x] **Final model (iter1-independent, epoch 8):** Cat F1=0.943, Spec F1=0.945, QWK=0.952, Combined=0.944
- **Architecture:** ModernBERT-large → attention pooling → dropout →
- Category: Linear(1024, 7) + weighted CE
- Specificity: 3× IndependentThreshold(Linear(1024→256→1)) + cumulative BCE + ordinal consistency reg.
- **Key findings (ablation round 1):**
- DAPT/TAPT pre-training did not help — base ModernBERT-large outperformed both
- Class weighting + CE is the best loss combination
- Focal loss + class weighting = too much correction (always bottom tier)
- TAPT consistently worst — likely overfitting on task paragraphs during MLM pre-training
- **Key findings (architecture iteration):**
- CORAL's shared weight vector was the primary bottleneck for specificity (0.517 → 0.940)
- Independent threshold heads let each L1→L2, L2→L3, L3→L4 transition learn different features
- Attention pooling captures distributed specificity signals (one "CISO" mention anywhere matters)
- Confidence filtering removes ~8.7% noisy boundary labels from specificity training
- **Training speed:** ~2.1 it/s, batch 32, seq 512, bf16, flash attention 2, torch.compile
- **Peak VRAM:** ~18-20 GB / 24.6 GB (RTX 3090)
- **Improvement plan:** `docs/SPECIFICITY-IMPROVEMENT-PLAN.md`
### 13. Evaluation & Paper
- [ ] Macro F1 on holdout (target > 0.80 both heads)
### 13. Evaluation & Paper ← CURRENT
- [ ] Proxy eval: run fine-tuned model on holdout, compare against GPT-5.4 and Opus benchmark labels
- [ ] Macro F1 on holdout gold (target > 0.80 both heads) — blocked on human labels
- [ ] Per-class F1 breakdown + GenAI benchmark table
- [ ] Error analysis, cost comparison, IGNITE slides
- [ ] Note in paper: specificity is paragraph-level (presence check), not category-conditional — acknowledge as limitation/future work
- [ ] Note in paper: DAPT/TAPT did not improve fine-tuning — noteworthy null result
- [ ] Note in paper: CORAL ordinal regression insufficient for multi-signal ordinal classification
- **Next:** evaluate fine-tuned model on holdout using GPT-5.4 + Opus labels as proxy gold
---
@ -165,6 +190,15 @@
| v2 Stage 1 judge | `data/annotations/v2-stage1/judge.jsonl` (212 tiebreakers) |
| Stage 1 distribution charts | `figures/stage1-*.png` (7 charts) |
| Stage 1 chart script | `scripts/plot-stage1-distributions.py` |
| Fine-tuning data loader | `python/src/finetune/data.py` |
| Dual-head model | `python/src/finetune/model.py` |
| Fine-tuning trainer | `python/src/finetune/train.py` |
| Fine-tune config | `python/configs/finetune/modernbert.yaml` |
| Ablation results | `checkpoints/finetune/ablation/ablation_results.json` |
| **Best model (final)** | `checkpoints/finetune/iter1-independent/final/` (cat=0.943, spec=0.945) |
| CORAL baseline (ablation winner) | `checkpoints/finetune/best-base_weighted_ce-ep5/final/` (cat=0.932, spec=0.517) |
| Ablation results | `checkpoints/finetune/ablation/ablation_results.json` |
| Spec improvement plan | `docs/SPECIFICITY-IMPROVEMENT-PLAN.md` |
### v2 Stage 1 Distribution (72,045 paragraphs, v4.5 prompt, Grok ×3 consensus + GPT-5.4 judge)

View File

@ -0,0 +1,38 @@
model:
name_or_path: answerdotai/ModernBERT-large
data:
paragraphs_path: ../data/paragraphs/paragraphs-clean.patched.jsonl
consensus_path: ../data/annotations/v2-stage1/consensus.jsonl
quality_path: ../data/paragraphs/quality/quality-scores.jsonl
holdout_path: ../data/gold/v2-holdout-ids.json
max_seq_length: 512
validation_split: 0.1
training:
output_dir: ../checkpoints/finetune/iter1-independent
learning_rate: 0.00005
num_train_epochs: 15
per_device_train_batch_size: 32
per_device_eval_batch_size: 64
gradient_accumulation_steps: 1
warmup_ratio: 0.1
weight_decay: 0.01
dropout: 0.1
bf16: true
gradient_checkpointing: false
logging_steps: 50
save_total_limit: 3
dataloader_num_workers: 4
seed: 42
loss_type: ce
focal_gamma: 2.0
class_weighting: true
category_loss_weight: 1.0
specificity_loss_weight: 1.0
# New options
specificity_head: independent
spec_mlp_dim: 256
pooling: attention
ordinal_consistency_weight: 0.1
filter_spec_confidence: true

View File

@ -0,0 +1,36 @@
model:
name_or_path: ../checkpoints/tapt/modernbert-large/final
# Alternatives for ablation:
# base: answerdotai/ModernBERT-large
# dapt: ../checkpoints/dapt/modernbert-large/final
# tapt: ../checkpoints/tapt/modernbert-large/final
data:
paragraphs_path: ../data/paragraphs/paragraphs-clean.patched.jsonl
consensus_path: ../data/annotations/v2-stage1/consensus.jsonl
quality_path: ../data/paragraphs/quality/quality-scores.jsonl
holdout_path: ../data/gold/v2-holdout-ids.json
max_seq_length: 512
validation_split: 0.1
training:
output_dir: ../checkpoints/finetune/modernbert-large
learning_rate: 0.00005
num_train_epochs: 3
per_device_train_batch_size: 32
per_device_eval_batch_size: 64
gradient_accumulation_steps: 1
warmup_ratio: 0.1
weight_decay: 0.01
dropout: 0.1
bf16: true
gradient_checkpointing: false
logging_steps: 50
save_total_limit: 3
dataloader_num_workers: 4
seed: 42
loss_type: ce
focal_gamma: 2.0
class_weighting: true
category_loss_weight: 1.0
specificity_loss_weight: 1.0

View File

@ -27,6 +27,35 @@ def cmd_dapt(args: argparse.Namespace) -> None:
train(config)
def cmd_finetune(args: argparse.Namespace) -> None:
from src.common.config import FinetuneConfig
from src.finetune.train import train
config = FinetuneConfig.from_yaml(args.config)
config.apply_overrides(
model_path=args.model_path,
output_dir=args.output_dir,
loss_type=args.loss_type,
class_weighting=args.class_weighting,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
)
train(config)
def cmd_ablate(args: argparse.Namespace) -> None:
from src.common.config import FinetuneConfig
from src.finetune.train import ablate
config = FinetuneConfig.from_yaml(args.config)
if args.output_dir:
config.training.output_dir = args.output_dir
if args.epochs:
config.training.num_train_epochs = args.epochs
ablate(config)
def main() -> None:
parser = argparse.ArgumentParser(
description="SEC-cyBERT training pipeline",
@ -46,21 +75,30 @@ def main() -> None:
dapt.add_argument("--stage", choices=["dapt", "tapt"], help="Override stage label")
dapt.set_defaults(func=cmd_dapt)
# ── finetune (placeholder) ──
# ── finetune ──
ft = sub.add_parser("finetune", help="Fine-tune classifier (dual-head)")
ft.add_argument("--config", required=True, help="Path to YAML config file")
ft.set_defaults(func=lambda args: print("Fine-tuning not yet implemented."))
ft.add_argument("--model-path", help="Override model checkpoint path")
ft.add_argument("--output-dir", help="Override output directory")
ft.add_argument("--loss-type", choices=["ce", "focal"], help="Override loss type")
ft.add_argument("--class-weighting", type=lambda x: x.lower() == "true", help="Override class weighting (true/false)")
ft.add_argument("--epochs", type=int, help="Override number of epochs")
ft.add_argument("--batch-size", type=int, help="Override batch size")
ft.add_argument("--learning-rate", type=float, help="Override learning rate")
ft.set_defaults(func=cmd_finetune)
# ── ablate ──
ab = sub.add_parser("ablate", help="Run full ablation grid (3 ckpts x 2 weighting x 2 loss)")
ab.add_argument("--config", required=True, help="Path to YAML config file")
ab.add_argument("--output-dir", help="Override base output directory")
ab.add_argument("--epochs", type=int, help="Override epochs per ablation run (default: config value)")
ab.set_defaults(func=cmd_ablate)
# ── eval (placeholder) ──
ev = sub.add_parser("eval", help="Evaluate a trained model")
ev.add_argument("--config", required=True, help="Path to YAML config file")
ev.set_defaults(func=lambda args: print("Evaluation not yet implemented."))
# ── decoder (placeholder) ──
dec = sub.add_parser("decoder", help="Decoder experiment (Qwen LoRA)")
dec.add_argument("--config", required=True, help="Path to YAML config file")
dec.set_defaults(func=lambda args: print("Decoder experiment not yet implemented."))
args = parser.parse_args()
args.func(args)

View File

@ -13,6 +13,8 @@ dependencies = [
"pyyaml>=6,<7",
"flash-attn==2.6.3+cu130torch2.11",
"unsloth==2026.3.11",
"coral-pytorch>=1.4.0",
"scikit-learn>=1.8.0",
]
[project.scripts]

View File

@ -0,0 +1,106 @@
"""Analyze ablation results and launch the best config for a full training run.
Usage:
uv run python scripts/run-best-config.py [--epochs 5] [--dry-run]
"""
import json
import subprocess
import sys
from pathlib import Path
ABLATION_DIR = Path("../checkpoints/finetune/ablation")
RESULTS_FILE = ABLATION_DIR / "ablation_results.json"
# Maps ablation run name components to CLI args
CHECKPOINT_MAP = {
"base": "answerdotai/ModernBERT-large",
"dapt": "../checkpoints/dapt/modernbert-large/final",
"tapt": "../checkpoints/tapt/modernbert-large/final",
}
def main():
epochs = 5
dry_run = False
for arg in sys.argv[1:]:
if arg.startswith("--epochs"):
epochs = int(arg.split("=")[1] if "=" in arg else sys.argv[sys.argv.index(arg) + 1])
if arg == "--dry-run":
dry_run = True
if not RESULTS_FILE.exists():
print(f"No results file found at {RESULTS_FILE}")
print("Ablation may still be running. Check: ps aux | grep ablate")
sys.exit(1)
with open(RESULTS_FILE) as f:
results = json.load(f)
# Filter successful runs
successful = [r for r in results if "error" not in r]
if not successful:
print("No successful ablation runs found!")
sys.exit(1)
# Sort by combined macro F1
successful.sort(key=lambda r: r.get("eval_combined_macro_f1", 0), reverse=True)
# Print results table
print(f"\n{'='*80}")
print(" ABLATION RESULTS (sorted by combined F1)")
print(f"{'='*80}")
print(f" {'Run':<45} {'Combined':>10} {'Cat F1':>10} {'Spec F1':>10} {'QWK':>10}")
print(f" {'-'*45} {'-'*10} {'-'*10} {'-'*10} {'-'*10}")
for r in successful:
name = r["run"]
combined = r.get("eval_combined_macro_f1", 0)
cat = r.get("eval_cat_macro_f1", 0)
spec = r.get("eval_spec_macro_f1", 0)
qwk = r.get("eval_spec_qwk", 0)
marker = " <-- BEST" if r == successful[0] else ""
print(f" {name:<45} {combined:>10.4f} {cat:>10.4f} {spec:>10.4f} {qwk:>10.4f}{marker}")
# Parse best config
best = successful[0]
name = best["run"]
parts = name.split("_")
ckpt_name = parts[0]
weighting = parts[1] == "weighted"
loss_type = parts[2]
model_path = CHECKPOINT_MAP[ckpt_name]
output_dir = f"../checkpoints/finetune/best-{name}-ep{epochs}"
print(f"\n Best config: {name}")
print(f" Model: {model_path}")
print(f" Class weighting: {weighting}")
print(f" Loss: {loss_type}")
print(f" Epochs: {epochs}")
print(f" Output: {output_dir}")
cmd = [
"uv", "run", "python", "main.py", "finetune",
"--config", "configs/finetune/modernbert.yaml",
"--model-path", model_path,
"--output-dir", output_dir,
"--loss-type", loss_type,
"--class-weighting", str(weighting).lower(),
"--epochs", str(epochs),
]
if dry_run:
print(f"\n Dry run — would execute:")
print(f" {' '.join(cmd)}")
else:
print(f"\n Launching full training...")
env = {"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"}
import os
env.update(os.environ)
subprocess.run(cmd, env=env)
if __name__ == "__main__":
main()

View File

@ -89,3 +89,108 @@ class DAPTConfig:
self.training.output_dir = output_dir
if stage is not None:
self.stage = stage
# ── Fine-tuning config ───────────────────────────────────────────────────────
@dataclass
class FinetuneDataConfig:
"""Data paths and preprocessing for fine-tuning."""
paragraphs_path: str
consensus_path: str
quality_path: str
holdout_path: str
max_seq_length: int = 512
validation_split: float = 0.1
@dataclass
class FinetuneTrainingConfig:
"""Training arguments for dual-head fine-tuning."""
output_dir: str
learning_rate: float = 5e-5
num_train_epochs: int = 3
per_device_train_batch_size: int = 32
per_device_eval_batch_size: int = 64
gradient_accumulation_steps: int = 1
warmup_ratio: float = 0.1
weight_decay: float = 0.01
dropout: float = 0.1
bf16: bool = True
gradient_checkpointing: bool = False
logging_steps: int = 50
save_total_limit: int = 3
dataloader_num_workers: int = 4
seed: int = 42
resume_from_checkpoint: str | None = None
# Dual-head specific
loss_type: str = "ce" # "ce" or "focal"
focal_gamma: float = 2.0
class_weighting: bool = True
category_loss_weight: float = 1.0
specificity_loss_weight: float = 1.0
# Architecture options
specificity_head: str = "independent" # "coral", "independent", "softmax"
spec_mlp_dim: int = 0 # 0 = linear, >0 = 2-layer MLP with this hidden dim
pooling: str = "cls" # "cls" or "attention"
ordinal_consistency_weight: float = 0.0
filter_spec_confidence: bool = False # only train spec on unanimous labels
@dataclass
class FinetuneConfig:
"""Full configuration for a fine-tuning run."""
model: ModelConfig
data: FinetuneDataConfig
training: FinetuneTrainingConfig
@classmethod
def from_yaml(cls, path: str | Path) -> "FinetuneConfig":
with open(path) as f:
raw = yaml.safe_load(f)
# Coerce float fields that PyYAML may parse as strings (e.g. 5e-5)
training = raw["training"]
for key in ("learning_rate", "warmup_ratio", "weight_decay", "dropout",
"focal_gamma", "category_loss_weight", "specificity_loss_weight",
"validation_split"):
if key in training and isinstance(training[key], str):
training[key] = float(training[key])
data = raw["data"]
if "validation_split" in data and isinstance(data["validation_split"], str):
data["validation_split"] = float(data["validation_split"])
return cls(
model=ModelConfig(**raw["model"]),
data=FinetuneDataConfig(**data),
training=FinetuneTrainingConfig(**training),
)
def apply_overrides(
self,
*,
model_path: str | None = None,
output_dir: str | None = None,
loss_type: str | None = None,
class_weighting: bool | None = None,
epochs: int | None = None,
batch_size: int | None = None,
learning_rate: float | None = None,
) -> None:
"""Apply CLI overrides on top of YAML config."""
if model_path is not None:
self.model.name_or_path = model_path
if output_dir is not None:
self.training.output_dir = output_dir
if loss_type is not None:
self.training.loss_type = loss_type
if class_weighting is not None:
self.training.class_weighting = class_weighting
if epochs is not None:
self.training.num_train_epochs = epochs
if batch_size is not None:
self.training.per_device_train_batch_size = batch_size
if learning_rate is not None:
self.training.learning_rate = learning_rate

211
python/src/finetune/data.py Normal file
View File

@ -0,0 +1,211 @@
"""Training data assembly for dual-head fine-tuning.
Merges Stage 1 consensus labels with paragraph text, filters holdout and
truncated filings, assigns quality-tier sample weights, tokenizes, and
produces a train/val DatasetDict ready for the Trainer.
"""
import json
from pathlib import Path
from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizerBase
# Canonical label maps — alphabetical for reproducibility
CATEGORIES = [
"Board Governance", # 0
"Incident Disclosure", # 1
"Management Role", # 2
"None/Other", # 3
"Risk Management Process",# 4
"Strategy Integration", # 5
"Third-Party Risk", # 6
]
CAT2ID = {c: i for i, c in enumerate(CATEGORIES)}
NUM_CATEGORIES = len(CATEGORIES)
# Specificity: ordinal 1-4 → 0-indexed 0-3
NUM_SPECIFICITY = 4
# Quality tier → sample weight (STATUS.md: clean/headed/minor 1.0, degraded 0.5)
TIER_WEIGHT = {
"clean": 1.0,
"headed": 1.0,
"minor": 1.0,
"degraded": 0.5,
}
def _load_jsonl(path: str | Path) -> list[dict]:
records = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
records.append(json.loads(line))
return records
def load_finetune_data(
*,
paragraphs_path: str,
consensus_path: str,
quality_path: str,
holdout_path: str,
max_seq_length: int,
validation_split: float,
tokenizer: PreTrainedTokenizerBase,
seed: int = 42,
) -> DatasetDict:
"""Load, merge, filter, tokenize, and split training data.
Returns a DatasetDict with 'train' and 'test' splits, each containing:
input_ids, attention_mask, category_labels, specificity_labels, sample_weight
"""
print(" Loading data files...")
# 1. Load all sources
paragraphs = {p["id"]: p for p in _load_jsonl(paragraphs_path)}
consensus = {c["paragraphId"]: c for c in _load_jsonl(consensus_path)}
quality = {q["id"]: q for q in _load_jsonl(quality_path)}
with open(holdout_path) as f:
holdout_ids = set(json.load(f))
print(f" Paragraphs: {len(paragraphs):,}")
print(f" Consensus: {len(consensus):,}")
print(f" Quality: {len(quality):,}")
print(f" Holdout: {len(holdout_ids):,}")
# 2. Identify individually truncated paragraphs
truncated_ids: set[str] = set()
for q in quality.values():
if "truncated" in q.get("issues", []):
truncated_ids.add(q["id"])
print(f" Truncated paragraphs: {len(truncated_ids)}")
# 3. Merge and filter
records: list[dict] = []
skipped_holdout = 0
skipped_truncated = 0
skipped_no_consensus = 0
for pid, para in paragraphs.items():
# Exclude holdout
if pid in holdout_ids:
skipped_holdout += 1
continue
# Exclude individually truncated paragraphs
if pid in truncated_ids:
skipped_truncated += 1
continue
# Must have consensus label
cons = consensus.get(pid)
if cons is None:
skipped_no_consensus += 1
continue
label = cons["finalLabel"]
cat = label["content_category"]
spec = label["specificity_level"] # 1-4
# Quality tier weight
q = quality.get(pid, {})
tier = q.get("quality_tier", "clean")
weight = TIER_WEIGHT.get(tier, 1.0)
# Specificity confidence: 1.0 if all 3 runs agreed, 0.0 otherwise
spec_agreed = cons.get("specificityAgreement", {}).get("agreed", True)
records.append({
"text": para["text"],
"category_labels": CAT2ID[cat],
"specificity_labels": spec - 1, # 0-indexed
"sample_weight": weight,
"spec_mask": 1.0 if spec_agreed else 0.0,
})
print(f"\n Skipped holdout: {skipped_holdout:,}")
print(f" Skipped truncated: {skipped_truncated:,}")
print(f" Skipped no consensus: {skipped_no_consensus:,}")
print(f" Training samples: {len(records):,}")
# 4. Distributions
from collections import Counter
cat_counts = Counter(r["category_labels"] for r in records)
spec_counts = Counter(r["specificity_labels"] for r in records)
spec_agreed = sum(1 for r in records if r["spec_mask"] == 1.0)
print(f" Spec high-confidence: {spec_agreed:,} ({spec_agreed/len(records)*100:.1f}%)")
print(f"\n Category distribution:")
for i, name in enumerate(CATEGORIES):
print(f" {name}: {cat_counts[i]:,} ({cat_counts[i]/len(records)*100:.1f}%)")
print(f" Specificity distribution:")
for i in range(NUM_SPECIFICITY):
print(f" L{i+1}: {spec_counts[i]:,} ({spec_counts[i]/len(records)*100:.1f}%)")
# 5. Create HuggingFace dataset
dataset = Dataset.from_list(records)
# 6. Tokenize
print(f"\n Tokenizing to max_seq_length={max_seq_length}...")
def tokenize_fn(examples):
return tokenizer(
examples["text"],
truncation=True,
max_length=max_seq_length,
padding=False, # dynamic padding in collator
)
dataset = dataset.map(
tokenize_fn,
batched=True,
remove_columns=["text"],
desc="Tokenizing",
)
# 7. Stratified train/val split (using sklearn since HF requires ClassLabel)
from sklearn.model_selection import StratifiedShuffleSplit
sss = StratifiedShuffleSplit(n_splits=1, test_size=validation_split, random_state=seed)
train_idx, val_idx = next(sss.split(range(len(dataset)), dataset["category_labels"]))
split = DatasetDict({
"train": dataset.select(train_idx),
"test": dataset.select(val_idx),
})
print(f" Train: {len(split['train']):,} | Val: {len(split['test']):,}")
return split
def compute_class_weights(dataset: Dataset) -> dict[str, list[float]]:
"""Compute inverse-frequency class weights from training data."""
from collections import Counter
import math
n = len(dataset)
# Category weights
cat_counts = Counter(dataset["category_labels"])
cat_weights = []
for i in range(NUM_CATEGORIES):
count = cat_counts.get(i, 1)
# Inverse frequency, normalized so mean weight = 1.0
cat_weights.append(n / (NUM_CATEGORIES * count))
# Cap extreme weights (ID category is very rare)
max_weight = 10.0
cat_weights = [min(w, max_weight) for w in cat_weights]
# Specificity weights
spec_counts = Counter(dataset["specificity_labels"])
spec_weights = []
for i in range(NUM_SPECIFICITY):
count = spec_counts.get(i, 1)
spec_weights.append(n / (NUM_SPECIFICITY * count))
return {
"category_weights": cat_weights,
"specificity_weights": spec_weights,
}

View File

@ -0,0 +1,249 @@
"""Dual-head ModernBERT for category + ordinal specificity classification.
Architecture:
ModernBERT backbone pooling (CLS or attention) dropout
category head (7-class linear)
specificity head (options: coral, independent, softmax)
Specificity head options:
- "coral": Single weight vector, 3 shifted biases (CoralLayer)
- "independent": 3 independent binary classifiers, one per threshold
(each learns its own feature direction addresses CORAL limitation)
- "softmax": Standard 4-class CE (no ordinal constraint)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from coral_pytorch.layers import CoralLayer
class FocalLoss(nn.Module):
"""Focal Loss (Lin et al. 2017) for imbalanced classification."""
def __init__(self, weight: torch.Tensor | None = None, gamma: float = 2.0):
super().__init__()
self.gamma = gamma
self.register_buffer("weight", weight)
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
ce = F.cross_entropy(logits, targets, weight=self.weight, reduction="none")
pt = torch.exp(-ce)
return ((1 - pt) ** self.gamma) * ce
class AttentionPooling(nn.Module):
"""Learned attention pooling over all token representations."""
def __init__(self, hidden_size: int):
super().__init__()
self.attention = nn.Linear(hidden_size, 1)
def forward(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
scores = self.attention(hidden_states).squeeze(-1) # (batch, seq_len)
scores = scores.masked_fill(~attention_mask.bool(), float("-inf"))
weights = F.softmax(scores, dim=1)
return (hidden_states * weights.unsqueeze(-1)).sum(dim=1)
class IndependentThresholdHead(nn.Module):
"""Independent binary classifiers for each ordinal threshold.
Unlike CORAL (shared weight vector, different biases), each threshold
has its own fully independent Linear layer. This lets each transition
learn different features:
- L1L2: cybersecurity terminology detection
- L2L3: firm-unique fact detection
- L3L4: quantified/verifiable claim detection
"""
def __init__(self, in_features: int, num_classes: int, mlp_dim: int = 0):
super().__init__()
self.num_thresholds = num_classes - 1
if mlp_dim > 0:
self.thresholds = nn.ModuleList([
nn.Sequential(
nn.Linear(in_features, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, 1),
)
for _ in range(self.num_thresholds)
])
else:
self.thresholds = nn.ModuleList([
nn.Linear(in_features, 1)
for _ in range(self.num_thresholds)
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.cat([t(x) for t in self.thresholds], dim=1)
class DualHeadModernBERT(nn.Module):
"""ModernBERT with category classification + ordinal specificity heads."""
def __init__(
self,
backbone: nn.Module,
hidden_size: int,
num_categories: int = 7,
num_specificity: int = 4,
dropout: float = 0.1,
loss_type: str = "ce",
focal_gamma: float = 2.0,
category_weights: torch.Tensor | None = None,
specificity_weights: list[float] | None = None,
category_loss_weight: float = 1.0,
specificity_loss_weight: float = 1.0,
specificity_head_type: str = "independent",
spec_mlp_dim: int = 0,
pooling: str = "cls",
ordinal_consistency_weight: float = 0.0,
):
super().__init__()
self.backbone = backbone
self.dropout = nn.Dropout(dropout)
self.num_specificity = num_specificity
self.category_loss_weight = category_loss_weight
self.specificity_loss_weight = specificity_loss_weight
self.specificity_head_type = specificity_head_type
self.ordinal_consistency_weight = ordinal_consistency_weight
# Pooling
self.pooling_type = pooling
if pooling == "attention":
self.pooler = AttentionPooling(hidden_size)
# Category head
self.category_head = nn.Linear(hidden_size, num_categories)
# Specificity head
if specificity_head_type == "coral":
self.specificity_head = CoralLayer(hidden_size, num_specificity)
elif specificity_head_type == "independent":
self.specificity_head = IndependentThresholdHead(
hidden_size, num_specificity, mlp_dim=spec_mlp_dim
)
elif specificity_head_type == "softmax":
self.specificity_head = nn.Linear(hidden_size, num_specificity)
else:
raise ValueError(f"Unknown specificity_head_type: {specificity_head_type}")
# Category loss
if loss_type == "focal":
self.category_loss_fn = FocalLoss(weight=category_weights, gamma=focal_gamma)
else:
self.category_loss_fn = nn.CrossEntropyLoss(weight=category_weights, reduction="none")
self.specificity_weights = specificity_weights
def _pool(
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
if self.pooling_type == "attention":
return self.pooler(hidden_states, attention_mask)
return hidden_states[:, 0] # [CLS]
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
category_labels: torch.Tensor | None = None,
specificity_labels: torch.Tensor | None = None,
sample_weight: torch.Tensor | None = None,
spec_mask: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
pooled = self._pool(outputs.last_hidden_state, attention_mask)
pooled = self.dropout(pooled)
cat_logits = self.category_head(pooled)
spec_logits = self.specificity_head(pooled)
result = {
"category_logits": cat_logits,
"specificity_logits": spec_logits,
}
if category_labels is not None and specificity_labels is not None:
# Category loss (all samples)
cat_loss_per_sample = self.category_loss_fn(cat_logits, category_labels)
if sample_weight is not None:
cat_loss_per_sample = cat_loss_per_sample * sample_weight
cat_loss = cat_loss_per_sample.mean()
# Specificity loss
if self.specificity_head_type == "softmax":
spec_loss_per_sample = F.cross_entropy(
spec_logits, specificity_labels, reduction="none"
)
else:
# Cumulative binary CE (works for both coral and independent)
spec_loss_per_sample = _cumulative_bce_per_sample(
spec_logits, specificity_labels, self.num_specificity
)
if sample_weight is not None:
spec_loss_per_sample = spec_loss_per_sample * sample_weight
# Apply spec_mask: only compute spec loss on high-confidence samples
if spec_mask is not None:
spec_loss_per_sample = spec_loss_per_sample * spec_mask
spec_loss = spec_loss_per_sample.sum() / (spec_mask.sum() + 1e-8)
else:
spec_loss = spec_loss_per_sample.mean()
total_loss = (
self.category_loss_weight * cat_loss
+ self.specificity_loss_weight * spec_loss
)
# Ordinal consistency regularization (for independent thresholds)
if self.ordinal_consistency_weight > 0 and self.specificity_head_type != "softmax":
consistency_loss = _ordinal_consistency_loss(spec_logits)
total_loss = total_loss + self.ordinal_consistency_weight * consistency_loss
result["loss"] = total_loss
result["category_loss"] = cat_loss.detach()
result["specificity_loss"] = spec_loss.detach()
return result
def _cumulative_bce_per_sample(
logits: torch.Tensor, labels: torch.Tensor, num_classes: int
) -> torch.Tensor:
"""Per-sample cumulative binary CE (shared by CORAL and independent heads)."""
levels = torch.zeros(
labels.size(0), num_classes - 1, device=logits.device, dtype=logits.dtype
)
for k in range(num_classes - 1):
levels[:, k] = (labels > k).float()
return F.binary_cross_entropy_with_logits(
logits, levels, reduction="none"
).sum(dim=1)
def _ordinal_consistency_loss(logits: torch.Tensor) -> torch.Tensor:
"""Penalize threshold k firing when threshold k-1 doesn't.
For cumulative thresholds, P(L3) should never exceed P(L2).
Penalty: max(0, sigmoid(logit_k) - sigmoid(logit_{k-1})) for all k.
"""
probs = torch.sigmoid(logits) # (batch, num_thresholds)
# probs[:, k] should be <= probs[:, k-1]
violations = F.relu(probs[:, 1:] - probs[:, :-1])
return violations.mean()
def ordinal_predict(logits: torch.Tensor) -> torch.Tensor:
"""Convert cumulative threshold logits to class predictions (0-indexed)."""
probs = torch.sigmoid(logits)
return (probs > 0.5).sum(dim=1).long()
def softmax_predict(logits: torch.Tensor) -> torch.Tensor:
"""Convert softmax logits to class predictions."""
return logits.argmax(dim=1)

View File

@ -0,0 +1,400 @@
"""Dual-head fine-tuning trainer for ModernBERT.
Supports the ablation grid: {base, +DAPT, +DAPT+TAPT} x {+/-class weighting} x {CE vs focal loss}.
Uses HuggingFace Trainer with custom loss computation for the dual-head architecture.
"""
import unsloth # noqa: F401 — monkey-patches transformers for faster kernels
from dataclasses import dataclass
from pathlib import Path
from itertools import product
import numpy as np
import torch
from datasets import DatasetDict
from sklearn.metrics import f1_score, cohen_kappa_score
from transformers import (
AutoModel,
AutoTokenizer,
Trainer,
TrainingArguments,
EvalPrediction,
)
from ..common.config import FinetuneConfig
from .data import (
CATEGORIES,
NUM_CATEGORIES,
NUM_SPECIFICITY,
load_finetune_data,
compute_class_weights,
)
from .model import DualHeadModernBERT, ordinal_predict, softmax_predict
@dataclass
class DualHeadCollator:
"""Pads input_ids/attention_mask and stacks label columns into tensors."""
tokenizer: object
padding: str = "longest"
def __call__(self, features: list[dict]) -> dict:
# Pop labels before padding (they're scalars, not sequences)
cat_labels = [f.pop("category_labels") for f in features]
spec_labels = [f.pop("specificity_labels") for f in features]
weights = [f.pop("sample_weight") for f in features]
spec_masks = [f.pop("spec_mask") for f in features]
# Pad input sequences
batch = self.tokenizer.pad(
features, padding=self.padding, return_tensors="pt"
)
# Add labels back as tensors
batch["category_labels"] = torch.tensor(cat_labels, dtype=torch.long)
batch["specificity_labels"] = torch.tensor(spec_labels, dtype=torch.long)
batch["sample_weight"] = torch.tensor(weights, dtype=torch.float)
batch["spec_mask"] = torch.tensor(spec_masks, dtype=torch.float)
return batch
class DualHeadTrainer(Trainer):
"""HuggingFace Trainer adapted for dual-head model output."""
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
outputs = model(**inputs)
loss = outputs["loss"]
return (loss, outputs) if return_outputs else loss
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
outputs = model(**inputs)
loss = outputs["loss"]
if prediction_loss_only:
return (loss, None, None)
cat_logits = outputs["category_logits"]
spec_logits = outputs["specificity_logits"]
cat_labels = inputs["category_labels"]
spec_labels = inputs["specificity_labels"]
return loss, (cat_logits, spec_logits), (cat_labels, spec_labels)
def build_compute_metrics(spec_head_type: str = "independent"):
"""Return a compute_metrics function for the dual-head model."""
def compute_metrics(eval_pred: EvalPrediction) -> dict:
(cat_logits, spec_logits) = eval_pred.predictions
(cat_labels, spec_labels) = eval_pred.label_ids
# Category predictions
cat_preds = np.argmax(cat_logits, axis=1)
cat_macro_f1 = f1_score(cat_labels, cat_preds, average="macro")
cat_weighted_f1 = f1_score(cat_labels, cat_preds, average="weighted")
cat_per_class = f1_score(
cat_labels, cat_preds, average=None, labels=range(NUM_CATEGORIES)
)
# Specificity predictions
spec_tensor = torch.tensor(spec_logits, dtype=torch.float32)
if spec_head_type == "softmax":
spec_preds = softmax_predict(spec_tensor).numpy()
else:
spec_preds = ordinal_predict(spec_tensor).numpy()
spec_macro_f1 = f1_score(
spec_labels, spec_preds, average="macro",
labels=range(NUM_SPECIFICITY),
)
spec_weighted_f1 = f1_score(
spec_labels, spec_preds, average="weighted",
labels=range(NUM_SPECIFICITY),
)
spec_per_class = f1_score(
spec_labels, spec_preds, average=None,
labels=range(NUM_SPECIFICITY),
)
# Quadratic weighted kappa for ordinal agreement
spec_qwk = cohen_kappa_score(
spec_labels, spec_preds, weights="quadratic"
)
# Aggregate metric for model selection
combined_f1 = (cat_macro_f1 + spec_macro_f1) / 2
metrics = {
"combined_macro_f1": combined_f1,
"cat_macro_f1": cat_macro_f1,
"cat_weighted_f1": cat_weighted_f1,
"spec_macro_f1": spec_macro_f1,
"spec_weighted_f1": spec_weighted_f1,
"spec_qwk": spec_qwk,
}
# Per-class F1 for category
for i, name in enumerate(CATEGORIES):
short = name.replace(" ", "").replace("/", "")[:6]
metrics[f"cat_f1_{short}"] = cat_per_class[i]
# Per-class F1 for specificity
for i in range(NUM_SPECIFICITY):
metrics[f"spec_f1_L{i+1}"] = spec_per_class[i]
return metrics
return compute_metrics
def train(config: FinetuneConfig) -> dict:
"""Run a single fine-tuning experiment. Returns final eval metrics."""
print(f"\n{'='*60}")
print(f" SEC-cyBERT Fine-tuning")
print(f" Model: {config.model.name_or_path}")
print(f" Loss: {config.training.loss_type}")
print(f" Class weights: {config.training.class_weighting}")
print(f" Output: {config.training.output_dir}")
print(f"{'='*60}\n")
# Load tokenizer
tokenizer_name = config.model.tokenizer or config.model.name_or_path
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
trust_remote_code=config.model.trust_remote_code,
)
# Load and prepare data (shared cache across ablation runs)
output_dir = Path(config.training.output_dir)
# Use a stable cache location based on data paths, not output_dir
cache_dir = Path(config.data.consensus_path).parent / ".finetune_data_cache"
if cache_dir.exists():
print(f" Loading cached dataset from {cache_dir}...")
split = DatasetDict.load_from_disk(str(cache_dir))
print(f" Train: {len(split['train']):,} | Val: {len(split['test']):,}\n")
else:
split = load_finetune_data(
paragraphs_path=config.data.paragraphs_path,
consensus_path=config.data.consensus_path,
quality_path=config.data.quality_path,
holdout_path=config.data.holdout_path,
max_seq_length=config.data.max_seq_length,
validation_split=config.data.validation_split,
tokenizer=tokenizer,
seed=config.training.seed,
)
cache_dir.mkdir(parents=True, exist_ok=True)
split.save_to_disk(str(cache_dir))
print(f" Cached to {cache_dir}\n")
# If not filtering spec confidence, override spec_mask to all 1.0
if not config.training.filter_spec_confidence:
split["train"] = split["train"].map(
lambda x: {"spec_mask": 1.0}, desc="Disabling spec filter"
)
split["test"] = split["test"].map(
lambda x: {"spec_mask": 1.0}, desc="Disabling spec filter"
)
# Compute class weights if needed
category_weights = None
specificity_weights = None
if config.training.class_weighting:
weights = compute_class_weights(split["train"])
category_weights = torch.tensor(weights["category_weights"], dtype=torch.float32)
specificity_weights = weights["specificity_weights"]
print(f" Category weights: {[f'{w:.2f}' for w in weights['category_weights']]}")
print(f" Specificity weights: {[f'{w:.2f}' for w in weights['specificity_weights']]}\n")
# Load backbone
try:
import flash_attn # noqa: F401
attn_impl = "flash_attention_2"
except ImportError:
attn_impl = "sdpa"
backbone = AutoModel.from_pretrained(
config.model.name_or_path,
trust_remote_code=config.model.trust_remote_code,
attn_implementation=attn_impl,
dtype=torch.bfloat16 if config.training.bf16 else None,
)
hidden_size = backbone.config.hidden_size
print(f" Backbone parameters: {sum(p.numel() for p in backbone.parameters()) / 1e6:.0f}M")
print(f" Hidden size: {hidden_size}")
print(f" Attention: {attn_impl}\n")
# Build dual-head model
spec_head = config.training.specificity_head
print(f" Specificity head: {spec_head}")
print(f" Spec MLP dim: {config.training.spec_mlp_dim}")
print(f" Pooling: {config.training.pooling}")
print(f" Filter spec confidence: {config.training.filter_spec_confidence}")
print(f" Ordinal consistency: {config.training.ordinal_consistency_weight}\n")
model = DualHeadModernBERT(
backbone=backbone,
hidden_size=hidden_size,
num_categories=NUM_CATEGORIES,
num_specificity=NUM_SPECIFICITY,
dropout=config.training.dropout,
loss_type=config.training.loss_type,
focal_gamma=config.training.focal_gamma,
category_weights=category_weights,
specificity_weights=specificity_weights,
category_loss_weight=config.training.category_loss_weight,
specificity_loss_weight=config.training.specificity_loss_weight,
specificity_head_type=spec_head,
spec_mlp_dim=config.training.spec_mlp_dim,
pooling=config.training.pooling,
ordinal_consistency_weight=config.training.ordinal_consistency_weight,
)
# Data collator
collator = DualHeadCollator(tokenizer=tokenizer)
# Training arguments
steps_per_epoch = len(split["train"]) // (
config.training.per_device_train_batch_size
* config.training.gradient_accumulation_steps
)
training_kwargs = dict(
output_dir=str(output_dir),
learning_rate=config.training.learning_rate,
num_train_epochs=config.training.num_train_epochs,
per_device_train_batch_size=config.training.per_device_train_batch_size,
per_device_eval_batch_size=config.training.per_device_eval_batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
warmup_steps=int(config.training.warmup_ratio * steps_per_epoch * config.training.num_train_epochs),
weight_decay=config.training.weight_decay,
bf16=config.training.bf16,
tf32=True,
torch_compile=True,
optim="adamw_torch_fused",
gradient_checkpointing=config.training.gradient_checkpointing,
dataloader_persistent_workers=True,
dataloader_num_workers=config.training.dataloader_num_workers,
logging_steps=config.training.logging_steps,
eval_strategy="epoch",
save_strategy="epoch",
save_total_limit=config.training.save_total_limit,
seed=config.training.seed,
remove_unused_columns=False, # we have custom columns
report_to="none",
load_best_model_at_end=True,
metric_for_best_model="spec_macro_f1", # optimize for specificity (category already > 0.80)
greater_is_better=True,
)
if config.training.gradient_checkpointing:
training_kwargs["gradient_checkpointing_kwargs"] = {"use_reentrant": False}
args = TrainingArguments(**training_kwargs)
trainer = DualHeadTrainer(
model=model,
args=args,
train_dataset=split["train"],
eval_dataset=split["test"],
data_collator=collator,
compute_metrics=build_compute_metrics(spec_head_type=spec_head),
)
# Train
resume = config.training.resume_from_checkpoint
if resume is None and any(output_dir.glob("checkpoint-*")):
resume = True
trainer.train(resume_from_checkpoint=resume)
# Save final model + tokenizer
final_dir = output_dir / "final"
print(f"\n Saving final model to {final_dir}...")
trainer.save_model(str(final_dir))
tokenizer.save_pretrained(str(final_dir))
# Final eval
metrics = trainer.evaluate()
print(f"\n Final metrics:")
for k, v in sorted(metrics.items()):
if isinstance(v, float):
print(f" {k}: {v:.4f}")
return metrics
# ── Ablation grid ────────────────────────────────────────────────────────────
ABLATION_CHECKPOINTS = {
"base": "answerdotai/ModernBERT-large",
"dapt": "../checkpoints/dapt/modernbert-large/final",
"tapt": "../checkpoints/tapt/modernbert-large/final",
}
def ablate(config: FinetuneConfig) -> None:
"""Run the full ablation grid: 3 checkpoints x 2 weighting x 2 loss = 12 runs."""
base_output = Path(config.training.output_dir)
results: list[dict] = []
checkpoints = list(ABLATION_CHECKPOINTS.items())
weightings = [False, True]
losses = ["ce", "focal"]
combos = list(product(checkpoints, weightings, losses))
print(f"\n Ablation grid: {len(combos)} configurations")
print(f" Checkpoints: {[c[0] for c in checkpoints]}")
print(f" Weighting: {weightings}")
print(f" Losses: {losses}")
print()
for i, ((ckpt_name, ckpt_path), weighting, loss_type) in enumerate(combos):
run_name = f"{ckpt_name}_{'weighted' if weighting else 'unweighted'}_{loss_type}"
run_output = base_output / run_name
print(f"\n{'#'*60}")
print(f" Ablation run {i+1}/{len(combos)}: {run_name}")
print(f"{'#'*60}")
# Override config for this run
config.model.name_or_path = ckpt_path
config.training.output_dir = str(run_output)
config.training.class_weighting = weighting
config.training.loss_type = loss_type
try:
metrics = train(config)
results.append({"run": run_name, **metrics})
except Exception as e:
print(f"\n FAILED: {run_name}: {e}")
results.append({"run": run_name, "error": str(e)})
# Summary table
print(f"\n\n{'='*80}")
print(" ABLATION RESULTS SUMMARY")
print(f"{'='*80}")
print(f" {'Run':<45} {'Combined':>10} {'Cat F1':>10} {'Spec F1':>10} {'QWK':>10}")
print(f" {'-'*45} {'-'*10} {'-'*10} {'-'*10} {'-'*10}")
for r in results:
name = r["run"]
if "error" in r:
print(f" {name:<45} {'FAILED':>10}")
else:
combined = r.get("eval_combined_macro_f1", 0)
cat = r.get("eval_cat_macro_f1", 0)
spec = r.get("eval_spec_macro_f1", 0)
qwk = r.get("eval_spec_qwk", 0)
print(f" {name:<45} {combined:>10.4f} {cat:>10.4f} {spec:>10.4f} {qwk:>10.4f}")
# Save results
import json
results_path = base_output / "ablation_results.json"
results_path.parent.mkdir(parents=True, exist_ok=True)
with open(results_path, "w") as f:
json.dump(results, f, indent=2, default=str)
print(f"\n Results saved to {results_path}")