SEC-cyBERT/scripts/sample-v2-holdout.py
2026-04-04 22:49:24 -04:00

347 lines
14 KiB
Python

"""
Sample v2 holdout set (1,200 paragraphs) using v1 Stage 1 consensus as guide.
Applies heuristic v2 specificity prediction:
- v1 L1 with domain terminology keywords → predicted L2
- v1 L3 with 1+ QV indicator → predicted L4
Stratified by category (185 per non-ID, 90 ID), with:
- Max 2 paragraphs per company per category stratum
- Secondary floor: ≥100 per predicted v2 specificity level
- Random within strata (no difficulty weighting)
Outputs:
- data/gold/v2-holdout-ids.json (list of paragraph IDs)
- data/gold/v2-holdout-manifest.jsonl (full metadata per paragraph)
"""
import json
import random
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
random.seed(42) # reproducible sampling
DATA = Path("data")
PARAGRAPHS_PATH = DATA / "paragraphs" / "paragraphs-clean.patched.jsonl"
STAGE1_PATH = DATA / "annotations" / "stage1.patched.jsonl"
V1_HOLDOUT_PATH = Path("labelapp") / ".sampled-ids.original.json"
OUTPUT_IDS = DATA / "gold" / "v2-holdout-ids.json"
OUTPUT_MANIFEST = DATA / "gold" / "v2-holdout-manifest.jsonl"
# ── Allocation ──────────────────────────────────────────────────────────────
TOTAL = 1200
CATEGORY_ALLOC = {
"Board Governance": 185,
"Management Role": 185,
"Risk Management Process": 185,
"Third-Party Risk": 185,
"Strategy Integration": 185,
"None/Other": 185,
"Incident Disclosure": 90,
}
assert sum(CATEGORY_ALLOC.values()) == TOTAL
SPECIFICITY_FLOOR = 100 # minimum per predicted v2 specificity level
MAX_PER_COMPANY_PER_STRATUM = 2
# ── Domain terminology keywords (v2 codebook) ──────────────────────────────
# Applied to v1 L1 paragraphs: any match → predicted L2
DOMAIN_TERMS = [
# Practices and activities
r"penetration test", r"pen test", r"vulnerability scan", r"vulnerability assess",
r"red team", r"phishing simul", r"security awareness training",
r"threat hunt", r"threat intellig", r"patch management",
r"identity and access management", r"\bIAM\b",
r"data loss prevention", r"\bDLP\b",
r"network segmentation", r"encryption",
# Tools and infrastructure
r"\bSIEM\b", r"security information and event management",
r"\bSOC\b", r"security operations center",
r"\bEDR\b", r"\bXDR\b", r"\bMDR\b", r"endpoint detection",
r"\bWAF\b", r"web application firewall",
r"\bIDS\b", r"\bIPS\b", r"intrusion detection", r"intrusion prevention",
r"\bMFA\b", r"multi-factor auth", r"two-factor auth", r"\b2FA\b",
r"\bfirewall\b", r"antivirus", r"anti-malware",
# Architectural concepts
r"zero trust", r"defense in depth", r"least privilege",
# Named standards (already L2 in v1, but catch any misses)
r"\bNIST\b", r"ISO 27001", r"ISO 27002", r"\bSOC 2\b", r"CIS Controls",
r"PCI[ -]DSS", r"\bHIPAA\b", r"\bGDPR\b", r"\bCOBIT\b", r"MITRE ATT",
# Threat types
r"ransomware", r"\bmalware\b", r"phishing",
r"\bDDoS\b", r"supply chain attack", r"social engineering",
r"advanced persistent threat", r"\bAPT\b", r"zero[- ]day",
]
DOMAIN_PATTERNS = [re.compile(t, re.IGNORECASE) for t in DOMAIN_TERMS]
# ── QV indicators (v2 codebook: 1+ → Level 4) ──────────────────────────────
# Applied to v1 L3 paragraphs: any match → predicted L4
QV_PATTERNS_RAW = [
r"\$[\d,]+", # dollar amounts
r"\b\d{1,3}(?:,\d{3})*\s*(?:million|billion|thousand)\b",
# Specific dates (month + year or exact)
r"\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{4}\b",
r"\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{4}\b",
# Named certifications
r"\bCISSP\b", r"\bCISM\b", r"\bCEH\b", r"\bCRISC\b",
# Named third-party firms
r"\bMandiant\b", r"\bCrowdStrike\b", r"\bDeloitte\b", r"\bKPMG\b",
r"\bPricewaterhouse\b", r"\bPwC\b", r"\bErnst\s*&\s*Young\b", r"\bEY\b",
r"\bAccenture\b", r"\bBooz Allen\b", r"\bProtiviti\b", r"\bKroll\b",
# Named products/tools
r"\bSplunk\b", r"\bAzure Sentinel\b", r"\bCrowdStrike Falcon\b",
r"\bServiceNow\b", r"\bPalo Alto\b", r"\bFortinet\b", r"\bSailPoint\b",
r"\bOkta\b", r"\bZscaler\b", r"\bCarbon Black\b", r"\bSentinelOne\b",
r"\bQualys\b", r"\bTenable\b", r"\bRapid7\b", r"\bProofpoint\b",
r"\bMimecast\b", r"\bDarktrace\b", r"\bKnowBe4\b",
# Headcounts / years of experience
r"\b\d+\s+(?:years?|professionals?|employees?|staff|members?|individuals?)\b",
# Specific percentages
r"\b\d+(?:\.\d+)?%",
# Named universities in credential context
r"\bPh\.?D\.?\b", r"\bMaster'?s?\s+(?:of|in)\b",
]
QV_PATTERNS = [re.compile(p, re.IGNORECASE) for p in QV_PATTERNS_RAW]
def has_domain_terminology(text: str) -> bool:
return any(p.search(text) for p in DOMAIN_PATTERNS)
def has_qv_indicator(text: str) -> bool:
return any(p.search(text) for p in QV_PATTERNS)
def predict_v2_specificity(v1_spec: int, text: str) -> int:
"""Heuristic v2 specificity prediction from v1 consensus + text scan."""
if v1_spec == 1 and has_domain_terminology(text):
return 2
if v1_spec == 3 and has_qv_indicator(text):
return 4
return v1_spec
# ── Main ────────────────────────────────────────────────────────────────────
def main():
# Load paragraphs
print("Loading paragraphs...", file=sys.stderr)
paragraphs = {}
with open(PARAGRAPHS_PATH) as f:
for line in f:
p = json.loads(line)
paragraphs[p["id"]] = p
print(f" Loaded {len(paragraphs)} paragraphs", file=sys.stderr)
# Load v1 holdout IDs (to exclude — we don't want overlap)
v1_holdout = set()
if V1_HOLDOUT_PATH.exists():
v1_holdout = set(json.loads(V1_HOLDOUT_PATH.read_text()))
print(f" Loaded {len(v1_holdout)} v1 holdout IDs (will NOT exclude — they're eligible)", file=sys.stderr)
# Load stage1 annotations and compute consensus
print("Computing v1 consensus...", file=sys.stderr)
ann_by_pid: dict[str, list[dict]] = defaultdict(list)
with open(STAGE1_PATH) as f:
for line in f:
r = json.loads(line)
ann_by_pid[r["paragraphId"]].append(r["label"])
# Build consensus records
records = [] # (pid, v1_cat, v1_spec, v2_pred_spec, company, text)
for pid, labels in ann_by_pid.items():
if len(labels) != 3:
continue
if pid not in paragraphs:
continue
p = paragraphs[pid]
text = p["text"]
company = p["filing"]["cik"] # CIK is unique per company
# Majority vote
cats = [l["content_category"] for l in labels]
cat = Counter(cats).most_common(1)[0][0]
specs = [l["specificity_level"] for l in labels]
spec = Counter(specs).most_common(1)[0][0]
v2_spec = predict_v2_specificity(spec, text)
records.append({
"id": pid,
"v1_category": cat,
"v1_specificity": spec,
"v2_pred_specificity": v2_spec,
"company_cik": company,
"company_name": p["filing"]["companyName"],
"text_preview": text[:120],
"word_count": p["wordCount"],
"text": text,
})
print(f" {len(records)} paragraphs with 3-model consensus", file=sys.stderr)
# ── Distribution report ─────────────────────────────────────────────
print("\n v1 Category distribution:", file=sys.stderr)
cat_counts = Counter(r["v1_category"] for r in records)
for cat, count in sorted(cat_counts.items(), key=lambda x: -x[1]):
print(f" {cat:30s} {count:6d} ({count/len(records)*100:.1f}%)", file=sys.stderr)
print("\n v2 Predicted specificity distribution:", file=sys.stderr)
spec_counts = Counter(r["v2_pred_specificity"] for r in records)
for spec in sorted(spec_counts):
count = spec_counts[spec]
print(f" Level {spec}: {count:6d} ({count/len(records)*100:.1f}%)", file=sys.stderr)
# ── Stratified sampling ─────────────────────────────────────────────
print("\nSampling holdout...", file=sys.stderr)
# Group by category
by_category: dict[str, list[dict]] = defaultdict(list)
for r in records:
by_category[r["v1_category"]].append(r)
selected_ids: set[str] = set()
selected_records: list[dict] = []
for cat, alloc in CATEGORY_ALLOC.items():
pool = by_category.get(cat, [])
random.shuffle(pool)
# Enforce per-company cap
company_counts: dict[str, int] = defaultdict(int)
eligible = []
for r in pool:
cik = r["company_cik"]
if company_counts[cik] < MAX_PER_COMPANY_PER_STRATUM:
eligible.append(r)
company_counts[cik] += 1
# Take up to allocation
taken = eligible[:alloc]
if len(taken) < alloc:
print(f" WARNING: {cat} — only {len(taken)}/{alloc} available after company cap", file=sys.stderr)
for r in taken:
selected_ids.add(r["id"])
selected_records.append(r)
print(f" {cat:30s} selected {len(taken):4d}/{alloc} (pool {len(pool)})", file=sys.stderr)
print(f"\n Initial selection: {len(selected_records)}", file=sys.stderr)
# ── Specificity floor enforcement ───────────────────────────────────
spec_selected = Counter(r["v2_pred_specificity"] for r in selected_records)
print("\n Predicted v2 specificity in selection:", file=sys.stderr)
for spec in sorted(spec_selected):
count = spec_selected[spec]
status = "OK" if count >= SPECIFICITY_FLOOR else f"BELOW FLOOR ({SPECIFICITY_FLOOR})"
print(f" Level {spec}: {count:4d}{status}", file=sys.stderr)
# If any level is below floor, swap from over-represented levels
for target_level in [1, 2, 3, 4]:
deficit = SPECIFICITY_FLOOR - spec_selected.get(target_level, 0)
if deficit <= 0:
continue
print(f"\n Boosting Level {target_level} by {deficit}...", file=sys.stderr)
# Find candidates NOT yet selected, matching target specificity
candidates = [
r for r in records
if r["id"] not in selected_ids
and r["v2_pred_specificity"] == target_level
]
random.shuffle(candidates)
# Find swappable records from over-represented specificity levels
# Sort selected by how over-represented their specificity level is
over_levels = [
lvl for lvl in [1, 2, 3, 4]
if spec_selected.get(lvl, 0) > SPECIFICITY_FLOOR + 20 # only swap from levels well above floor
]
swappable = [
r for r in selected_records
if r["v2_pred_specificity"] in over_levels
]
random.shuffle(swappable)
swapped = 0
for cand in candidates:
if swapped >= deficit:
break
if not swappable:
break
# Find a swappable record from same category (maintain category balance)
for i, swap_r in enumerate(swappable):
if swap_r["v1_category"] == cand["v1_category"]:
# Swap
selected_ids.remove(swap_r["id"])
selected_records.remove(swap_r)
selected_ids.add(cand["id"])
selected_records.append(cand)
swappable.pop(i)
spec_selected[swap_r["v2_pred_specificity"]] -= 1
spec_selected[target_level] += 1
swapped += 1
break
if swapped < deficit:
print(f" Could only boost by {swapped}/{deficit} (not enough same-category swaps)", file=sys.stderr)
else:
print(f" Boosted by {swapped}", file=sys.stderr)
# ── Final report ────────────────────────────────────────────────────
print(f"\n Final selection: {len(selected_records)}", file=sys.stderr)
print("\n Final category distribution:", file=sys.stderr)
final_cat = Counter(r["v1_category"] for r in selected_records)
for cat, count in sorted(final_cat.items(), key=lambda x: -x[1]):
print(f" {cat:30s} {count:4d}", file=sys.stderr)
print("\n Final predicted v2 specificity distribution:", file=sys.stderr)
final_spec = Counter(r["v2_pred_specificity"] for r in selected_records)
for spec in sorted(final_spec):
count = final_spec[spec]
status = "OK" if count >= SPECIFICITY_FLOOR else "BELOW"
print(f" Level {spec}: {count:4d}{status}", file=sys.stderr)
# Company diversity check
companies = Counter(r["company_cik"] for r in selected_records)
print(f"\n Companies represented: {len(companies)}", file=sys.stderr)
print(f" Max paragraphs from one company: {companies.most_common(1)[0][1]}", file=sys.stderr)
# ── Write outputs ───────────────────────────────────────────────────
ids = sorted(r["id"] for r in selected_records)
OUTPUT_IDS.parent.mkdir(parents=True, exist_ok=True)
OUTPUT_IDS.write_text(json.dumps(ids, indent=2) + "\n")
print(f"\n Wrote {len(ids)} IDs to {OUTPUT_IDS}", file=sys.stderr)
with open(OUTPUT_MANIFEST, "w") as f:
for r in sorted(selected_records, key=lambda x: x["v1_category"]):
manifest = {
"id": r["id"],
"v1_category": r["v1_category"],
"v1_specificity": r["v1_specificity"],
"v2_pred_specificity": r["v2_pred_specificity"],
"company_cik": r["company_cik"],
"company_name": r["company_name"],
"word_count": r["word_count"],
"text_preview": r["text_preview"],
}
f.write(json.dumps(manifest) + "\n")
print(f" Wrote manifest to {OUTPUT_MANIFEST}", file=sys.stderr)
if __name__ == "__main__":
main()