SEC-cyBERT/python/scripts/generate-comparison-figures.py
2026-04-05 15:37:50 -04:00

227 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Generate side-by-side comparison figures: CORAL baseline vs Independent threshold model."""
import json
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
RESULTS_DIR = Path(__file__).resolve().parent.parent.parent / "results" / "eval"
OUTPUT_DIR = RESULTS_DIR / "comparison"
CATEGORIES = [
"Board Gov.",
"Incident Disc.",
"Mgmt Role",
"None/Other",
"Risk Mgmt Proc.",
"Strategy Int.",
"Third-Party",
]
SPEC_LABELS = ["L1: Generic", "L2: Domain", "L3: Firm-Spec.", "L4: Quantified"]
def load_metrics(model_dir: str) -> dict:
with open(RESULTS_DIR / model_dir / "metrics.json") as f:
return json.load(f)
def main():
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
sns.set_theme(style="whitegrid", font_scale=1.1)
coral = load_metrics("coral-baseline")
indep = load_metrics("iter1-independent")
# Use GPT-5.4 as the reference (1200 samples, complete)
coral_gpt = coral["best-base_weighted_ce-ep5_vs_GPT-5.4"]
indep_gpt = indep["iter1-independent_vs_GPT-5.4"]
# ── 1. Side-by-side per-class F1 (Category) ─────────────────────────────
# Keys come from eval.py: name.replace(" ", "").replace("/", "")[:8]
cat_keys = ["BoardGov", "Incident", "Manageme", "NoneOthe", "RiskMana", "Strategy", "Third-Pa"]
coral_cat_f1 = [coral_gpt.get(f"cat_f1_{k}", 0) for k in cat_keys]
indep_cat_f1 = [indep_gpt.get(f"cat_f1_{k}", 0) for k in cat_keys]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
x = np.arange(len(CATEGORIES))
width = 0.35
bars1 = ax1.bar(x - width/2, coral_cat_f1, width, label="CORAL (Epoch 5)", color="#DD8452", alpha=0.85)
bars2 = ax1.bar(x + width/2, indep_cat_f1, width, label="Independent (Epoch 8)", color="#4C72B0", alpha=0.85)
ax1.axhline(0.80, color="red", linestyle="--", alpha=0.5, label="Target (0.80)")
ax1.set_ylabel("F1 Score")
ax1.set_title("Category F1 by Class")
ax1.set_xticks(x)
ax1.set_xticklabels(CATEGORIES, rotation=25, ha="right")
ax1.set_ylim(0, 1.05)
ax1.legend(loc="lower right")
for bar, v in zip(bars1, coral_cat_f1):
ax1.text(bar.get_x() + bar.get_width()/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=7)
for bar, v in zip(bars2, indep_cat_f1):
ax1.text(bar.get_x() + bar.get_width()/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=7)
# ── Specificity F1 side-by-side ──────────────────────────────────────────
# Keys come from eval.py: name.replace(" ", "").replace(":", "")[:8]
spec_keys = ["L1Generi", "L2Domain", "L3Firm-S", "L4Quanti"]
coral_spec_f1 = [coral_gpt.get(f"spec_f1_{k}", 0) for k in spec_keys]
indep_spec_f1 = [indep_gpt.get(f"spec_f1_{k}", 0) for k in spec_keys]
x2 = np.arange(len(SPEC_LABELS))
bars3 = ax2.bar(x2 - width/2, coral_spec_f1, width, label="CORAL (Epoch 5)", color="#DD8452", alpha=0.85)
bars4 = ax2.bar(x2 + width/2, indep_spec_f1, width, label="Independent (Epoch 8)", color="#4C72B0", alpha=0.85)
ax2.axhline(0.80, color="red", linestyle="--", alpha=0.5, label="Target (0.80)")
ax2.set_ylabel("F1 Score")
ax2.set_title("Specificity F1 by Level")
ax2.set_xticks(x2)
ax2.set_xticklabels(SPEC_LABELS)
ax2.set_ylim(0, 1.05)
ax2.legend(loc="lower right")
for bar, v in zip(bars3, coral_spec_f1):
ax2.text(bar.get_x() + bar.get_width()/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=8)
for bar, v in zip(bars4, indep_spec_f1):
ax2.text(bar.get_x() + bar.get_width()/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=8)
plt.suptitle("CORAL Baseline vs Independent Thresholds — Holdout Set (vs GPT-5.4)", fontsize=14, fontweight="bold")
plt.tight_layout()
fig.savefig(OUTPUT_DIR / "coral_vs_independent_f1.png", dpi=200)
plt.close(fig)
print(f" Saved: coral_vs_independent_f1.png")
# ── 2. Summary metrics comparison ────────────────────────────────────────
metrics_to_compare = {
"Cat Macro F1": ("cat_macro_f1", "cat_macro_f1"),
"Spec Macro F1": ("spec_macro_f1", "spec_macro_f1"),
"Cat MCC": ("cat_mcc", "cat_mcc"),
"Spec MCC": ("spec_mcc", "spec_mcc"),
"Cat AUC": ("cat_auc", "cat_auc"),
"Spec AUC": ("spec_auc", "spec_auc"),
"Spec QWK": ("spec_qwk", "spec_qwk"),
"Cat Kripp α": ("cat_kripp_alpha", "cat_kripp_alpha"),
"Spec Kripp α": ("spec_kripp_alpha", "spec_kripp_alpha"),
}
fig, ax = plt.subplots(figsize=(12, 6))
labels = list(metrics_to_compare.keys())
coral_vals = [coral_gpt.get(v[0], 0) for v in metrics_to_compare.values()]
indep_vals = [indep_gpt.get(v[1], 0) for v in metrics_to_compare.values()]
x = np.arange(len(labels))
width = 0.35
ax.bar(x - width/2, coral_vals, width, label="CORAL (Epoch 5)", color="#DD8452", alpha=0.85)
ax.bar(x + width/2, indep_vals, width, label="Independent (Epoch 8)", color="#4C72B0", alpha=0.85)
ax.axhline(0.80, color="red", linestyle="--", alpha=0.5)
ax.set_ylabel("Score")
ax.set_title("CORAL vs Independent — All Metrics (Holdout vs GPT-5.4)")
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=30, ha="right")
ax.set_ylim(0, 1.1)
ax.legend()
for i, (cv, iv) in enumerate(zip(coral_vals, indep_vals)):
ax.text(i - width/2, cv + 0.01, f"{cv:.3f}", ha="center", va="bottom", fontsize=7)
ax.text(i + width/2, iv + 0.01, f"{iv:.3f}", ha="center", va="bottom", fontsize=7)
plt.tight_layout()
fig.savefig(OUTPUT_DIR / "coral_vs_independent_all_metrics.png", dpi=200)
plt.close(fig)
print(f" Saved: coral_vs_independent_all_metrics.png")
# ── 3. Delta chart (improvement from CORAL → Independent) ────────────────
deltas = [iv - cv for cv, iv in zip(coral_vals, indep_vals)]
colors = ["#55a868" if d >= 0 else "#c44e52" for d in deltas]
fig, ax = plt.subplots(figsize=(10, 5))
ax.barh(labels, deltas, color=colors, alpha=0.85)
ax.axvline(0, color="black", linewidth=0.8)
ax.set_xlabel("Improvement (Independent CORAL)")
ax.set_title("Metric Improvement: Independent Thresholds over CORAL")
for i, (d, label) in enumerate(zip(deltas, labels)):
ax.text(d + 0.003 if d >= 0 else d - 0.003, i, f"{d:+.3f}",
va="center", ha="left" if d >= 0 else "right", fontsize=9)
plt.tight_layout()
fig.savefig(OUTPUT_DIR / "improvement_delta.png", dpi=200)
plt.close(fig)
print(f" Saved: improvement_delta.png")
# ── 4. Specificity confusion matrix side-by-side ─────────────────────────
spec_labels_short = ["L1", "L2", "L3", "L4"]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))
cm_coral = np.array(coral_gpt["spec_confusion_matrix"])
cm_indep = np.array(indep_gpt["spec_confusion_matrix"])
cm_coral_norm = cm_coral.astype(float) / cm_coral.sum(axis=1, keepdims=True).clip(min=1)
cm_indep_norm = cm_indep.astype(float) / cm_indep.sum(axis=1, keepdims=True).clip(min=1)
sns.heatmap(cm_coral_norm, annot=cm_coral, fmt="d", cmap="Oranges",
xticklabels=spec_labels_short, yticklabels=spec_labels_short,
ax=ax1, vmin=0, vmax=1, cbar=False)
ax1.set_title("CORAL (Epoch 5) — Spec F1=0.597")
ax1.set_xlabel("Predicted")
ax1.set_ylabel("GPT-5.4 Reference")
sns.heatmap(cm_indep_norm, annot=cm_indep, fmt="d", cmap="Blues",
xticklabels=spec_labels_short, yticklabels=spec_labels_short,
ax=ax2, vmin=0, vmax=1, cbar=False)
ax2.set_title("Independent (Epoch 8) — Spec F1=0.895")
ax2.set_xlabel("Predicted")
ax2.set_ylabel("GPT-5.4 Reference")
plt.suptitle("Specificity Confusion Matrices — CORAL vs Independent", fontsize=13, fontweight="bold")
plt.tight_layout()
fig.savefig(OUTPUT_DIR / "spec_confusion_comparison.png", dpi=200)
plt.close(fig)
print(f" Saved: spec_confusion_comparison.png")
# ── 5. Cost/speed comparison table figure ────────────────────────────────
fig, ax = plt.subplots(figsize=(10, 4))
ax.axis("off")
table_data = [
["Metric", "CORAL (Ep5)", "Independent (Ep8)", "GPT-5.4 (API)", "Opus-4.6 (API)"],
["Cat Macro F1", f"{coral_gpt['cat_macro_f1']:.4f}", f"{indep_gpt['cat_macro_f1']:.4f}", "—(reference)", "—(reference)"],
["Spec Macro F1", f"{coral_gpt['spec_macro_f1']:.4f}", f"{indep_gpt['spec_macro_f1']:.4f}", "—(reference)", "—(reference)"],
["Spec QWK", f"{coral_gpt['spec_qwk']:.4f}", f"{indep_gpt['spec_qwk']:.4f}", "", ""],
["MCC (Cat)", f"{coral_gpt['cat_mcc']:.4f}", f"{indep_gpt['cat_mcc']:.4f}", "", ""],
["Latency/sample", "5.6ms", "5.6ms", "~2,900ms", "~6,000ms"],
["Cost/1M texts", "~$5", "~$5", "~$3,400", "~$5,000*"],
["Reproducible", "Yes", "Yes", "No", "No"],
]
table = ax.table(cellText=table_data[1:], colLabels=table_data[0],
cellLoc="center", loc="center")
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1, 1.5)
# Style header
for j in range(len(table_data[0])):
table[0, j].set_facecolor("#4C72B0")
table[0, j].set_text_props(color="white", fontweight="bold")
# Highlight best specialist column
for i in range(1, len(table_data)):
table[i, 2].set_facecolor("#d4edda")
ax.set_title("Model Comparison Summary", fontsize=13, fontweight="bold", pad=20)
plt.tight_layout()
fig.savefig(OUTPUT_DIR / "comparison_table.png", dpi=200)
plt.close(fig)
print(f" Saved: comparison_table.png")
print(f"\n All figures saved to {OUTPUT_DIR}")
if __name__ == "__main__":
main()