""" Identify holdout paragraphs on confusion axes that need v3.5 re-annotation. Builds a 13-signal matrix from all available sources: - 3 human annotators (per paragraph) - 1 Opus golden annotation - Up to 6 bench-holdout model annotations - Stage 1 patched annotations (filtered to holdout PIDs) Flags paragraphs splitting on: 1. SI <-> N/O (at least 2 signals each side) 2. MR <-> RMP (at least 2 signals each side) 3. BG <-> MR (at least 2 signals each side) 4. BG <-> RMP (at least 2 signals each side) 5. Materiality language present but majority says N/O """ import json import re from collections import Counter from pathlib import Path ROOT = Path(__file__).resolve().parent.parent DATA = ROOT / "data" # Short names for categories ABBREV = { "Board Governance": "BG", "Incident Disclosure": "ID", "Management Role": "MR", "None/Other": "NO", "Risk Management Process": "RMP", "Strategy Integration": "SI", "Third-Party Risk": "TPR", } # Materiality language patterns MATERIALITY_PATTERNS = [ re.compile(r"material(ly)?\s+(adverse|impact|effect|affect)", re.IGNORECASE), re.compile(r"materially\s+affect(ed)?", re.IGNORECASE), re.compile(r"material\s+cybersecurity\s+(incident|threat|event)", re.IGNORECASE), re.compile(r"not\s+(experienced|had|identified)\s+.{0,40}material", re.IGNORECASE), re.compile(r"reasonably\s+likely\s+to\s+materially", re.IGNORECASE), re.compile(r"material(ity)?\s+(assessment|conclusion|determination)", re.IGNORECASE), re.compile(r"no\s+material\s+(impact|effect|cybersecurity)", re.IGNORECASE), re.compile( r"have\s+not\s+.{0,30}materially\s+affect(ed)?", re.IGNORECASE ), ] def has_materiality_language(text: str) -> bool: return any(p.search(text) for p in MATERIALITY_PATTERNS) def majority_category(tally: Counter) -> str: if not tally: return "UNKNOWN" return tally.most_common(1)[0][0] def main(): # 1. Determine the 1,200 holdout PIDs from human labels holdout_pids: set[str] = set() human_labels: dict[str, list[str]] = {} # pid -> list of abbreviated cats with open(DATA / "gold" / "human-labels-raw.jsonl") as f: for line in f: d = json.loads(line) pid = d["paragraphId"] holdout_pids.add(pid) human_labels.setdefault(pid, []).append( ABBREV.get(d["contentCategory"], d["contentCategory"]) ) # Load paragraph texts for the holdout PIDs holdout_paragraphs: dict[str, str] = {} with open(DATA / "gold" / "paragraphs-holdout.jsonl") as f: for line in f: d = json.loads(line) if d["id"] in holdout_pids: holdout_paragraphs[d["id"]] = d["text"] print(f"Total holdout paragraphs: {len(holdout_pids)}") # 2. Build signal matrix: pid -> list of category strings (abbreviated) signals: dict[str, list[str]] = {pid: list(cats) for pid, cats in human_labels.items()} # 2a. Human labels already loaded above print(f"Paragraphs with human labels: {len(human_labels)}") # 2b. Opus golden with open(DATA / "annotations" / "golden" / "opus.jsonl") as f: for line in f: d = json.loads(line) pid = d["paragraphId"] if pid in holdout_pids: cat = ABBREV.get( d["label"]["content_category"], d["label"]["content_category"] ) signals[pid].append(cat) # 2c. Bench-holdout model annotations (skip error files) bench_dir = DATA / "annotations" / "bench-holdout" for fpath in sorted(bench_dir.glob("*.jsonl")): if "-errors" in fpath.name: continue with open(fpath) as f: for line in f: d = json.loads(line) pid = d.get("paragraphId") if pid and pid in holdout_pids and "label" in d: cat = ABBREV.get( d["label"]["content_category"], d["label"]["content_category"], ) signals[pid].append(cat) # 2d. Stage 1 patched (filter to holdout PIDs) with open(DATA / "annotations" / "stage1.patched.jsonl") as f: for line in f: d = json.loads(line) pid = d["paragraphId"] if pid in holdout_pids: cat = ABBREV.get( d["label"]["content_category"], d["label"]["content_category"] ) signals[pid].append(cat) # Report signal counts signal_counts = [len(signals[pid]) for pid in holdout_pids] print( f"Signals per paragraph: min={min(signal_counts)}, max={max(signal_counts)}, " f"mean={sum(signal_counts)/len(signal_counts):.1f}" ) # 3. Check confusion axes AXES = { "SI_NO": ("SI", "NO"), "MR_RMP": ("MR", "RMP"), "BG_MR": ("BG", "MR"), "BG_RMP": ("BG", "RMP"), } axis_counts: dict[str, int] = {k: 0 for k in AXES} materiality_no_count = 0 results: list[dict] = [] for pid in sorted(holdout_pids): tally = Counter(signals[pid]) maj = majority_category(tally) text = holdout_paragraphs[pid] mat_lang = has_materiality_language(text) # Check each axis flagged_axes: list[str] = [] for axis_name, (cat_a, cat_b) in AXES.items(): if tally.get(cat_a, 0) >= 2 and tally.get(cat_b, 0) >= 2: flagged_axes.append(axis_name) # Materiality language + majority N/O mat_no_flag = mat_lang and maj == "NO" if flagged_axes or mat_no_flag: for axis_name in flagged_axes: axis_counts[axis_name] += 1 if mat_no_flag: materiality_no_count += 1 # Build tally dict with full names for output readability tally_dict = dict(tally.most_common()) results.append( { "paragraphId": pid, "axes": flagged_axes if flagged_axes else [], "signalTally": tally_dict, "hasMaterialityLanguage": mat_lang, "currentMajority": maj, "materialityNoFlag": mat_no_flag, } ) # 4. Output out_path = DATA / "gold" / "holdout-rerun-v35.jsonl" with open(out_path, "w") as f: for r in results: f.write(json.dumps(r) + "\n") print(f"\n--- Confusion Axis Summary ---") print(f"SI <-> N/O splits: {axis_counts['SI_NO']}") print(f"MR <-> RMP splits: {axis_counts['MR_RMP']}") print(f"BG <-> MR splits: {axis_counts['BG_MR']}") print(f"BG <-> RMP splits: {axis_counts['BG_RMP']}") print(f"Materiality lang + majority N/O: {materiality_no_count}") print(f"\nTotal unique paragraphs needing re-run: {len(results)}") cost = len(results) * 0.005 * 5 print(f"Estimated cost at $0.005/paragraph x 5 models: ${cost:.2f}") if __name__ == "__main__": main()