202 lines
7.0 KiB
Python
202 lines
7.0 KiB
Python
"""
|
|
Identify holdout paragraphs on confusion axes that need v3.5 re-annotation.
|
|
|
|
Builds a 13-signal matrix from all available sources:
|
|
- 3 human annotators (per paragraph)
|
|
- 1 Opus golden annotation
|
|
- Up to 6 bench-holdout model annotations
|
|
- Stage 1 patched annotations (filtered to holdout PIDs)
|
|
|
|
Flags paragraphs splitting on:
|
|
1. SI <-> N/O (at least 2 signals each side)
|
|
2. MR <-> RMP (at least 2 signals each side)
|
|
3. BG <-> MR (at least 2 signals each side)
|
|
4. BG <-> RMP (at least 2 signals each side)
|
|
5. Materiality language present but majority says N/O
|
|
"""
|
|
|
|
import json
|
|
import re
|
|
from collections import Counter
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parent.parent
|
|
DATA = ROOT / "data"
|
|
|
|
# Short names for categories
|
|
ABBREV = {
|
|
"Board Governance": "BG",
|
|
"Incident Disclosure": "ID",
|
|
"Management Role": "MR",
|
|
"None/Other": "NO",
|
|
"Risk Management Process": "RMP",
|
|
"Strategy Integration": "SI",
|
|
"Third-Party Risk": "TPR",
|
|
}
|
|
|
|
# Materiality language patterns
|
|
MATERIALITY_PATTERNS = [
|
|
re.compile(r"material(ly)?\s+(adverse|impact|effect|affect)", re.IGNORECASE),
|
|
re.compile(r"materially\s+affect(ed)?", re.IGNORECASE),
|
|
re.compile(r"material\s+cybersecurity\s+(incident|threat|event)", re.IGNORECASE),
|
|
re.compile(r"not\s+(experienced|had|identified)\s+.{0,40}material", re.IGNORECASE),
|
|
re.compile(r"reasonably\s+likely\s+to\s+materially", re.IGNORECASE),
|
|
re.compile(r"material(ity)?\s+(assessment|conclusion|determination)", re.IGNORECASE),
|
|
re.compile(r"no\s+material\s+(impact|effect|cybersecurity)", re.IGNORECASE),
|
|
re.compile(
|
|
r"have\s+not\s+.{0,30}materially\s+affect(ed)?", re.IGNORECASE
|
|
),
|
|
]
|
|
|
|
|
|
def has_materiality_language(text: str) -> bool:
|
|
return any(p.search(text) for p in MATERIALITY_PATTERNS)
|
|
|
|
|
|
def majority_category(tally: Counter) -> str:
|
|
if not tally:
|
|
return "UNKNOWN"
|
|
return tally.most_common(1)[0][0]
|
|
|
|
|
|
def main():
|
|
# 1. Determine the 1,200 holdout PIDs from human labels
|
|
holdout_pids: set[str] = set()
|
|
human_labels: dict[str, list[str]] = {} # pid -> list of abbreviated cats
|
|
with open(DATA / "gold" / "human-labels-raw.jsonl") as f:
|
|
for line in f:
|
|
d = json.loads(line)
|
|
pid = d["paragraphId"]
|
|
holdout_pids.add(pid)
|
|
human_labels.setdefault(pid, []).append(
|
|
ABBREV.get(d["contentCategory"], d["contentCategory"])
|
|
)
|
|
|
|
# Load paragraph texts for the holdout PIDs
|
|
holdout_paragraphs: dict[str, str] = {}
|
|
with open(DATA / "gold" / "paragraphs-holdout.jsonl") as f:
|
|
for line in f:
|
|
d = json.loads(line)
|
|
if d["id"] in holdout_pids:
|
|
holdout_paragraphs[d["id"]] = d["text"]
|
|
|
|
print(f"Total holdout paragraphs: {len(holdout_pids)}")
|
|
|
|
# 2. Build signal matrix: pid -> list of category strings (abbreviated)
|
|
signals: dict[str, list[str]] = {pid: list(cats) for pid, cats in human_labels.items()}
|
|
|
|
# 2a. Human labels already loaded above
|
|
print(f"Paragraphs with human labels: {len(human_labels)}")
|
|
|
|
# 2b. Opus golden
|
|
with open(DATA / "annotations" / "golden" / "opus.jsonl") as f:
|
|
for line in f:
|
|
d = json.loads(line)
|
|
pid = d["paragraphId"]
|
|
if pid in holdout_pids:
|
|
cat = ABBREV.get(
|
|
d["label"]["content_category"], d["label"]["content_category"]
|
|
)
|
|
signals[pid].append(cat)
|
|
|
|
# 2c. Bench-holdout model annotations (skip error files)
|
|
bench_dir = DATA / "annotations" / "bench-holdout"
|
|
for fpath in sorted(bench_dir.glob("*.jsonl")):
|
|
if "-errors" in fpath.name:
|
|
continue
|
|
with open(fpath) as f:
|
|
for line in f:
|
|
d = json.loads(line)
|
|
pid = d.get("paragraphId")
|
|
if pid and pid in holdout_pids and "label" in d:
|
|
cat = ABBREV.get(
|
|
d["label"]["content_category"],
|
|
d["label"]["content_category"],
|
|
)
|
|
signals[pid].append(cat)
|
|
|
|
# 2d. Stage 1 patched (filter to holdout PIDs)
|
|
with open(DATA / "annotations" / "stage1.patched.jsonl") as f:
|
|
for line in f:
|
|
d = json.loads(line)
|
|
pid = d["paragraphId"]
|
|
if pid in holdout_pids:
|
|
cat = ABBREV.get(
|
|
d["label"]["content_category"], d["label"]["content_category"]
|
|
)
|
|
signals[pid].append(cat)
|
|
|
|
# Report signal counts
|
|
signal_counts = [len(signals[pid]) for pid in holdout_pids]
|
|
print(
|
|
f"Signals per paragraph: min={min(signal_counts)}, max={max(signal_counts)}, "
|
|
f"mean={sum(signal_counts)/len(signal_counts):.1f}"
|
|
)
|
|
|
|
# 3. Check confusion axes
|
|
AXES = {
|
|
"SI_NO": ("SI", "NO"),
|
|
"MR_RMP": ("MR", "RMP"),
|
|
"BG_MR": ("BG", "MR"),
|
|
"BG_RMP": ("BG", "RMP"),
|
|
}
|
|
|
|
axis_counts: dict[str, int] = {k: 0 for k in AXES}
|
|
materiality_no_count = 0
|
|
results: list[dict] = []
|
|
|
|
for pid in sorted(holdout_pids):
|
|
tally = Counter(signals[pid])
|
|
maj = majority_category(tally)
|
|
text = holdout_paragraphs[pid]
|
|
mat_lang = has_materiality_language(text)
|
|
|
|
# Check each axis
|
|
flagged_axes: list[str] = []
|
|
for axis_name, (cat_a, cat_b) in AXES.items():
|
|
if tally.get(cat_a, 0) >= 2 and tally.get(cat_b, 0) >= 2:
|
|
flagged_axes.append(axis_name)
|
|
|
|
# Materiality language + majority N/O
|
|
mat_no_flag = mat_lang and maj == "NO"
|
|
|
|
if flagged_axes or mat_no_flag:
|
|
for axis_name in flagged_axes:
|
|
axis_counts[axis_name] += 1
|
|
if mat_no_flag:
|
|
materiality_no_count += 1
|
|
|
|
# Build tally dict with full names for output readability
|
|
tally_dict = dict(tally.most_common())
|
|
|
|
results.append(
|
|
{
|
|
"paragraphId": pid,
|
|
"axes": flagged_axes if flagged_axes else [],
|
|
"signalTally": tally_dict,
|
|
"hasMaterialityLanguage": mat_lang,
|
|
"currentMajority": maj,
|
|
"materialityNoFlag": mat_no_flag,
|
|
}
|
|
)
|
|
|
|
# 4. Output
|
|
out_path = DATA / "gold" / "holdout-rerun-v35.jsonl"
|
|
with open(out_path, "w") as f:
|
|
for r in results:
|
|
f.write(json.dumps(r) + "\n")
|
|
|
|
print(f"\n--- Confusion Axis Summary ---")
|
|
print(f"SI <-> N/O splits: {axis_counts['SI_NO']}")
|
|
print(f"MR <-> RMP splits: {axis_counts['MR_RMP']}")
|
|
print(f"BG <-> MR splits: {axis_counts['BG_MR']}")
|
|
print(f"BG <-> RMP splits: {axis_counts['BG_RMP']}")
|
|
print(f"Materiality lang + majority N/O: {materiality_no_count}")
|
|
print(f"\nTotal unique paragraphs needing re-run: {len(results)}")
|
|
cost = len(results) * 0.005 * 5
|
|
print(f"Estimated cost at $0.005/paragraph x 5 models: ${cost:.2f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|