SEC-cyBERT/scripts/show-hard-examples.py
2026-04-03 14:43:53 -04:00

531 lines
23 KiB
Python

"""
Show carefully selected hard-case paragraphs from the holdout set for each confusion axis.
Displays full paragraph text + compact 13-signal label table + vote tally.
Run: uv run --with numpy scripts/show-hard-examples.py
"""
import json
import os
from collections import Counter, defaultdict
from pathlib import Path
from textwrap import fill
import numpy as np
ROOT = Path(__file__).resolve().parent.parent
# ── Category abbreviation map ──────────────────────────────────────────────
FULL_TO_ABBR = {
"Board Governance": "BG",
"Incident Disclosure": "ID",
"Management Role": "MR",
"None/Other": "N/O",
"Risk Management Process": "RMP",
"Strategy Integration": "SI",
"Third-Party Risk": "TPR",
}
# ── Short source-name helpers ──────────────────────────────────────────────
S1_MODEL_SHORT = {
"google/gemini-3.1-flash-lite-preview": "gemini-lite",
"x-ai/grok-4.1-fast": "grok-fast",
"xiaomi/mimo-v2-flash": "mimo-flash",
}
BENCH_FILE_SHORT = {
"gpt-5.4": "gpt-5.4",
"gemini-3.1-pro-preview": "gemini-pro",
"glm-5:exacto": "glm-5",
"kimi-k2.5": "kimi",
"mimo-v2-pro:exacto": "mimo-pro",
"minimax-m2.7:exacto": "minimax",
}
BENCH_FILES = [
"gpt-5.4",
"gemini-3.1-pro-preview",
"glm-5:exacto",
"kimi-k2.5",
"mimo-v2-pro:exacto",
"minimax-m2.7:exacto",
]
def load_jsonl(path: str | Path) -> list[dict]:
rows = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
# ── Load data ──────────────────────────────────────────────────────────────
print("Loading data...")
paragraphs_raw = load_jsonl(ROOT / "data/gold/paragraphs-holdout.jsonl")
para_map: dict[str, dict] = {p["id"]: p for p in paragraphs_raw}
holdout_pids = set(para_map.keys())
human_raw = load_jsonl(ROOT / "data/gold/human-labels-raw.jsonl")
opus_raw = load_jsonl(ROOT / "data/annotations/golden/opus.jsonl")
stage1_raw = load_jsonl(ROOT / "data/annotations/stage1.patched.jsonl")
# ── Build signal matrix: pid → {source_label: category_abbr} ─────────────
signals: dict[str, dict[str, str]] = defaultdict(dict)
# 1) Human annotators
for row in human_raw:
pid = row["paragraphId"]
name = row["annotatorName"]
cat = FULL_TO_ABBR.get(row["contentCategory"], row["contentCategory"])
signals[pid][f"H:{name}"] = cat
# 2) Opus
for row in opus_raw:
pid = row["paragraphId"]
cat = FULL_TO_ABBR.get(row["label"]["content_category"], row["label"]["content_category"])
signals[pid]["Opus"] = cat
# 3) Stage 1 (filter to holdout PIDs)
for row in stage1_raw:
pid = row["paragraphId"]
if pid not in holdout_pids:
continue
model_id = row["provenance"]["modelId"]
short = S1_MODEL_SHORT.get(model_id, model_id)
source = f"S1:{short}"
cat = FULL_TO_ABBR.get(row["label"]["content_category"], row["label"]["content_category"])
signals[pid][source] = cat
# 4) Benchmark models
for bench_name in BENCH_FILES:
path = ROOT / f"data/annotations/bench-holdout/{bench_name}.jsonl"
short = BENCH_FILE_SHORT[bench_name]
rows = load_jsonl(path)
for row in rows:
pid = row["paragraphId"]
cat = FULL_TO_ABBR.get(row["label"]["content_category"], row["label"]["content_category"])
signals[pid][short] = cat
# ── Ordered source list (for display) ─────────────────────────────────────
HUMAN_NAMES = sorted({r["annotatorName"] for r in human_raw})
ORDERED_SOURCES = (
[f"H:{n}" for n in HUMAN_NAMES]
+ ["Opus"]
+ [f"S1:{S1_MODEL_SHORT[m]}" for m in sorted(S1_MODEL_SHORT)]
+ [BENCH_FILE_SHORT[b] for b in BENCH_FILES]
)
# ── Utility: compute axis stats ───────────────────────────────────────────
def axis_candidates(cat_a: str, cat_b: str, extra_cat: str | None = None) -> list[tuple[str, dict, Counter]]:
"""Find PIDs where both cat_a and cat_b appear among the 13 signals.
Returns list of (pid, signals_dict, vote_counter) sorted by closeness of split."""
results = []
for pid, sigs in signals.items():
if pid not in holdout_pids:
continue
counts = Counter(sigs.values())
cats_present = set(counts.keys())
if cat_a in cats_present and cat_b in cats_present:
if extra_cat is not None and extra_cat not in cats_present:
continue
# closeness = min(count_a, count_b) / total — higher is closer split
total = sum(counts.values())
closeness = min(counts[cat_a], counts[cat_b]) / total
results.append((pid, sigs, counts, closeness))
# Sort by closeness (descending), then by total signal count (descending) as tiebreaker
results.sort(key=lambda x: (-x[3], -sum(x[2].values())))
return [(pid, sigs, counts) for pid, sigs, counts, _ in results]
def print_example(pid: str, sigs: dict, counts: Counter, sub_pattern: str, note: str = ""):
"""Print one example paragraph with signals."""
para = para_map.get(pid)
if not para:
print(f" [paragraph {pid} not found]")
return
print(f" ┌─ Paragraph {pid}")
print(f" │ Company: {para.get('companyName', '?')} | Filing: {para.get('filingType', '?')} {para.get('filingDate', '?')}")
print(f" │ Sub-pattern: {sub_pattern}")
print(f"")
# Full text — wrap at 100 chars, indent
text = para["text"]
for line in text.split("\n"):
wrapped = fill(line, width=100, initial_indent="", subsequent_indent="")
print(wrapped)
print(f"")
# Signal table — compact single line
parts = []
for src in ORDERED_SOURCES:
if src in sigs:
parts.append(f"{src}={sigs[src]}")
print(f" │ Signals: {', '.join(parts)}")
# Vote tally
tally_parts = [f"{cat}: {n}" for cat, n in counts.most_common()]
print(f" │ Tally: {', '.join(tally_parts)} (out of {sum(counts.values())})")
if note:
print(f"")
for line in note.split("\n"):
wrapped = fill(line, width=100, initial_indent=" │ ▸ ", subsequent_indent="")
print(wrapped)
print(f"{'' * 78}")
print()
def pick_diverse(candidates: list[tuple[str, dict, Counter]], n: int, min_signals: int = 10) -> list[tuple[str, dict, Counter]]:
"""Pick n diverse examples from candidates (different companies, prefer many signals)."""
if len(candidates) <= n:
return candidates
# Filter to examples with enough signals for a meaningful table
rich = [(pid, sigs, counts) for pid, sigs, counts in candidates if sum(counts.values()) >= min_signals]
if len(rich) < n:
rich = candidates # fall back if not enough rich examples
# Diversify by company
seen_companies: set[str] = set()
selected = []
for pid, sigs, counts in rich:
company = para_map.get(pid, {}).get("companyName", "")
if company in seen_companies and len(rich) > n * 2:
continue
selected.append((pid, sigs, counts))
seen_companies.add(company)
if len(selected) >= n * 3:
break
return selected[:n]
# ══════════════════════════════════════════════════════════════════════════
# AXIS 1: MR ↔ RMP
# ══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 80)
print(" AXIS 1: MR ↔ RMP — Management Role vs. Risk Management Process")
print("=" * 80)
mr_rmp = axis_candidates("MR", "RMP")
print(f"\n Total paragraphs with both MR and RMP in signals: {len(mr_rmp)}\n")
def classify_mr_rmp_subpattern(text: str) -> str:
"""Heuristic to guess sub-pattern for MR↔RMP confusion."""
text_lower = text.lower()
sentences = [s.strip() for s in text.replace("\n", " ").split(".") if s.strip()]
person_keywords = [
"ciso", "chief information security", "chief information officer",
"cio", "vp ", "vice president", "director", "officer", "head of",
"manager", "leader", "executive", "cto", "chief technology",
]
process_keywords = [
"program", "framework", "process", "policy", "policies",
"procedures", "controls", "assessment", "monitoring",
"risk management", "incident response", "vulnerability",
]
person_subject_sentences = 0
process_subject_sentences = 0
for sent in sentences:
sent_lower = sent.lower().strip()
has_person = any(kw in sent_lower[:80] for kw in person_keywords)
has_process = any(kw in sent_lower[:80] for kw in process_keywords)
if has_person:
person_subject_sentences += 1
if has_process:
process_subject_sentences += 1
if person_subject_sentences > 0 and process_subject_sentences == 0:
return "person-subject"
elif process_subject_sentences > 0 and person_subject_sentences == 0:
return "process-subject"
elif person_subject_sentences > 0 and process_subject_sentences > 0:
return "mixed"
else:
return "other"
# Bucket candidates by sub-pattern
buckets: dict[str, list] = {"person-subject": [], "process-subject": [], "mixed": [], "other": []}
for pid, sigs, counts in mr_rmp:
text = para_map.get(pid, {}).get("text", "")
sp = classify_mr_rmp_subpattern(text)
buckets[sp].append((pid, sigs, counts))
print(f" Sub-pattern distribution: person-subject={len(buckets['person-subject'])}, "
f"process-subject={len(buckets['process-subject'])}, mixed={len(buckets['mixed'])}, "
f"other={len(buckets['other'])}")
print()
# (a) Person is grammatical subject
print(" ── (a) Person is the grammatical subject, doing process-like things ──\n")
for pid, sigs, counts in pick_diverse(buckets["person-subject"], 2):
text = para_map[pid]["text"]
# Subject test note
note = "SUBJECT TEST → MR (person is the main subject)"
print_example(pid, sigs, counts, "Person as subject doing process-like things", note)
# (b) Process/framework is subject
print(" ── (b) Process/framework is the subject, person mentioned as responsible ──\n")
for pid, sigs, counts in pick_diverse(buckets["process-subject"], 2):
text = para_map[pid]["text"]
note = "SUBJECT TEST → RMP (process/framework is the main subject)"
print_example(pid, sigs, counts, "Process as subject, person mentioned", note)
# (c) Mixed
print(" ── (c) Mixed — both person and process are subjects ──\n")
for pid, sigs, counts in pick_diverse(buckets["mixed"], 2):
note = "SUBJECT TEST → AMBIGUOUS (both person and process serve as subjects)"
print_example(pid, sigs, counts, "Mixed subjects", note)
# (d) Edge cases — closest splits from "other" or overall closest
print(" ── (d) Edge cases — genuinely hard to call ──\n")
# Take from overall closest that aren't already shown
shown_pids = set()
for bucket in buckets.values():
for pid, _, _ in bucket[:2]:
shown_pids.add(pid)
edge_cases = [(p, s, c) for p, s, c in mr_rmp if p not in shown_pids][:20]
for pid, sigs, counts in pick_diverse(edge_cases, 2):
mr_count = counts.get("MR", 0)
rmp_count = counts.get("RMP", 0)
note = f"SUBJECT TEST → unclear; split is {mr_count}-{rmp_count} MR-RMP"
print_example(pid, sigs, counts, "Edge case", note)
# ══════════════════════════════════════════════════════════════════════════
# AXIS 2: BG ↔ MR
# ══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 80)
print(" AXIS 2: BG ↔ MR — Board Governance vs. Management Role")
print("=" * 80)
bg_mr = axis_candidates("BG", "MR")
print(f"\n Total paragraphs with both BG and MR in signals: {len(bg_mr)}\n")
def classify_bg_mr_subpattern(text: str) -> str:
text_lower = text.lower()
board_words = ["board", "committee", "audit committee", "directors"]
mgmt_words = ["ciso", "chief information", "officer", "vp", "vice president",
"director of", "head of", "reports to", "briefing", "briefs",
"presents to", "reporting"]
has_board_actor = any(w in text_lower for w in board_words)
has_mgmt_reporting = any(w in text_lower for w in mgmt_words)
if has_board_actor and not has_mgmt_reporting:
return "board-actor"
elif has_mgmt_reporting and has_board_actor:
return "mgmt-reporting-to-board"
elif has_mgmt_reporting:
return "mgmt-only"
else:
return "mixed-governance"
buckets_bg: dict[str, list] = defaultdict(list)
for pid, sigs, counts in bg_mr:
sp = classify_bg_mr_subpattern(para_map.get(pid, {}).get("text", ""))
buckets_bg[sp].append((pid, sigs, counts))
print(f" Sub-pattern distribution: {dict((k, len(v)) for k, v in buckets_bg.items())}")
print()
# (a) Board/committee is clearly the actor
print(" ── (a) Board/committee is clearly the actor ──\n")
pool = buckets_bg.get("board-actor", []) or buckets_bg.get("mixed-governance", [])
for pid, sigs, counts in pick_diverse(pool, 2):
print_example(pid, sigs, counts, "Board as actor")
# (b) Management officer reporting TO the board
print(" ── (b) Management officer reporting TO/briefing the board ──\n")
pool = buckets_bg.get("mgmt-reporting-to-board", [])
for pid, sigs, counts in pick_diverse(pool, 2):
note = "KEY QUESTION: Is this BG (board receiving info) or MR (officer doing the briefing)?"
print_example(pid, sigs, counts, "Management reporting to board", note)
# (c) Mixed governance
print(" ── (c) Mixed governance language ──\n")
remaining = [x for x in bg_mr if x[0] not in {p for bucket in buckets_bg.values() for p, _, _ in bucket[:2]}]
for pid, sigs, counts in pick_diverse(remaining, 2):
note = "Could be BG, MR, or RMP depending on interpretation"
print_example(pid, sigs, counts, "Mixed governance", note)
# ══════════════════════════════════════════════════════════════════════════
# AXIS 3: SI ↔ N/O
# ══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 80)
print(" AXIS 3: SI ↔ N/O — Strategy Integration vs. None/Other")
print("=" * 80)
si_no = axis_candidates("SI", "N/O")
print(f"\n Total paragraphs with both SI and N/O in signals: {len(si_no)}\n")
def classify_si_no_subpattern(text: str) -> str:
text_lower = text.lower()
incident_words = ["incident", "breach", "attack", "compromised", "unauthorized access",
"data breach", "ransomware", "phishing"]
negative_words = ["have not experienced", "not experienced", "no material",
"has not been materially", "not been the subject",
"not aware of any", "no known", "have not had"]
hypothetical_words = ["could", "may", "might", "would", "if ", "potential",
"face threats", "subject to"]
specific_words = ["$", "million", "vendor", "contract", "insurance",
"specific", "particular", "named"]
has_incident = any(w in text_lower for w in incident_words)
has_negative = any(w in text_lower for w in negative_words)
has_hypothetical = any(w in text_lower for w in hypothetical_words)
has_specific = any(w in text_lower for w in specific_words)
if has_incident and not has_negative:
return "actual-incident"
elif has_negative:
return "negative-assertion"
elif has_hypothetical and not has_specific:
return "hypothetical"
elif has_specific:
return "specific-no-incident"
else:
return "other"
buckets_si: dict[str, list] = defaultdict(list)
for pid, sigs, counts in si_no:
sp = classify_si_no_subpattern(para_map.get(pid, {}).get("text", ""))
buckets_si[sp].append((pid, sigs, counts))
print(f" Sub-pattern distribution: {dict((k, len(v)) for k, v in buckets_si.items())}")
print()
# Also find the 23 cases where humans=SI but GenAI=N/O
human_si_genai_no = []
for pid, sigs, counts in si_no:
human_cats = [sigs.get(f"H:{n}") for n in HUMAN_NAMES if f"H:{n}" in sigs]
genai_cats = [v for k, v in sigs.items() if not k.startswith("H:")]
human_si = sum(1 for c in human_cats if c == "SI")
human_no = sum(1 for c in human_cats if c == "N/O")
genai_si = sum(1 for c in genai_cats if c == "SI")
genai_no = sum(1 for c in genai_cats if c == "N/O")
if human_si > human_no and genai_no > genai_si:
human_si_genai_no.append((pid, sigs, counts))
print(f" Cases where humans lean SI but GenAI leans N/O: {len(human_si_genai_no)}")
print()
# (a) Clear actual incident
print(" ── (a) Clear actual incident described ──\n")
for pid, sigs, counts in pick_diverse(buckets_si.get("actual-incident", []), 2):
print_example(pid, sigs, counts, "Actual incident")
# (b) Negative assertion
print(" ── (b) Negative assertion — 'we have not experienced material incidents' ──\n")
neg_pool = buckets_si.get("negative-assertion", [])
# Prefer ones in the human-SI-genAI-NO set
neg_human_si = [x for x in neg_pool if x[0] in {p for p, _, _ in human_si_genai_no}]
neg_other = [x for x in neg_pool if x[0] not in {p for p, _, _ in human_si_genai_no}]
pool = neg_human_si[:2] if len(neg_human_si) >= 2 else (neg_human_si + neg_other)[:2]
for pid, sigs, counts in pool:
human_cats = [sigs.get(f"H:{n}") for n in HUMAN_NAMES if f"H:{n}" in sigs]
genai_cats = [v for k, v in sigs.items() if not k.startswith("H:")]
note = (f"CRUX: Humans keyed on the materiality assessment language. "
f"Human votes: {Counter(human_cats).most_common()}, "
f"GenAI votes: {Counter(genai_cats).most_common()}")
print_example(pid, sigs, counts, "Negative assertion", note)
# (c) Hypothetical/conditional
print(" ── (c) Hypothetical/conditional language ──\n")
for pid, sigs, counts in pick_diverse(buckets_si.get("hypothetical", []), 2):
print_example(pid, sigs, counts, "Hypothetical/conditional")
# (d) Specific programs/vendors/amounts but no incident
print(" ── (d) Specific programs/vendors/amounts but no incident ──\n")
spec_pool = buckets_si.get("specific-no-incident", [])
if len(spec_pool) < 2:
spec_pool += buckets_si.get("other", [])
for pid, sigs, counts in pick_diverse(spec_pool, 2):
note = "SI because specific details? Or N/O because no event/strategy content?"
print_example(pid, sigs, counts, "Specific but no incident", note)
# Extra: show human-SI / genAI-N/O cases not already shown
shown_si = set()
for bucket in buckets_si.values():
for p, _, _ in bucket[:2]:
shown_si.add(p)
extra_human_si = [x for x in human_si_genai_no if x[0] not in shown_si]
if extra_human_si:
print(" ── (extra) Additional human=SI, GenAI=N/O cases ──\n")
for pid, sigs, counts in pick_diverse(extra_human_si, 2):
human_cats = [sigs.get(f"H:{n}") for n in HUMAN_NAMES if f"H:{n}" in sigs]
genai_cats = [v for k, v in sigs.items() if not k.startswith("H:")]
note = (f"Humans: {Counter(human_cats).most_common()}, "
f"GenAI: {Counter(genai_cats).most_common()}")
print_example(pid, sigs, counts, "Human=SI, GenAI=N/O", note)
# ══════════════════════════════════════════════════════════════════════════
# AXIS 4: Three-way BG ↔ MR ↔ RMP
# ══════════════════════════════════════════════════════════════════════════
print("\n" + "=" * 80)
print(" AXIS 4: Three-way BG ↔ MR ↔ RMP")
print("=" * 80)
three_way = []
for pid, sigs in signals.items():
if pid not in holdout_pids:
continue
counts = Counter(sigs.values())
if "BG" in counts and "MR" in counts and "RMP" in counts:
# Score by how evenly split the three are
vals = [counts["BG"], counts["MR"], counts["RMP"]]
total_3 = sum(vals)
evenness = min(vals) / max(vals) if max(vals) > 0 else 0
three_way.append((pid, sigs, counts, evenness))
three_way.sort(key=lambda x: (-x[3], -sum(x[2].values())))
print(f"\n Total paragraphs with all three of BG, MR, RMP: {len(three_way)}\n")
# Pick diverse examples with enough signals
seen_co: set[str] = set()
three_way_selected = []
for pid, sigs, counts, evenness in three_way:
if sum(counts.values()) < 10:
continue
co = para_map.get(pid, {}).get("companyName", "")
if co in seen_co:
continue
seen_co.add(co)
three_way_selected.append((pid, sigs, counts, evenness))
if len(three_way_selected) >= 3:
break
for pid, sigs, counts, evenness in three_way_selected:
bg_c, mr_c, rmp_c = counts["BG"], counts["MR"], counts["RMP"]
note = (f"Three-way split: BG={bg_c}, MR={mr_c}, RMP={rmp_c}. "
f"This paragraph intertwines governance, management roles, and process descriptions.")
print_example(pid, sigs, counts, "Three-way BG/MR/RMP", note)
# ── Summary statistics ────────────────────────────────────────────────────
print("\n" + "=" * 80)
print(" SUMMARY")
print("=" * 80)
print(f"""
Axis 1 (MR↔RMP): {len(mr_rmp)} paragraphs with split signals
Axis 2 (BG↔MR): {len(bg_mr)} paragraphs with split signals
Axis 3 (SI↔N/O): {len(si_no)} paragraphs with split signals
Axis 4 (BG↔MR↔RMP): {len(three_way)} paragraphs with three-way split
Human=SI/GenAI=N/O: {len(human_si_genai_no)} cases (directional asymmetry)
""")