SEC-cyBERT/scripts/audit_paragraphs.py
2026-03-29 20:33:39 -04:00

406 lines
15 KiB
Python

"""
Audit SEC-cyBERT paragraph corpus for boundary errors.
Run from project root: python3 scripts/audit_paragraphs.py
"""
import json
import random
import re
import sys
from collections import Counter, defaultdict
from pathlib import Path
DATA_PATH = Path("data/paragraphs/paragraphs-clean.jsonl")
def load_paragraphs():
paragraphs = []
with open(DATA_PATH) as f:
for line in f:
paragraphs.append(json.loads(line))
return paragraphs
def section_header(title):
bar = "=" * 80
print(f"\n{bar}")
print(f" {title}")
print(bar)
def truncate(text, n):
if len(text) <= n:
return text
return text[:n] + "..."
# ---------------------------------------------------------------------------
# Load
# ---------------------------------------------------------------------------
print("Loading paragraphs...")
paragraphs = load_paragraphs()
print(f"Loaded {len(paragraphs):,} paragraphs")
# Group by accessionNumber
by_filing = defaultdict(list)
for p in paragraphs:
acc = p["filing"]["accessionNumber"]
by_filing[acc].append(p)
print(f"Unique filings: {len(by_filing):,}")
# ---------------------------------------------------------------------------
# 1. Paragraphs-per-filing distribution
# ---------------------------------------------------------------------------
section_header("1. PARAGRAPHS-PER-FILING DISTRIBUTION")
counts = sorted([len(ps) for ps in by_filing.values()])
n = len(counts)
import math
mean = sum(counts) / n
variance = sum((c - mean) ** 2 for c in counts) / n
stdev = math.sqrt(variance)
def percentile(sorted_list, pct):
idx = pct / 100 * (len(sorted_list) - 1)
lo = int(math.floor(idx))
hi = int(math.ceil(idx))
if lo == hi:
return sorted_list[lo]
frac = idx - lo
return sorted_list[lo] * (1 - frac) + sorted_list[hi] * frac
print(f" Min: {counts[0]}")
print(f" P5: {percentile(counts, 5):.1f}")
print(f" P25: {percentile(counts, 25):.1f}")
print(f" Median: {percentile(counts, 50):.1f}")
print(f" P75: {percentile(counts, 75):.1f}")
print(f" P95: {percentile(counts, 95):.1f}")
print(f" Max: {counts[-1]}")
print(f" Stdev: {stdev:.2f}")
print(f" Mean: {mean:.2f}")
# Histogram buckets
buckets = [1, 2, 3, 5, 10, 15, 20, 30, 50, 100, 200]
print("\n Histogram:")
prev = 0
for b in buckets:
c = sum(1 for x in counts if prev < x <= b)
if c > 0:
print(f" ({prev+1}-{b}]: {c:>5} filings")
prev = b
c = sum(1 for x in counts if x > buckets[-1])
if c > 0:
print(f" (>{buckets[-1]}): {c:>5} filings")
# Fewest paragraphs
print("\n --- 10 filings with FEWEST paragraphs ---")
sorted_filings = sorted(by_filing.items(), key=lambda x: len(x[1]))
for acc, ps in sorted_filings[:10]:
company = ps[0]["filing"]["companyName"]
print(f"\n [{acc}] {company}{len(ps)} paragraph(s):")
for p in sorted(ps, key=lambda x: x["paragraphIndex"]):
print(f" p{p['paragraphIndex']} ({p['wordCount']}w): {truncate(p['text'], 150)}")
# Most paragraphs
print("\n --- 10 filings with MOST paragraphs ---")
for acc, ps in sorted_filings[-10:]:
company = ps[0]["filing"]["companyName"]
print(f"\n [{acc}] {company}{len(ps)} paragraph(s):")
for p in sorted(ps, key=lambda x: x["paragraphIndex"])[:5]:
print(f" p{p['paragraphIndex']} ({p['wordCount']}w): {truncate(p['text'], 150)}")
if len(ps) > 5:
print(f" ... ({len(ps) - 5} more)")
# ---------------------------------------------------------------------------
# 2. Suspiciously long paragraphs
# ---------------------------------------------------------------------------
section_header("2. SUSPICIOUSLY LONG PARAGRAPHS (top 20 by word count)")
sorted_by_wc = sorted(paragraphs, key=lambda p: p["wordCount"], reverse=True)
for i, p in enumerate(sorted_by_wc[:20]):
acc = p["filing"]["accessionNumber"]
company = p["filing"]["companyName"]
text = p["text"]
first200 = text[:200]
last200 = text[-200:] if len(text) > 400 else ""
print(f"\n #{i+1}: {p['wordCount']} words | p{p['paragraphIndex']} | {company}")
print(f" Acc: {acc}")
print(f" FIRST 200: {first200}")
if last200:
print(f" LAST 200: {last200}")
# Check for signs of merged paragraphs
issues = []
if p["wordCount"] > 300:
issues.append("VERY LONG (>300w)")
# Look for heading-like patterns mid-text (capitalized lines, bold markers)
lines = text.split("\n")
if len(lines) > 1:
issues.append(f"CONTAINS {len(lines)} LINES (possible merge)")
# Look for sentence-ending followed by topic shift
sentences = re.split(r'(?<=[.!?])\s+', text)
if len(sentences) > 8:
issues.append(f"{len(sentences)} sentences")
if issues:
print(f" FLAGS: {', '.join(issues)}")
# ---------------------------------------------------------------------------
# 3. Suspiciously short paragraphs
# ---------------------------------------------------------------------------
section_header("3. SUSPICIOUSLY SHORT PARAGRAPHS (<25 words)")
short = [p for p in paragraphs if p["wordCount"] < 25]
print(f"\n Total paragraphs <25 words: {len(short)} ({100*len(short)/len(paragraphs):.1f}%)")
# Categorize
headings = []
standalone = []
fragments = []
list_items = []
heading_patterns = re.compile(
r"^(risk management|cybersecurity|governance|strategy|board|"
r"oversight|incident|material|information security|"
r"risk factors|item 1c|risk management and strategy|"
r"risk management, strategy|governance, risk management)"
, re.IGNORECASE
)
for p in short:
text = p["text"].strip()
lower = text.lower()
# Heading detection: short, no period at end, title-case-ish
is_heading = False
if len(text.split()) <= 8 and not text.endswith("."):
is_heading = True
if heading_patterns.match(lower):
is_heading = True
if text.isupper() and len(text.split()) <= 10:
is_heading = True
# List item: starts with bullet, dash, number, or letter
is_list = bool(re.match(r"^(\d+[.)]\s|[-•●◦▪]\s|[a-z][.)]\s|\([a-z]\)\s|\(\d+\)\s)", text))
# Fragment: doesn't end with period/question/exclamation and not a heading
is_fragment = not is_heading and not is_list and not re.search(r'[.!?"]$', text.rstrip())
if is_heading:
headings.append(p)
elif is_list:
list_items.append(p)
elif is_fragment:
fragments.append(p)
else:
standalone.append(p)
print(f" Headings: {len(headings)}")
print(f" Standalone sentences:{len(standalone)}")
print(f" Fragments: {len(fragments)}")
print(f" List items: {len(list_items)}")
def show_examples(label, items, count):
sample = items[:count] if len(items) <= count else random.sample(items, count)
print(f"\n --- {label} (showing {len(sample)} of {len(items)}) ---")
for p in sample:
acc = p["filing"]["accessionNumber"]
print(f" [{p['wordCount']}w] p{p['paragraphIndex']} | {truncate(p['text'], 120)}")
print(f" {p['filing']['companyName']} | {acc}")
random.seed(42)
show_examples("Headings", headings, 10)
show_examples("Standalone sentences", standalone, 8)
show_examples("Fragments", fragments, 8)
show_examples("List items", list_items, 4)
# ---------------------------------------------------------------------------
# 4. Sequential paragraph coherence
# ---------------------------------------------------------------------------
section_header("4. SEQUENTIAL PARAGRAPH COHERENCE (20 random filings)")
random.seed(123)
sample_accs = random.sample(list(by_filing.keys()), min(20, len(by_filing)))
mid_sentence_breaks = []
topic_shifts = []
for acc in sample_accs:
ps = sorted(by_filing[acc], key=lambda x: x["paragraphIndex"])
for i in range(len(ps) - 1):
curr = ps[i]
nxt = ps[i + 1]
curr_text = curr["text"].strip()
nxt_text = nxt["text"].strip()
# Check: does current paragraph end mid-sentence?
# Signs: ends with comma, semicolon, conjunction, lowercase word, no terminal punctuation
ends_mid = False
if curr_text and not re.search(r'[.!?:"\)]$', curr_text):
ends_mid = True
if curr_text and re.search(r'(,|;|\band\b|\bor\b|\bbut\b|\bthat\b|\bwhich\b)\s*$', curr_text):
ends_mid = True
# Check: does next paragraph start with lowercase (continuation)?
starts_lower = bool(nxt_text) and nxt_text[0].islower()
if ends_mid or starts_lower:
mid_sentence_breaks.append({
"acc": acc,
"company": curr["filing"]["companyName"],
"curr_idx": curr["paragraphIndex"],
"nxt_idx": nxt["paragraphIndex"],
"curr_end": curr_text[-150:] if len(curr_text) > 150 else curr_text,
"nxt_start": nxt_text[:150] if len(nxt_text) > 150 else nxt_text,
"ends_mid": ends_mid,
"starts_lower": starts_lower,
})
print(f"\n Checked {len(sample_accs)} filings")
print(f" Potential mid-sentence breaks found: {len(mid_sentence_breaks)}")
print("\n --- Examples of mid-sentence / continuation breaks ---")
for ex in mid_sentence_breaks[:5]:
print(f"\n [{ex['acc']}] {ex['company']}")
print(f" p{ex['curr_idx']} ENDS: ...{ex['curr_end']}")
print(f" p{ex['nxt_idx']} STARTS: {ex['nxt_start']}...")
flags = []
if ex["ends_mid"]:
flags.append("no terminal punctuation")
if ex["starts_lower"]:
flags.append("next starts lowercase")
print(f" FLAGS: {', '.join(flags)}")
if len(mid_sentence_breaks) == 0:
print(" (none found)")
# Also check for topic shifts within single paragraphs (long ones in sampled filings)
print("\n --- Checking for intra-paragraph topic shifts ---")
shift_examples = []
for acc in sample_accs:
for p in by_filing[acc]:
if p["wordCount"] < 150:
continue
text = p["text"]
# Look for heading-like substrings mid-text
# e.g., "Risk Management" or "Governance" appearing after a sentence end
matches = list(re.finditer(
r'(?<=[.!?]\s)(Risk Management|Governance|Strategy|Cybersecurity|'
r'Board of Directors|Incident Response|Overview|Third.Party)',
text
))
if matches:
shift_examples.append({
"acc": acc,
"company": p["filing"]["companyName"],
"idx": p["paragraphIndex"],
"wordCount": p["wordCount"],
"match": matches[0].group(),
"context": text[max(0, matches[0].start()-80):matches[0].end()+80],
})
print(f" Paragraphs with possible embedded topic headers: {len(shift_examples)}")
for ex in shift_examples[:5]:
print(f"\n [{ex['acc']}] {ex['company']} p{ex['idx']} ({ex['wordCount']}w)")
print(f" Found '{ex['match']}' mid-paragraph:")
print(f" ...{ex['context']}...")
# ---------------------------------------------------------------------------
# 5. Paragraph index gaps
# ---------------------------------------------------------------------------
section_header("5. PARAGRAPH INDEX GAPS & DUPLICATES")
gap_filings = []
dup_filings = []
for acc, ps in by_filing.items():
indices = sorted(p["paragraphIndex"] for p in ps)
# Check for duplicates
if len(indices) != len(set(indices)):
counter = Counter(indices)
dups = {k: v for k, v in counter.items() if v > 1}
dup_filings.append((acc, ps[0]["filing"]["companyName"], dups))
# Check for gaps (should be 0, 1, 2, ...)
expected = list(range(indices[0], indices[0] + len(indices)))
if indices != expected:
missing = set(expected) - set(indices)
extra = set(indices) - set(expected)
if missing or extra:
gap_filings.append((acc, ps[0]["filing"]["companyName"], sorted(missing), sorted(extra), indices))
print(f"\n Filings with duplicate paragraph indices: {len(dup_filings)}")
for acc, company, dups in dup_filings[:10]:
print(f" [{acc}] {company}: duplicates at indices {dups}")
print(f"\n Filings with index gaps: {len(gap_filings)}")
for acc, company, missing, extra, indices in gap_filings[:10]:
print(f" [{acc}] {company}")
if missing:
print(f" Missing indices: {missing}")
if extra:
print(f" Unexpected indices: {extra}")
print(f" Actual indices: {indices}")
# Check if all start at 0
non_zero_start = [(acc, ps) for acc, ps in by_filing.items()
if min(p["paragraphIndex"] for p in ps) != 0]
print(f"\n Filings not starting at index 0: {len(non_zero_start)}")
for acc, ps in non_zero_start[:5]:
start = min(p["paragraphIndex"] for p in ps)
print(f" [{acc}] {ps[0]['filing']['companyName']}: starts at {start}")
# ---------------------------------------------------------------------------
# 6. Cross-filing duplicate paragraphs
# ---------------------------------------------------------------------------
section_header("6. CROSS-FILING DUPLICATE PARAGRAPHS")
# Group by textHash
by_hash = defaultdict(list)
for p in paragraphs:
by_hash[p["textHash"]].append(p)
# Find hashes appearing in multiple filings
cross_filing_dupes = {}
for h, ps in by_hash.items():
accs = set(p["filing"]["accessionNumber"] for p in ps)
if len(accs) > 1:
cross_filing_dupes[h] = ps
total_dupe_paragraphs = sum(len(ps) for ps in cross_filing_dupes.values())
unique_dupe_texts = len(cross_filing_dupes)
print(f"\n Unique paragraph texts appearing in >1 filing: {unique_dupe_texts}")
print(f" Total paragraphs that are cross-filing duplicates: {total_dupe_paragraphs} ({100*total_dupe_paragraphs/len(paragraphs):.1f}%)")
# Also count same-hash within same filing
within_filing_dupes = 0
for h, ps in by_hash.items():
accs = [p["filing"]["accessionNumber"] for p in ps]
if len(accs) != len(set(accs)):
within_filing_dupes += 1
print(f" Hashes duplicated WITHIN a single filing: {within_filing_dupes}")
# Top 20 most duplicated
sorted_dupes = sorted(cross_filing_dupes.items(), key=lambda x: len(x[1]), reverse=True)
print("\n --- Top 20 most duplicated texts across filings ---")
for i, (h, ps) in enumerate(sorted_dupes[:20]):
n_filings = len(set(p["filing"]["accessionNumber"] for p in ps))
text = ps[0]["text"]
print(f"\n #{i+1}: hash={h} | {n_filings} filings | {ps[0]['wordCount']}w")
print(f" TEXT: {truncate(text, 200)}")
# Boilerplate analysis: texts appearing in 3+ filings
boilerplate_threshold = 3
boilerplate_hashes = {h for h, ps in cross_filing_dupes.items()
if len(set(p["filing"]["accessionNumber"] for p in ps)) >= boilerplate_threshold}
boilerplate_paragraphs = sum(len(by_hash[h]) for h in boilerplate_hashes)
print(f"\n Boilerplate (text in {boilerplate_threshold}+ filings):")
print(f" Unique texts: {len(boilerplate_hashes)}")
print(f" Total paragraphs: {boilerplate_paragraphs} ({100*boilerplate_paragraphs/len(paragraphs):.1f}%)")
print("\n" + "=" * 80)
print(" AUDIT COMPLETE")
print("=" * 80)