347 lines
14 KiB
Python
347 lines
14 KiB
Python
"""
|
|
Sample v2 holdout set (1,200 paragraphs) using v1 Stage 1 consensus as guide.
|
|
|
|
Applies heuristic v2 specificity prediction:
|
|
- v1 L1 with domain terminology keywords → predicted L2
|
|
- v1 L3 with 1+ QV indicator → predicted L4
|
|
|
|
Stratified by category (185 per non-ID, 90 ID), with:
|
|
- Max 2 paragraphs per company per category stratum
|
|
- Secondary floor: ≥100 per predicted v2 specificity level
|
|
- Random within strata (no difficulty weighting)
|
|
|
|
Outputs:
|
|
- data/gold/v2-holdout-ids.json (list of paragraph IDs)
|
|
- data/gold/v2-holdout-manifest.jsonl (full metadata per paragraph)
|
|
"""
|
|
|
|
import json
|
|
import random
|
|
import re
|
|
import sys
|
|
from collections import Counter, defaultdict
|
|
from pathlib import Path
|
|
|
|
random.seed(42) # reproducible sampling
|
|
|
|
DATA = Path("data")
|
|
PARAGRAPHS_PATH = DATA / "paragraphs" / "paragraphs-clean.patched.jsonl"
|
|
STAGE1_PATH = DATA / "annotations" / "stage1.patched.jsonl"
|
|
V1_HOLDOUT_PATH = Path("labelapp") / ".sampled-ids.original.json"
|
|
OUTPUT_IDS = DATA / "gold" / "v2-holdout-ids.json"
|
|
OUTPUT_MANIFEST = DATA / "gold" / "v2-holdout-manifest.jsonl"
|
|
|
|
# ── Allocation ──────────────────────────────────────────────────────────────
|
|
|
|
TOTAL = 1200
|
|
CATEGORY_ALLOC = {
|
|
"Board Governance": 185,
|
|
"Management Role": 185,
|
|
"Risk Management Process": 185,
|
|
"Third-Party Risk": 185,
|
|
"Strategy Integration": 185,
|
|
"None/Other": 185,
|
|
"Incident Disclosure": 90,
|
|
}
|
|
assert sum(CATEGORY_ALLOC.values()) == TOTAL
|
|
|
|
SPECIFICITY_FLOOR = 100 # minimum per predicted v2 specificity level
|
|
MAX_PER_COMPANY_PER_STRATUM = 2
|
|
|
|
# ── Domain terminology keywords (v2 codebook) ──────────────────────────────
|
|
# Applied to v1 L1 paragraphs: any match → predicted L2
|
|
|
|
DOMAIN_TERMS = [
|
|
# Practices and activities
|
|
r"penetration test", r"pen test", r"vulnerability scan", r"vulnerability assess",
|
|
r"red team", r"phishing simul", r"security awareness training",
|
|
r"threat hunt", r"threat intellig", r"patch management",
|
|
r"identity and access management", r"\bIAM\b",
|
|
r"data loss prevention", r"\bDLP\b",
|
|
r"network segmentation", r"encryption",
|
|
# Tools and infrastructure
|
|
r"\bSIEM\b", r"security information and event management",
|
|
r"\bSOC\b", r"security operations center",
|
|
r"\bEDR\b", r"\bXDR\b", r"\bMDR\b", r"endpoint detection",
|
|
r"\bWAF\b", r"web application firewall",
|
|
r"\bIDS\b", r"\bIPS\b", r"intrusion detection", r"intrusion prevention",
|
|
r"\bMFA\b", r"multi-factor auth", r"two-factor auth", r"\b2FA\b",
|
|
r"\bfirewall\b", r"antivirus", r"anti-malware",
|
|
# Architectural concepts
|
|
r"zero trust", r"defense in depth", r"least privilege",
|
|
# Named standards (already L2 in v1, but catch any misses)
|
|
r"\bNIST\b", r"ISO 27001", r"ISO 27002", r"\bSOC 2\b", r"CIS Controls",
|
|
r"PCI[ -]DSS", r"\bHIPAA\b", r"\bGDPR\b", r"\bCOBIT\b", r"MITRE ATT",
|
|
# Threat types
|
|
r"ransomware", r"\bmalware\b", r"phishing",
|
|
r"\bDDoS\b", r"supply chain attack", r"social engineering",
|
|
r"advanced persistent threat", r"\bAPT\b", r"zero[- ]day",
|
|
]
|
|
DOMAIN_PATTERNS = [re.compile(t, re.IGNORECASE) for t in DOMAIN_TERMS]
|
|
|
|
# ── QV indicators (v2 codebook: 1+ → Level 4) ──────────────────────────────
|
|
# Applied to v1 L3 paragraphs: any match → predicted L4
|
|
|
|
QV_PATTERNS_RAW = [
|
|
r"\$[\d,]+", # dollar amounts
|
|
r"\b\d{1,3}(?:,\d{3})*\s*(?:million|billion|thousand)\b",
|
|
# Specific dates (month + year or exact)
|
|
r"\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{4}\b",
|
|
r"\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{4}\b",
|
|
# Named certifications
|
|
r"\bCISSP\b", r"\bCISM\b", r"\bCEH\b", r"\bCRISC\b",
|
|
# Named third-party firms
|
|
r"\bMandiant\b", r"\bCrowdStrike\b", r"\bDeloitte\b", r"\bKPMG\b",
|
|
r"\bPricewaterhouse\b", r"\bPwC\b", r"\bErnst\s*&\s*Young\b", r"\bEY\b",
|
|
r"\bAccenture\b", r"\bBooz Allen\b", r"\bProtiviti\b", r"\bKroll\b",
|
|
# Named products/tools
|
|
r"\bSplunk\b", r"\bAzure Sentinel\b", r"\bCrowdStrike Falcon\b",
|
|
r"\bServiceNow\b", r"\bPalo Alto\b", r"\bFortinet\b", r"\bSailPoint\b",
|
|
r"\bOkta\b", r"\bZscaler\b", r"\bCarbon Black\b", r"\bSentinelOne\b",
|
|
r"\bQualys\b", r"\bTenable\b", r"\bRapid7\b", r"\bProofpoint\b",
|
|
r"\bMimecast\b", r"\bDarktrace\b", r"\bKnowBe4\b",
|
|
# Headcounts / years of experience
|
|
r"\b\d+\s+(?:years?|professionals?|employees?|staff|members?|individuals?)\b",
|
|
# Specific percentages
|
|
r"\b\d+(?:\.\d+)?%",
|
|
# Named universities in credential context
|
|
r"\bPh\.?D\.?\b", r"\bMaster'?s?\s+(?:of|in)\b",
|
|
]
|
|
QV_PATTERNS = [re.compile(p, re.IGNORECASE) for p in QV_PATTERNS_RAW]
|
|
|
|
|
|
def has_domain_terminology(text: str) -> bool:
|
|
return any(p.search(text) for p in DOMAIN_PATTERNS)
|
|
|
|
|
|
def has_qv_indicator(text: str) -> bool:
|
|
return any(p.search(text) for p in QV_PATTERNS)
|
|
|
|
|
|
def predict_v2_specificity(v1_spec: int, text: str) -> int:
|
|
"""Heuristic v2 specificity prediction from v1 consensus + text scan."""
|
|
if v1_spec == 1 and has_domain_terminology(text):
|
|
return 2
|
|
if v1_spec == 3 and has_qv_indicator(text):
|
|
return 4
|
|
return v1_spec
|
|
|
|
|
|
# ── Main ────────────────────────────────────────────────────────────────────
|
|
|
|
def main():
|
|
# Load paragraphs
|
|
print("Loading paragraphs...", file=sys.stderr)
|
|
paragraphs = {}
|
|
with open(PARAGRAPHS_PATH) as f:
|
|
for line in f:
|
|
p = json.loads(line)
|
|
paragraphs[p["id"]] = p
|
|
print(f" Loaded {len(paragraphs)} paragraphs", file=sys.stderr)
|
|
|
|
# Load v1 holdout IDs (to exclude — we don't want overlap)
|
|
v1_holdout = set()
|
|
if V1_HOLDOUT_PATH.exists():
|
|
v1_holdout = set(json.loads(V1_HOLDOUT_PATH.read_text()))
|
|
print(f" Loaded {len(v1_holdout)} v1 holdout IDs (will NOT exclude — they're eligible)", file=sys.stderr)
|
|
|
|
# Load stage1 annotations and compute consensus
|
|
print("Computing v1 consensus...", file=sys.stderr)
|
|
ann_by_pid: dict[str, list[dict]] = defaultdict(list)
|
|
with open(STAGE1_PATH) as f:
|
|
for line in f:
|
|
r = json.loads(line)
|
|
ann_by_pid[r["paragraphId"]].append(r["label"])
|
|
|
|
# Build consensus records
|
|
records = [] # (pid, v1_cat, v1_spec, v2_pred_spec, company, text)
|
|
for pid, labels in ann_by_pid.items():
|
|
if len(labels) != 3:
|
|
continue
|
|
if pid not in paragraphs:
|
|
continue
|
|
|
|
p = paragraphs[pid]
|
|
text = p["text"]
|
|
company = p["filing"]["cik"] # CIK is unique per company
|
|
|
|
# Majority vote
|
|
cats = [l["content_category"] for l in labels]
|
|
cat = Counter(cats).most_common(1)[0][0]
|
|
|
|
specs = [l["specificity_level"] for l in labels]
|
|
spec = Counter(specs).most_common(1)[0][0]
|
|
|
|
v2_spec = predict_v2_specificity(spec, text)
|
|
|
|
records.append({
|
|
"id": pid,
|
|
"v1_category": cat,
|
|
"v1_specificity": spec,
|
|
"v2_pred_specificity": v2_spec,
|
|
"company_cik": company,
|
|
"company_name": p["filing"]["companyName"],
|
|
"text_preview": text[:120],
|
|
"word_count": p["wordCount"],
|
|
"text": text,
|
|
})
|
|
|
|
print(f" {len(records)} paragraphs with 3-model consensus", file=sys.stderr)
|
|
|
|
# ── Distribution report ─────────────────────────────────────────────
|
|
print("\n v1 Category distribution:", file=sys.stderr)
|
|
cat_counts = Counter(r["v1_category"] for r in records)
|
|
for cat, count in sorted(cat_counts.items(), key=lambda x: -x[1]):
|
|
print(f" {cat:30s} {count:6d} ({count/len(records)*100:.1f}%)", file=sys.stderr)
|
|
|
|
print("\n v2 Predicted specificity distribution:", file=sys.stderr)
|
|
spec_counts = Counter(r["v2_pred_specificity"] for r in records)
|
|
for spec in sorted(spec_counts):
|
|
count = spec_counts[spec]
|
|
print(f" Level {spec}: {count:6d} ({count/len(records)*100:.1f}%)", file=sys.stderr)
|
|
|
|
# ── Stratified sampling ─────────────────────────────────────────────
|
|
print("\nSampling holdout...", file=sys.stderr)
|
|
|
|
# Group by category
|
|
by_category: dict[str, list[dict]] = defaultdict(list)
|
|
for r in records:
|
|
by_category[r["v1_category"]].append(r)
|
|
|
|
selected_ids: set[str] = set()
|
|
selected_records: list[dict] = []
|
|
|
|
for cat, alloc in CATEGORY_ALLOC.items():
|
|
pool = by_category.get(cat, [])
|
|
random.shuffle(pool)
|
|
|
|
# Enforce per-company cap
|
|
company_counts: dict[str, int] = defaultdict(int)
|
|
eligible = []
|
|
for r in pool:
|
|
cik = r["company_cik"]
|
|
if company_counts[cik] < MAX_PER_COMPANY_PER_STRATUM:
|
|
eligible.append(r)
|
|
company_counts[cik] += 1
|
|
|
|
# Take up to allocation
|
|
taken = eligible[:alloc]
|
|
if len(taken) < alloc:
|
|
print(f" WARNING: {cat} — only {len(taken)}/{alloc} available after company cap", file=sys.stderr)
|
|
|
|
for r in taken:
|
|
selected_ids.add(r["id"])
|
|
selected_records.append(r)
|
|
|
|
print(f" {cat:30s} selected {len(taken):4d}/{alloc} (pool {len(pool)})", file=sys.stderr)
|
|
|
|
print(f"\n Initial selection: {len(selected_records)}", file=sys.stderr)
|
|
|
|
# ── Specificity floor enforcement ───────────────────────────────────
|
|
spec_selected = Counter(r["v2_pred_specificity"] for r in selected_records)
|
|
print("\n Predicted v2 specificity in selection:", file=sys.stderr)
|
|
for spec in sorted(spec_selected):
|
|
count = spec_selected[spec]
|
|
status = "OK" if count >= SPECIFICITY_FLOOR else f"BELOW FLOOR ({SPECIFICITY_FLOOR})"
|
|
print(f" Level {spec}: {count:4d} — {status}", file=sys.stderr)
|
|
|
|
# If any level is below floor, swap from over-represented levels
|
|
for target_level in [1, 2, 3, 4]:
|
|
deficit = SPECIFICITY_FLOOR - spec_selected.get(target_level, 0)
|
|
if deficit <= 0:
|
|
continue
|
|
|
|
print(f"\n Boosting Level {target_level} by {deficit}...", file=sys.stderr)
|
|
|
|
# Find candidates NOT yet selected, matching target specificity
|
|
candidates = [
|
|
r for r in records
|
|
if r["id"] not in selected_ids
|
|
and r["v2_pred_specificity"] == target_level
|
|
]
|
|
random.shuffle(candidates)
|
|
|
|
# Find swappable records from over-represented specificity levels
|
|
# Sort selected by how over-represented their specificity level is
|
|
over_levels = [
|
|
lvl for lvl in [1, 2, 3, 4]
|
|
if spec_selected.get(lvl, 0) > SPECIFICITY_FLOOR + 20 # only swap from levels well above floor
|
|
]
|
|
|
|
swappable = [
|
|
r for r in selected_records
|
|
if r["v2_pred_specificity"] in over_levels
|
|
]
|
|
random.shuffle(swappable)
|
|
|
|
swapped = 0
|
|
for cand in candidates:
|
|
if swapped >= deficit:
|
|
break
|
|
if not swappable:
|
|
break
|
|
|
|
# Find a swappable record from same category (maintain category balance)
|
|
for i, swap_r in enumerate(swappable):
|
|
if swap_r["v1_category"] == cand["v1_category"]:
|
|
# Swap
|
|
selected_ids.remove(swap_r["id"])
|
|
selected_records.remove(swap_r)
|
|
selected_ids.add(cand["id"])
|
|
selected_records.append(cand)
|
|
swappable.pop(i)
|
|
spec_selected[swap_r["v2_pred_specificity"]] -= 1
|
|
spec_selected[target_level] += 1
|
|
swapped += 1
|
|
break
|
|
|
|
if swapped < deficit:
|
|
print(f" Could only boost by {swapped}/{deficit} (not enough same-category swaps)", file=sys.stderr)
|
|
else:
|
|
print(f" Boosted by {swapped}", file=sys.stderr)
|
|
|
|
# ── Final report ────────────────────────────────────────────────────
|
|
print(f"\n Final selection: {len(selected_records)}", file=sys.stderr)
|
|
|
|
print("\n Final category distribution:", file=sys.stderr)
|
|
final_cat = Counter(r["v1_category"] for r in selected_records)
|
|
for cat, count in sorted(final_cat.items(), key=lambda x: -x[1]):
|
|
print(f" {cat:30s} {count:4d}", file=sys.stderr)
|
|
|
|
print("\n Final predicted v2 specificity distribution:", file=sys.stderr)
|
|
final_spec = Counter(r["v2_pred_specificity"] for r in selected_records)
|
|
for spec in sorted(final_spec):
|
|
count = final_spec[spec]
|
|
status = "OK" if count >= SPECIFICITY_FLOOR else "BELOW"
|
|
print(f" Level {spec}: {count:4d} — {status}", file=sys.stderr)
|
|
|
|
# Company diversity check
|
|
companies = Counter(r["company_cik"] for r in selected_records)
|
|
print(f"\n Companies represented: {len(companies)}", file=sys.stderr)
|
|
print(f" Max paragraphs from one company: {companies.most_common(1)[0][1]}", file=sys.stderr)
|
|
|
|
# ── Write outputs ───────────────────────────────────────────────────
|
|
ids = sorted(r["id"] for r in selected_records)
|
|
OUTPUT_IDS.parent.mkdir(parents=True, exist_ok=True)
|
|
OUTPUT_IDS.write_text(json.dumps(ids, indent=2) + "\n")
|
|
print(f"\n Wrote {len(ids)} IDs to {OUTPUT_IDS}", file=sys.stderr)
|
|
|
|
with open(OUTPUT_MANIFEST, "w") as f:
|
|
for r in sorted(selected_records, key=lambda x: x["v1_category"]):
|
|
manifest = {
|
|
"id": r["id"],
|
|
"v1_category": r["v1_category"],
|
|
"v1_specificity": r["v1_specificity"],
|
|
"v2_pred_specificity": r["v2_pred_specificity"],
|
|
"company_cik": r["company_cik"],
|
|
"company_name": r["company_name"],
|
|
"word_count": r["word_count"],
|
|
"text_preview": r["text_preview"],
|
|
}
|
|
f.write(json.dumps(manifest) + "\n")
|
|
print(f" Wrote manifest to {OUTPUT_MANIFEST}", file=sys.stderr)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|