334 lines
12 KiB
Python
334 lines
12 KiB
Python
"""
|
|
Comprehensive comparison of v3.0 vs v3.5f prompt on the 359 confusion-axis holdout paragraphs.
|
|
Covers per-model accuracy, per-axis breakdown, SI/NO asymmetry, rankings, convergence, and cost.
|
|
"""
|
|
|
|
import json
|
|
from collections import Counter
|
|
from pathlib import Path
|
|
from itertools import combinations
|
|
|
|
import numpy as np
|
|
|
|
ROOT = Path("/home/joey/Documents/sec-cyBERT")
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model definitions
|
|
# ---------------------------------------------------------------------------
|
|
MODELS = [
|
|
("Opus", "golden", "opus"),
|
|
("GPT-5.4", "bench-holdout", "gpt-5.4"),
|
|
("Gemini-3.1-Pro", "bench-holdout", "gemini-3.1-pro-preview"),
|
|
("GLM-5", "bench-holdout", "glm-5:exacto"),
|
|
("Kimi-K2.5", "bench-holdout", "kimi-k2.5"),
|
|
("MIMO-v2-Pro", "bench-holdout", "mimo-v2-pro:exacto"),
|
|
("MiniMax-M2.7", "bench-holdout", "minimax-m2.7:exacto"),
|
|
]
|
|
|
|
CATEGORY_ABBREV = {
|
|
"None/Other": "N/O",
|
|
"Background": "BG",
|
|
"Risk Management Process": "RMP",
|
|
"Management Role": "MR",
|
|
"Strategy Integration": "SI",
|
|
}
|
|
|
|
def abbrev(cat: str) -> str:
|
|
return CATEGORY_ABBREV.get(cat, cat)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Data loading
|
|
# ---------------------------------------------------------------------------
|
|
def load_jsonl(path: Path) -> list[dict]:
|
|
rows = []
|
|
with open(path) as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
rows.append(json.loads(line))
|
|
return rows
|
|
|
|
|
|
def load_model_labels(version_suffix: str, subdir: str, filename: str) -> dict[str, str]:
|
|
"""Return {paragraphId: content_category} for a model file."""
|
|
if version_suffix:
|
|
base = ROOT / "data" / "annotations" / f"{subdir}-{version_suffix}" / f"{filename}.jsonl"
|
|
else:
|
|
base = ROOT / "data" / "annotations" / subdir / f"{filename}.jsonl"
|
|
rows = load_jsonl(base)
|
|
return {r["paragraphId"]: r["label"]["content_category"] for r in rows}
|
|
|
|
|
|
def load_model_rows(version_suffix: str, subdir: str, filename: str) -> list[dict]:
|
|
if version_suffix:
|
|
base = ROOT / "data" / "annotations" / f"{subdir}-{version_suffix}" / f"{filename}.jsonl"
|
|
else:
|
|
base = ROOT / "data" / "annotations" / subdir / f"{filename}.jsonl"
|
|
return load_jsonl(base)
|
|
|
|
|
|
# Load holdout PIDs and axes
|
|
holdout_rows = load_jsonl(ROOT / "data" / "gold" / "holdout-rerun-v35.jsonl")
|
|
HOLDOUT_PIDS = {r["paragraphId"] for r in holdout_rows}
|
|
PID_AXES: dict[str, list[str]] = {r["paragraphId"]: r["axes"] for r in holdout_rows}
|
|
|
|
# Human labels → majority vote per PID
|
|
human_raw = load_jsonl(ROOT / "data" / "gold" / "human-labels-raw.jsonl")
|
|
human_by_pid: dict[str, list[str]] = {}
|
|
for row in human_raw:
|
|
pid = row["paragraphId"]
|
|
if pid in HOLDOUT_PIDS:
|
|
human_by_pid.setdefault(pid, []).append(row["contentCategory"])
|
|
|
|
human_majority: dict[str, str] = {}
|
|
for pid, cats in human_by_pid.items():
|
|
counter = Counter(cats)
|
|
human_majority[pid] = counter.most_common(1)[0][0]
|
|
|
|
# Load v3.0 and v3.5f labels for all models
|
|
v30_labels: dict[str, dict[str, str]] = {} # model_name -> {pid: cat}
|
|
v35_labels: dict[str, dict[str, str]] = {}
|
|
v35_rows_by_model: dict[str, list[dict]] = {}
|
|
|
|
for name, subdir, filename in MODELS:
|
|
# v3.0: full 1200 file, filter to 359
|
|
all_v30 = load_model_labels("", subdir, filename)
|
|
v30_labels[name] = {pid: cat for pid, cat in all_v30.items() if pid in HOLDOUT_PIDS}
|
|
|
|
# v3.5f
|
|
suffix = "v35"
|
|
sub = f"golden" if subdir == "golden" else "bench-holdout"
|
|
v35_all = load_model_labels(suffix, sub, filename)
|
|
v35_labels[name] = {pid: cat for pid, cat in v35_all.items() if pid in HOLDOUT_PIDS}
|
|
|
|
v35_rows_by_model[name] = load_model_rows(suffix, sub, filename)
|
|
|
|
|
|
# Common PID set (intersection of all models in both versions + human majority)
|
|
common_pids = set(HOLDOUT_PIDS)
|
|
for name in [m[0] for m in MODELS]:
|
|
common_pids &= set(v30_labels[name].keys())
|
|
common_pids &= set(v35_labels[name].keys())
|
|
common_pids &= set(human_majority.keys())
|
|
common_pids_sorted = sorted(common_pids)
|
|
|
|
N = len(common_pids_sorted)
|
|
print(f"Common paragraphs across all models + human majority: {N}")
|
|
print()
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper: 6-model majority (excl MiniMax)
|
|
# ---------------------------------------------------------------------------
|
|
TOP6_NAMES = [m[0] for m in MODELS if m[0] != "MiniMax-M2.7"]
|
|
|
|
|
|
def majority_vote(labels_dict: dict[str, dict[str, str]], model_names: list[str], pid: str) -> str | None:
|
|
cats = []
|
|
for mn in model_names:
|
|
if pid in labels_dict[mn]:
|
|
cats.append(labels_dict[mn][pid])
|
|
if not cats:
|
|
return None
|
|
counter = Counter(cats)
|
|
return counter.most_common(1)[0][0]
|
|
|
|
|
|
# ===========================================================================
|
|
# 1. Per-model summary table
|
|
# ===========================================================================
|
|
print("=" * 90)
|
|
print("1. PER-MODEL SUMMARY TABLE (vs human majority)")
|
|
print("=" * 90)
|
|
header = f"{'Model':<20} {'v3.0 Acc':>10} {'v3.5f Acc':>10} {'Delta':>8} {'Change%':>9}"
|
|
print(header)
|
|
print("-" * len(header))
|
|
|
|
model_v30_acc = {}
|
|
model_v35_acc = {}
|
|
|
|
for name, _, _ in MODELS:
|
|
correct_30 = sum(1 for pid in common_pids_sorted if v30_labels[name][pid] == human_majority[pid])
|
|
correct_35 = sum(1 for pid in common_pids_sorted if v35_labels[name][pid] == human_majority[pid])
|
|
changed = sum(1 for pid in common_pids_sorted if v30_labels[name][pid] != v35_labels[name][pid])
|
|
|
|
acc30 = correct_30 / N
|
|
acc35 = correct_35 / N
|
|
delta = acc35 - acc30
|
|
change_rate = changed / N
|
|
|
|
model_v30_acc[name] = acc30
|
|
model_v35_acc[name] = acc35
|
|
|
|
print(f"{name:<20} {acc30:>9.1%} {acc35:>9.1%} {delta:>+7.1%} {change_rate:>8.1%}")
|
|
|
|
# 6-model majority row
|
|
correct_30_maj = 0
|
|
correct_35_maj = 0
|
|
changed_maj = 0
|
|
for pid in common_pids_sorted:
|
|
m30 = majority_vote(v30_labels, TOP6_NAMES, pid)
|
|
m35 = majority_vote(v35_labels, TOP6_NAMES, pid)
|
|
if m30 == human_majority[pid]:
|
|
correct_30_maj += 1
|
|
if m35 == human_majority[pid]:
|
|
correct_35_maj += 1
|
|
if m30 != m35:
|
|
changed_maj += 1
|
|
|
|
acc30_maj = correct_30_maj / N
|
|
acc35_maj = correct_35_maj / N
|
|
delta_maj = acc35_maj - acc30_maj
|
|
change_maj_rate = changed_maj / N
|
|
|
|
model_v30_acc["6-model majority"] = acc30_maj
|
|
model_v35_acc["6-model majority"] = acc35_maj
|
|
|
|
print("-" * len(header))
|
|
print(f"{'6-model maj (no MM)':<20} {acc30_maj:>9.1%} {acc35_maj:>9.1%} {delta_maj:>+7.1%} {change_maj_rate:>8.1%}")
|
|
print()
|
|
|
|
# ===========================================================================
|
|
# 2. Per-axis breakdown (6-model majority excl MiniMax)
|
|
# ===========================================================================
|
|
print("=" * 90)
|
|
print("2. PER-AXIS BREAKDOWN (6-model majority excl MiniMax vs human majority)")
|
|
print("=" * 90)
|
|
|
|
all_axes = sorted({ax for axes in PID_AXES.values() for ax in axes})
|
|
header2 = f"{'Axis':<12} {'N':>5} {'v3.0 Acc':>10} {'v3.5f Acc':>10} {'Delta':>8}"
|
|
print(header2)
|
|
print("-" * len(header2))
|
|
|
|
for axis in all_axes:
|
|
axis_pids = [pid for pid in common_pids_sorted if axis in PID_AXES.get(pid, [])]
|
|
n_axis = len(axis_pids)
|
|
if n_axis == 0:
|
|
continue
|
|
correct_30 = sum(1 for pid in axis_pids if majority_vote(v30_labels, TOP6_NAMES, pid) == human_majority[pid])
|
|
correct_35 = sum(1 for pid in axis_pids if majority_vote(v35_labels, TOP6_NAMES, pid) == human_majority[pid])
|
|
a30 = correct_30 / n_axis
|
|
a35 = correct_35 / n_axis
|
|
d = a35 - a30
|
|
print(f"{axis:<12} {n_axis:>5} {a30:>9.1%} {a35:>9.1%} {d:>+7.1%}")
|
|
|
|
print()
|
|
|
|
# ===========================================================================
|
|
# 3. SI ↔ N/O asymmetry check
|
|
# ===========================================================================
|
|
print("=" * 90)
|
|
print("3. SI <-> N/O ASYMMETRY CHECK")
|
|
print("=" * 90)
|
|
|
|
si_no_pids = [pid for pid in common_pids_sorted if "SI_NO" in PID_AXES.get(pid, [])]
|
|
print(f"SI↔N/O paragraphs in common set: {len(si_no_pids)}")
|
|
print()
|
|
|
|
for version_label, labels_dict in [("v3.0", v30_labels), ("v3.5f", v35_labels)]:
|
|
human_si_model_no = 0
|
|
human_no_model_si = 0
|
|
for pid in si_no_pids:
|
|
h = human_majority[pid]
|
|
m = majority_vote(labels_dict, TOP6_NAMES, pid)
|
|
if h == "Strategy Integration" and m == "None/Other":
|
|
human_si_model_no += 1
|
|
elif h == "None/Other" and m == "Strategy Integration":
|
|
human_no_model_si += 1
|
|
print(f"{version_label}:")
|
|
print(f" Human=SI, 6-model=N/O: {human_si_model_no}")
|
|
print(f" Human=N/O, 6-model=SI: {human_no_model_si}")
|
|
print()
|
|
|
|
# Also show per-model breakdown for SI↔N/O
|
|
print("Per-model SI↔N/O errors:")
|
|
header3 = f"{'Model':<20} {'v3.0 H=SI,M=NO':>16} {'v3.0 H=NO,M=SI':>16} {'v3.5 H=SI,M=NO':>16} {'v3.5 H=NO,M=SI':>16}"
|
|
print(header3)
|
|
print("-" * len(header3))
|
|
for name, _, _ in MODELS:
|
|
counts = []
|
|
for labels_dict in [v30_labels, v35_labels]:
|
|
hsi_mno = 0
|
|
hno_msi = 0
|
|
for pid in si_no_pids:
|
|
h = human_majority[pid]
|
|
m = labels_dict[name].get(pid)
|
|
if m is None:
|
|
continue
|
|
if h == "Strategy Integration" and m == "None/Other":
|
|
hsi_mno += 1
|
|
elif h == "None/Other" and m == "Strategy Integration":
|
|
hno_msi += 1
|
|
counts.extend([hsi_mno, hno_msi])
|
|
print(f"{name:<20} {counts[0]:>16} {counts[1]:>16} {counts[2]:>16} {counts[3]:>16}")
|
|
|
|
print()
|
|
|
|
# ===========================================================================
|
|
# 4. Per-model ranking
|
|
# ===========================================================================
|
|
print("=" * 90)
|
|
print("4. PER-MODEL RANKING")
|
|
print("=" * 90)
|
|
|
|
all_names = [m[0] for m in MODELS]
|
|
|
|
rank_v30 = sorted(all_names, key=lambda n: model_v30_acc[n], reverse=True)
|
|
rank_v35 = sorted(all_names, key=lambda n: model_v35_acc[n], reverse=True)
|
|
|
|
header4 = f"{'Rank':>4} {'v3.0 Model':<20} {'Acc':>8} {'v3.5f Model':<20} {'Acc':>8}"
|
|
print(header4)
|
|
print("-" * len(header4))
|
|
for i in range(len(all_names)):
|
|
n30 = rank_v30[i]
|
|
n35 = rank_v35[i]
|
|
print(f"{i+1:>4} {n30:<20} {model_v30_acc[n30]:>7.1%} {n35:<20} {model_v35_acc[n35]:>7.1%}")
|
|
|
|
print()
|
|
|
|
# ===========================================================================
|
|
# 5. Model convergence (average pairwise agreement)
|
|
# ===========================================================================
|
|
print("=" * 90)
|
|
print("5. MODEL CONVERGENCE (average pairwise agreement)")
|
|
print("=" * 90)
|
|
|
|
|
|
def avg_pairwise_agreement(labels_dict: dict[str, dict[str, str]], model_names: list[str], pids: list[str]) -> float:
|
|
agreements = []
|
|
for m1, m2 in combinations(model_names, 2):
|
|
agree = sum(1 for pid in pids if labels_dict[m1].get(pid) == labels_dict[m2].get(pid))
|
|
agreements.append(agree / len(pids))
|
|
return float(np.mean(agreements))
|
|
|
|
|
|
for group_label, group_names in [("All 7 models", all_names), ("Top 6 (excl MiniMax)", TOP6_NAMES)]:
|
|
a30 = avg_pairwise_agreement(v30_labels, group_names, common_pids_sorted)
|
|
a35 = avg_pairwise_agreement(v35_labels, group_names, common_pids_sorted)
|
|
delta = a35 - a30
|
|
print(f"{group_label}:")
|
|
print(f" v3.0 avg pairwise agreement: {a30:.1%}")
|
|
print(f" v3.5f avg pairwise agreement: {a35:.1%}")
|
|
print(f" Delta: {delta:+.1%}")
|
|
print()
|
|
|
|
# ===========================================================================
|
|
# 6. Cost summary
|
|
# ===========================================================================
|
|
print("=" * 90)
|
|
print("6. v3.5f RE-RUN COST SUMMARY")
|
|
print("=" * 90)
|
|
|
|
total_cost = 0.0
|
|
header6 = f"{'Model':<20} {'Records':>8} {'Cost ($)':>10}"
|
|
print(header6)
|
|
print("-" * len(header6))
|
|
for name, _, _ in MODELS:
|
|
rows = v35_rows_by_model[name]
|
|
cost = sum(r.get("provenance", {}).get("costUsd", 0) for r in rows)
|
|
total_cost += cost
|
|
print(f"{name:<20} {len(rows):>8} {cost:>10.4f}")
|
|
|
|
print("-" * len(header6))
|
|
print(f"{'TOTAL':<20} {'':<8} {total_cost:>10.4f}")
|
|
print()
|