SEC-cyBERT/scripts/identify-holdout-rerun.py
2026-04-03 14:43:53 -04:00

202 lines
7.0 KiB
Python

"""
Identify holdout paragraphs on confusion axes that need v3.5 re-annotation.
Builds a 13-signal matrix from all available sources:
- 3 human annotators (per paragraph)
- 1 Opus golden annotation
- Up to 6 bench-holdout model annotations
- Stage 1 patched annotations (filtered to holdout PIDs)
Flags paragraphs splitting on:
1. SI <-> N/O (at least 2 signals each side)
2. MR <-> RMP (at least 2 signals each side)
3. BG <-> MR (at least 2 signals each side)
4. BG <-> RMP (at least 2 signals each side)
5. Materiality language present but majority says N/O
"""
import json
import re
from collections import Counter
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
DATA = ROOT / "data"
# Short names for categories
ABBREV = {
"Board Governance": "BG",
"Incident Disclosure": "ID",
"Management Role": "MR",
"None/Other": "NO",
"Risk Management Process": "RMP",
"Strategy Integration": "SI",
"Third-Party Risk": "TPR",
}
# Materiality language patterns
MATERIALITY_PATTERNS = [
re.compile(r"material(ly)?\s+(adverse|impact|effect|affect)", re.IGNORECASE),
re.compile(r"materially\s+affect(ed)?", re.IGNORECASE),
re.compile(r"material\s+cybersecurity\s+(incident|threat|event)", re.IGNORECASE),
re.compile(r"not\s+(experienced|had|identified)\s+.{0,40}material", re.IGNORECASE),
re.compile(r"reasonably\s+likely\s+to\s+materially", re.IGNORECASE),
re.compile(r"material(ity)?\s+(assessment|conclusion|determination)", re.IGNORECASE),
re.compile(r"no\s+material\s+(impact|effect|cybersecurity)", re.IGNORECASE),
re.compile(
r"have\s+not\s+.{0,30}materially\s+affect(ed)?", re.IGNORECASE
),
]
def has_materiality_language(text: str) -> bool:
return any(p.search(text) for p in MATERIALITY_PATTERNS)
def majority_category(tally: Counter) -> str:
if not tally:
return "UNKNOWN"
return tally.most_common(1)[0][0]
def main():
# 1. Determine the 1,200 holdout PIDs from human labels
holdout_pids: set[str] = set()
human_labels: dict[str, list[str]] = {} # pid -> list of abbreviated cats
with open(DATA / "gold" / "human-labels-raw.jsonl") as f:
for line in f:
d = json.loads(line)
pid = d["paragraphId"]
holdout_pids.add(pid)
human_labels.setdefault(pid, []).append(
ABBREV.get(d["contentCategory"], d["contentCategory"])
)
# Load paragraph texts for the holdout PIDs
holdout_paragraphs: dict[str, str] = {}
with open(DATA / "gold" / "paragraphs-holdout.jsonl") as f:
for line in f:
d = json.loads(line)
if d["id"] in holdout_pids:
holdout_paragraphs[d["id"]] = d["text"]
print(f"Total holdout paragraphs: {len(holdout_pids)}")
# 2. Build signal matrix: pid -> list of category strings (abbreviated)
signals: dict[str, list[str]] = {pid: list(cats) for pid, cats in human_labels.items()}
# 2a. Human labels already loaded above
print(f"Paragraphs with human labels: {len(human_labels)}")
# 2b. Opus golden
with open(DATA / "annotations" / "golden" / "opus.jsonl") as f:
for line in f:
d = json.loads(line)
pid = d["paragraphId"]
if pid in holdout_pids:
cat = ABBREV.get(
d["label"]["content_category"], d["label"]["content_category"]
)
signals[pid].append(cat)
# 2c. Bench-holdout model annotations (skip error files)
bench_dir = DATA / "annotations" / "bench-holdout"
for fpath in sorted(bench_dir.glob("*.jsonl")):
if "-errors" in fpath.name:
continue
with open(fpath) as f:
for line in f:
d = json.loads(line)
pid = d.get("paragraphId")
if pid and pid in holdout_pids and "label" in d:
cat = ABBREV.get(
d["label"]["content_category"],
d["label"]["content_category"],
)
signals[pid].append(cat)
# 2d. Stage 1 patched (filter to holdout PIDs)
with open(DATA / "annotations" / "stage1.patched.jsonl") as f:
for line in f:
d = json.loads(line)
pid = d["paragraphId"]
if pid in holdout_pids:
cat = ABBREV.get(
d["label"]["content_category"], d["label"]["content_category"]
)
signals[pid].append(cat)
# Report signal counts
signal_counts = [len(signals[pid]) for pid in holdout_pids]
print(
f"Signals per paragraph: min={min(signal_counts)}, max={max(signal_counts)}, "
f"mean={sum(signal_counts)/len(signal_counts):.1f}"
)
# 3. Check confusion axes
AXES = {
"SI_NO": ("SI", "NO"),
"MR_RMP": ("MR", "RMP"),
"BG_MR": ("BG", "MR"),
"BG_RMP": ("BG", "RMP"),
}
axis_counts: dict[str, int] = {k: 0 for k in AXES}
materiality_no_count = 0
results: list[dict] = []
for pid in sorted(holdout_pids):
tally = Counter(signals[pid])
maj = majority_category(tally)
text = holdout_paragraphs[pid]
mat_lang = has_materiality_language(text)
# Check each axis
flagged_axes: list[str] = []
for axis_name, (cat_a, cat_b) in AXES.items():
if tally.get(cat_a, 0) >= 2 and tally.get(cat_b, 0) >= 2:
flagged_axes.append(axis_name)
# Materiality language + majority N/O
mat_no_flag = mat_lang and maj == "NO"
if flagged_axes or mat_no_flag:
for axis_name in flagged_axes:
axis_counts[axis_name] += 1
if mat_no_flag:
materiality_no_count += 1
# Build tally dict with full names for output readability
tally_dict = dict(tally.most_common())
results.append(
{
"paragraphId": pid,
"axes": flagged_axes if flagged_axes else [],
"signalTally": tally_dict,
"hasMaterialityLanguage": mat_lang,
"currentMajority": maj,
"materialityNoFlag": mat_no_flag,
}
)
# 4. Output
out_path = DATA / "gold" / "holdout-rerun-v35.jsonl"
with open(out_path, "w") as f:
for r in results:
f.write(json.dumps(r) + "\n")
print(f"\n--- Confusion Axis Summary ---")
print(f"SI <-> N/O splits: {axis_counts['SI_NO']}")
print(f"MR <-> RMP splits: {axis_counts['MR_RMP']}")
print(f"BG <-> MR splits: {axis_counts['BG_MR']}")
print(f"BG <-> RMP splits: {axis_counts['BG_RMP']}")
print(f"Materiality lang + majority N/O: {materiality_no_count}")
print(f"\nTotal unique paragraphs needing re-run: {len(results)}")
cost = len(results) * 0.005 * 5
print(f"Estimated cost at $0.005/paragraph x 5 models: ${cost:.2f}")
if __name__ == "__main__":
main()