227 lines
10 KiB
Python
227 lines
10 KiB
Python
"""Generate side-by-side comparison figures: CORAL baseline vs Independent threshold model."""
|
||
|
||
import json
|
||
from pathlib import Path
|
||
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import seaborn as sns
|
||
|
||
RESULTS_DIR = Path(__file__).resolve().parent.parent.parent / "results" / "eval"
|
||
OUTPUT_DIR = RESULTS_DIR / "comparison"
|
||
|
||
CATEGORIES = [
|
||
"Board Gov.",
|
||
"Incident Disc.",
|
||
"Mgmt Role",
|
||
"None/Other",
|
||
"Risk Mgmt Proc.",
|
||
"Strategy Int.",
|
||
"Third-Party",
|
||
]
|
||
SPEC_LABELS = ["L1: Generic", "L2: Domain", "L3: Firm-Spec.", "L4: Quantified"]
|
||
|
||
|
||
def load_metrics(model_dir: str) -> dict:
|
||
with open(RESULTS_DIR / model_dir / "metrics.json") as f:
|
||
return json.load(f)
|
||
|
||
|
||
def main():
|
||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
sns.set_theme(style="whitegrid", font_scale=1.1)
|
||
|
||
coral = load_metrics("coral-baseline")
|
||
indep = load_metrics("iter1-independent")
|
||
|
||
# Use GPT-5.4 as the reference (1200 samples, complete)
|
||
coral_gpt = coral["best-base_weighted_ce-ep5_vs_GPT-5.4"]
|
||
indep_gpt = indep["iter1-independent_vs_GPT-5.4"]
|
||
|
||
# ── 1. Side-by-side per-class F1 (Category) ─────────────────────────────
|
||
# Keys come from eval.py: name.replace(" ", "").replace("/", "")[:8]
|
||
cat_keys = ["BoardGov", "Incident", "Manageme", "NoneOthe", "RiskMana", "Strategy", "Third-Pa"]
|
||
|
||
coral_cat_f1 = [coral_gpt.get(f"cat_f1_{k}", 0) for k in cat_keys]
|
||
indep_cat_f1 = [indep_gpt.get(f"cat_f1_{k}", 0) for k in cat_keys]
|
||
|
||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
|
||
|
||
x = np.arange(len(CATEGORIES))
|
||
width = 0.35
|
||
bars1 = ax1.bar(x - width/2, coral_cat_f1, width, label="CORAL (Epoch 5)", color="#DD8452", alpha=0.85)
|
||
bars2 = ax1.bar(x + width/2, indep_cat_f1, width, label="Independent (Epoch 8)", color="#4C72B0", alpha=0.85)
|
||
ax1.axhline(0.80, color="red", linestyle="--", alpha=0.5, label="Target (0.80)")
|
||
ax1.set_ylabel("F1 Score")
|
||
ax1.set_title("Category F1 by Class")
|
||
ax1.set_xticks(x)
|
||
ax1.set_xticklabels(CATEGORIES, rotation=25, ha="right")
|
||
ax1.set_ylim(0, 1.05)
|
||
ax1.legend(loc="lower right")
|
||
|
||
for bar, v in zip(bars1, coral_cat_f1):
|
||
ax1.text(bar.get_x() + bar.get_width()/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=7)
|
||
for bar, v in zip(bars2, indep_cat_f1):
|
||
ax1.text(bar.get_x() + bar.get_width()/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=7)
|
||
|
||
# ── Specificity F1 side-by-side ──────────────────────────────────────────
|
||
# Keys come from eval.py: name.replace(" ", "").replace(":", "")[:8]
|
||
spec_keys = ["L1Generi", "L2Domain", "L3Firm-S", "L4Quanti"]
|
||
|
||
coral_spec_f1 = [coral_gpt.get(f"spec_f1_{k}", 0) for k in spec_keys]
|
||
indep_spec_f1 = [indep_gpt.get(f"spec_f1_{k}", 0) for k in spec_keys]
|
||
|
||
x2 = np.arange(len(SPEC_LABELS))
|
||
bars3 = ax2.bar(x2 - width/2, coral_spec_f1, width, label="CORAL (Epoch 5)", color="#DD8452", alpha=0.85)
|
||
bars4 = ax2.bar(x2 + width/2, indep_spec_f1, width, label="Independent (Epoch 8)", color="#4C72B0", alpha=0.85)
|
||
ax2.axhline(0.80, color="red", linestyle="--", alpha=0.5, label="Target (0.80)")
|
||
ax2.set_ylabel("F1 Score")
|
||
ax2.set_title("Specificity F1 by Level")
|
||
ax2.set_xticks(x2)
|
||
ax2.set_xticklabels(SPEC_LABELS)
|
||
ax2.set_ylim(0, 1.05)
|
||
ax2.legend(loc="lower right")
|
||
|
||
for bar, v in zip(bars3, coral_spec_f1):
|
||
ax2.text(bar.get_x() + bar.get_width()/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=8)
|
||
for bar, v in zip(bars4, indep_spec_f1):
|
||
ax2.text(bar.get_x() + bar.get_width()/2, v + 0.01, f"{v:.2f}", ha="center", va="bottom", fontsize=8)
|
||
|
||
plt.suptitle("CORAL Baseline vs Independent Thresholds — Holdout Set (vs GPT-5.4)", fontsize=14, fontweight="bold")
|
||
plt.tight_layout()
|
||
fig.savefig(OUTPUT_DIR / "coral_vs_independent_f1.png", dpi=200)
|
||
plt.close(fig)
|
||
print(f" Saved: coral_vs_independent_f1.png")
|
||
|
||
# ── 2. Summary metrics comparison ────────────────────────────────────────
|
||
metrics_to_compare = {
|
||
"Cat Macro F1": ("cat_macro_f1", "cat_macro_f1"),
|
||
"Spec Macro F1": ("spec_macro_f1", "spec_macro_f1"),
|
||
"Cat MCC": ("cat_mcc", "cat_mcc"),
|
||
"Spec MCC": ("spec_mcc", "spec_mcc"),
|
||
"Cat AUC": ("cat_auc", "cat_auc"),
|
||
"Spec AUC": ("spec_auc", "spec_auc"),
|
||
"Spec QWK": ("spec_qwk", "spec_qwk"),
|
||
"Cat Kripp α": ("cat_kripp_alpha", "cat_kripp_alpha"),
|
||
"Spec Kripp α": ("spec_kripp_alpha", "spec_kripp_alpha"),
|
||
}
|
||
|
||
fig, ax = plt.subplots(figsize=(12, 6))
|
||
labels = list(metrics_to_compare.keys())
|
||
coral_vals = [coral_gpt.get(v[0], 0) for v in metrics_to_compare.values()]
|
||
indep_vals = [indep_gpt.get(v[1], 0) for v in metrics_to_compare.values()]
|
||
|
||
x = np.arange(len(labels))
|
||
width = 0.35
|
||
ax.bar(x - width/2, coral_vals, width, label="CORAL (Epoch 5)", color="#DD8452", alpha=0.85)
|
||
ax.bar(x + width/2, indep_vals, width, label="Independent (Epoch 8)", color="#4C72B0", alpha=0.85)
|
||
ax.axhline(0.80, color="red", linestyle="--", alpha=0.5)
|
||
ax.set_ylabel("Score")
|
||
ax.set_title("CORAL vs Independent — All Metrics (Holdout vs GPT-5.4)")
|
||
ax.set_xticks(x)
|
||
ax.set_xticklabels(labels, rotation=30, ha="right")
|
||
ax.set_ylim(0, 1.1)
|
||
ax.legend()
|
||
|
||
for i, (cv, iv) in enumerate(zip(coral_vals, indep_vals)):
|
||
ax.text(i - width/2, cv + 0.01, f"{cv:.3f}", ha="center", va="bottom", fontsize=7)
|
||
ax.text(i + width/2, iv + 0.01, f"{iv:.3f}", ha="center", va="bottom", fontsize=7)
|
||
|
||
plt.tight_layout()
|
||
fig.savefig(OUTPUT_DIR / "coral_vs_independent_all_metrics.png", dpi=200)
|
||
plt.close(fig)
|
||
print(f" Saved: coral_vs_independent_all_metrics.png")
|
||
|
||
# ── 3. Delta chart (improvement from CORAL → Independent) ────────────────
|
||
deltas = [iv - cv for cv, iv in zip(coral_vals, indep_vals)]
|
||
colors = ["#55a868" if d >= 0 else "#c44e52" for d in deltas]
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 5))
|
||
ax.barh(labels, deltas, color=colors, alpha=0.85)
|
||
ax.axvline(0, color="black", linewidth=0.8)
|
||
ax.set_xlabel("Improvement (Independent − CORAL)")
|
||
ax.set_title("Metric Improvement: Independent Thresholds over CORAL")
|
||
for i, (d, label) in enumerate(zip(deltas, labels)):
|
||
ax.text(d + 0.003 if d >= 0 else d - 0.003, i, f"{d:+.3f}",
|
||
va="center", ha="left" if d >= 0 else "right", fontsize=9)
|
||
|
||
plt.tight_layout()
|
||
fig.savefig(OUTPUT_DIR / "improvement_delta.png", dpi=200)
|
||
plt.close(fig)
|
||
print(f" Saved: improvement_delta.png")
|
||
|
||
# ── 4. Specificity confusion matrix side-by-side ─────────────────────────
|
||
spec_labels_short = ["L1", "L2", "L3", "L4"]
|
||
|
||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 5))
|
||
|
||
cm_coral = np.array(coral_gpt["spec_confusion_matrix"])
|
||
cm_indep = np.array(indep_gpt["spec_confusion_matrix"])
|
||
|
||
cm_coral_norm = cm_coral.astype(float) / cm_coral.sum(axis=1, keepdims=True).clip(min=1)
|
||
cm_indep_norm = cm_indep.astype(float) / cm_indep.sum(axis=1, keepdims=True).clip(min=1)
|
||
|
||
sns.heatmap(cm_coral_norm, annot=cm_coral, fmt="d", cmap="Oranges",
|
||
xticklabels=spec_labels_short, yticklabels=spec_labels_short,
|
||
ax=ax1, vmin=0, vmax=1, cbar=False)
|
||
ax1.set_title("CORAL (Epoch 5) — Spec F1=0.597")
|
||
ax1.set_xlabel("Predicted")
|
||
ax1.set_ylabel("GPT-5.4 Reference")
|
||
|
||
sns.heatmap(cm_indep_norm, annot=cm_indep, fmt="d", cmap="Blues",
|
||
xticklabels=spec_labels_short, yticklabels=spec_labels_short,
|
||
ax=ax2, vmin=0, vmax=1, cbar=False)
|
||
ax2.set_title("Independent (Epoch 8) — Spec F1=0.895")
|
||
ax2.set_xlabel("Predicted")
|
||
ax2.set_ylabel("GPT-5.4 Reference")
|
||
|
||
plt.suptitle("Specificity Confusion Matrices — CORAL vs Independent", fontsize=13, fontweight="bold")
|
||
plt.tight_layout()
|
||
fig.savefig(OUTPUT_DIR / "spec_confusion_comparison.png", dpi=200)
|
||
plt.close(fig)
|
||
print(f" Saved: spec_confusion_comparison.png")
|
||
|
||
# ── 5. Cost/speed comparison table figure ────────────────────────────────
|
||
fig, ax = plt.subplots(figsize=(10, 4))
|
||
ax.axis("off")
|
||
|
||
table_data = [
|
||
["Metric", "CORAL (Ep5)", "Independent (Ep8)", "GPT-5.4 (API)", "Opus-4.6 (API)"],
|
||
["Cat Macro F1", f"{coral_gpt['cat_macro_f1']:.4f}", f"{indep_gpt['cat_macro_f1']:.4f}", "—(reference)", "—(reference)"],
|
||
["Spec Macro F1", f"{coral_gpt['spec_macro_f1']:.4f}", f"{indep_gpt['spec_macro_f1']:.4f}", "—(reference)", "—(reference)"],
|
||
["Spec QWK", f"{coral_gpt['spec_qwk']:.4f}", f"{indep_gpt['spec_qwk']:.4f}", "—", "—"],
|
||
["MCC (Cat)", f"{coral_gpt['cat_mcc']:.4f}", f"{indep_gpt['cat_mcc']:.4f}", "—", "—"],
|
||
["Latency/sample", "5.6ms", "5.6ms", "~2,900ms", "~6,000ms"],
|
||
["Cost/1M texts", "~$5", "~$5", "~$3,400", "~$5,000*"],
|
||
["Reproducible", "Yes", "Yes", "No", "No"],
|
||
]
|
||
|
||
table = ax.table(cellText=table_data[1:], colLabels=table_data[0],
|
||
cellLoc="center", loc="center")
|
||
table.auto_set_font_size(False)
|
||
table.set_fontsize(9)
|
||
table.scale(1, 1.5)
|
||
|
||
# Style header
|
||
for j in range(len(table_data[0])):
|
||
table[0, j].set_facecolor("#4C72B0")
|
||
table[0, j].set_text_props(color="white", fontweight="bold")
|
||
|
||
# Highlight best specialist column
|
||
for i in range(1, len(table_data)):
|
||
table[i, 2].set_facecolor("#d4edda")
|
||
|
||
ax.set_title("Model Comparison Summary", fontsize=13, fontweight="bold", pad=20)
|
||
plt.tight_layout()
|
||
fig.savefig(OUTPUT_DIR / "comparison_table.png", dpi=200)
|
||
plt.close(fig)
|
||
print(f" Saved: comparison_table.png")
|
||
|
||
print(f"\n All figures saved to {OUTPUT_DIR}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|