SEC-cyBERT/ts/scripts/model-bias-analysis.ts
2026-03-28 23:44:37 -04:00

471 lines
16 KiB
TypeScript

/**
* Model bias analysis for Stage 1 annotations.
* Identifies which model is the outlier most often, systematic biases,
* pairwise agreement, and category-specific dispute patterns.
*
* Usage: bun ts/scripts/model-bias-analysis.ts
*/
import { readJsonl, readJsonlRaw } from "../src/lib/jsonl.ts";
import { Paragraph } from "@sec-cybert/schemas/paragraph.ts";
const PARAGRAPHS_PATH = new URL(
"../../data/paragraphs/paragraphs-clean.jsonl",
import.meta.url,
).pathname;
const ANNOTATIONS_PATH = new URL(
"../../data/annotations/stage1.jsonl",
import.meta.url,
).pathname;
const MODELS = [
"google/gemini-3.1-flash-lite-preview",
"x-ai/grok-4.1-fast",
"xiaomi/mimo-v2-flash",
] as const;
type ModelId = (typeof MODELS)[number];
const SHORT: Record<ModelId, string> = {
"google/gemini-3.1-flash-lite-preview": "Gemini",
"x-ai/grok-4.1-fast": "Grok",
"xiaomi/mimo-v2-flash": "Mimo",
};
interface Ann {
paragraphId: string;
label: {
content_category: string;
specificity_level: number;
category_confidence: string;
specificity_confidence: string;
reasoning: string;
};
provenance: {
modelId: string;
costUsd: number;
inputTokens: number;
outputTokens: number;
reasoningTokens: number;
latencyMs: number;
requestedAt: string;
};
}
// ── Helpers ──────────────────────────────────────────────────────────────
function pct(n: number, total: number): string {
if (total === 0) return "0.0%";
return (100 * n / total).toFixed(1) + "%";
}
function padRight(s: string, len: number): string {
return s.length >= len ? s : s + " ".repeat(len - s.length);
}
function padLeft(s: string, len: number): string {
return s.length >= len ? s : " ".repeat(len - s.length) + s;
}
function printTable(headers: string[], rows: string[][], colWidths?: number[]) {
const widths =
colWidths ??
headers.map((h, i) =>
Math.max(h.length, ...rows.map((r) => (r[i] ?? "").length)),
);
const headerLine = headers.map((h, i) => padRight(h, widths[i])).join(" ");
const sep = widths.map((w) => "-".repeat(w)).join(" ");
console.log(headerLine);
console.log(sep);
for (const row of rows) {
console.log(row.map((c, i) => padRight(c, widths[i])).join(" "));
}
}
// ── Load data ────────────────────────────────────────────────────────────
console.log("Loading data...");
const [{ records: paragraphs, skipped: pSkip }, { records: rawAnns, skipped: aSkip }] =
await Promise.all([
readJsonl(PARAGRAPHS_PATH, Paragraph),
readJsonlRaw(ANNOTATIONS_PATH),
]);
const annotations = rawAnns as Ann[];
console.log(
`Loaded ${paragraphs.length} paragraphs (${pSkip} skipped), ${annotations.length} annotations (${aSkip} skipped)\n`,
);
// ── Group annotations by paragraphId ─────────────────────────────────────
const byParagraph = new Map<string, Map<ModelId, Ann>>();
for (const ann of annotations) {
const mid = ann.provenance.modelId as ModelId;
if (!MODELS.includes(mid)) continue;
if (!byParagraph.has(ann.paragraphId)) byParagraph.set(ann.paragraphId, new Map());
byParagraph.get(ann.paragraphId)!.set(mid, ann);
}
// Only keep paragraphs with all 3 models
const complete = new Map<string, Map<ModelId, Ann>>();
for (const [pid, models] of byParagraph) {
if (models.size === 3) complete.set(pid, models);
}
console.log(`Paragraphs with all 3 models: ${complete.size}\n`);
// ── 1. Outlier Analysis ──────────────────────────────────────────────────
console.log("=" .repeat(70));
console.log("1. OUTLIER ANALYSIS");
console.log("=" .repeat(70));
const catOutlierCount: Record<ModelId, number> = {
"google/gemini-3.1-flash-lite-preview": 0,
"x-ai/grok-4.1-fast": 0,
"xiaomi/mimo-v2-flash": 0,
};
const specOutlierCount: Record<ModelId, number> = { ...catOutlierCount };
// Reset specOutlierCount independently
for (const m of MODELS) specOutlierCount[m] = 0;
let catDisagree = 0;
let specDisagree = 0;
let catUnanimous = 0;
let specUnanimous = 0;
let threeWayDisagreeCat = 0;
let threeWayDisagreeSpec = 0;
for (const [, models] of complete) {
const cats = MODELS.map((m) => models.get(m)!.label.content_category);
const specs = MODELS.map((m) => models.get(m)!.label.specificity_level);
// Category outlier
if (cats[0] === cats[1] && cats[1] === cats[2]) {
catUnanimous++;
} else if (cats[0] === cats[1] && cats[2] !== cats[0]) {
catDisagree++;
catOutlierCount[MODELS[2]]++;
} else if (cats[0] === cats[2] && cats[1] !== cats[0]) {
catDisagree++;
catOutlierCount[MODELS[1]]++;
} else if (cats[1] === cats[2] && cats[0] !== cats[1]) {
catDisagree++;
catOutlierCount[MODELS[0]]++;
} else {
threeWayDisagreeCat++;
}
// Specificity outlier
if (specs[0] === specs[1] && specs[1] === specs[2]) {
specUnanimous++;
} else if (specs[0] === specs[1] && specs[2] !== specs[0]) {
specDisagree++;
specOutlierCount[MODELS[2]]++;
} else if (specs[0] === specs[2] && specs[1] !== specs[0]) {
specDisagree++;
specOutlierCount[MODELS[1]]++;
} else if (specs[1] === specs[2] && specs[0] !== specs[1]) {
specDisagree++;
specOutlierCount[MODELS[0]]++;
} else {
threeWayDisagreeSpec++;
}
}
console.log(`\nCategory: ${catUnanimous} unanimous, ${catDisagree} 2v1, ${threeWayDisagreeCat} three-way disagree`);
console.log("\nCategory outlier counts (when one model disagrees with the other two):");
printTable(
["Model", "Outlier Count", "% of 2v1"],
MODELS.map((m) => [SHORT[m], String(catOutlierCount[m]), pct(catOutlierCount[m], catDisagree)]),
);
console.log(`\nSpecificity: ${specUnanimous} unanimous, ${specDisagree} 2v1, ${threeWayDisagreeSpec} three-way disagree`);
console.log("\nSpecificity outlier counts:");
printTable(
["Model", "Outlier Count", "% of 2v1"],
MODELS.map((m) => [SHORT[m], String(specOutlierCount[m]), pct(specOutlierCount[m], specDisagree)]),
);
// ── 2. Category Bias ─────────────────────────────────────────────────────
console.log("\n" + "=" .repeat(70));
console.log("2. CATEGORY BIAS");
console.log("=" .repeat(70));
const allCategories = new Set<string>();
const catCounts: Record<ModelId, Record<string, number>> = {} as any;
for (const m of MODELS) catCounts[m] = {};
for (const ann of annotations) {
const mid = ann.provenance.modelId as ModelId;
if (!MODELS.includes(mid)) continue;
const cat = ann.label.content_category;
allCategories.add(cat);
catCounts[mid][cat] = (catCounts[mid][cat] ?? 0) + 1;
}
const categories = [...allCategories].sort();
const modelTotals: Record<ModelId, number> = {} as any;
for (const m of MODELS) {
modelTotals[m] = Object.values(catCounts[m]).reduce((a, b) => a + b, 0);
}
console.log("\nCategory distribution (% of each model's annotations):\n");
const catHeaders = ["Category", ...MODELS.map((m) => SHORT[m]), "Average"];
const catRows: string[][] = [];
for (const cat of categories) {
const pcts = MODELS.map((m) => (100 * (catCounts[m][cat] ?? 0)) / modelTotals[m]);
const avg = pcts.reduce((a, b) => a + b, 0) / 3;
catRows.push([cat, ...pcts.map((p) => p.toFixed(1) + "%"), avg.toFixed(1) + "%"]);
}
printTable(catHeaders, catRows);
console.log("\nOver/under-indexing vs average (percentage points):\n");
const biasHeaders = ["Category", ...MODELS.map((m) => SHORT[m])];
const biasRows: string[][] = [];
for (const cat of categories) {
const pcts = MODELS.map((m) => (100 * (catCounts[m][cat] ?? 0)) / modelTotals[m]);
const avg = pcts.reduce((a, b) => a + b, 0) / 3;
biasRows.push([
cat,
...pcts.map((p) => {
const diff = p - avg;
const sign = diff >= 0 ? "+" : "";
return sign + diff.toFixed(1) + "pp";
}),
]);
}
printTable(biasHeaders, biasRows);
// ── 3. Specificity Bias ──────────────────────────────────────────────────
console.log("\n" + "=" .repeat(70));
console.log("3. SPECIFICITY BIAS");
console.log("=" .repeat(70));
const specCounts: Record<ModelId, Record<number, number>> = {} as any;
for (const m of MODELS) specCounts[m] = {};
for (const ann of annotations) {
const mid = ann.provenance.modelId as ModelId;
if (!MODELS.includes(mid)) continue;
const spec = ann.label.specificity_level;
specCounts[mid][spec] = (specCounts[mid][spec] ?? 0) + 1;
}
const specLevels = [1, 2, 3, 4];
console.log("\nSpecificity distribution (% of each model's annotations):\n");
const specHeaders = ["Spec Level", ...MODELS.map((m) => SHORT[m]), "Average"];
const specRows: string[][] = [];
for (const lvl of specLevels) {
const pcts = MODELS.map((m) => (100 * (specCounts[m][lvl] ?? 0)) / modelTotals[m]);
const avg = pcts.reduce((a, b) => a + b, 0) / 3;
specRows.push([
String(lvl),
...pcts.map((p) => p.toFixed(1) + "%"),
avg.toFixed(1) + "%",
]);
}
printTable(specHeaders, specRows);
console.log("\nOver/under-indexing vs average (percentage points):\n");
const specBiasRows: string[][] = [];
for (const lvl of specLevels) {
const pcts = MODELS.map((m) => (100 * (specCounts[m][lvl] ?? 0)) / modelTotals[m]);
const avg = pcts.reduce((a, b) => a + b, 0) / 3;
specBiasRows.push([
String(lvl),
...pcts.map((p) => {
const diff = p - avg;
const sign = diff >= 0 ? "+" : "";
return sign + diff.toFixed(1) + "pp";
}),
]);
}
printTable(["Spec Level", ...MODELS.map((m) => SHORT[m])], specBiasRows);
// Mean specificity per model
console.log("\nMean specificity per model:");
for (const m of MODELS) {
let sum = 0;
let count = 0;
for (const [lvl, n] of Object.entries(specCounts[m])) {
sum += Number(lvl) * n;
count += n;
}
console.log(` ${SHORT[m]}: ${(sum / count).toFixed(3)}`);
}
// ── 4. Pairwise Agreement ────────────────────────────────────────────────
console.log("\n" + "=" .repeat(70));
console.log("4. PAIRWISE AGREEMENT");
console.log("=" .repeat(70));
const pairs: [ModelId, ModelId][] = [
[MODELS[0], MODELS[1]],
[MODELS[0], MODELS[2]],
[MODELS[1], MODELS[2]],
];
console.log("");
const pairHeaders = ["Pair", "Cat Agree", "Cat %", "Spec Agree", "Spec %", "Both Agree", "Both %"];
const pairRows: string[][] = [];
for (const [a, b] of pairs) {
let catAgree = 0;
let specAgree = 0;
let bothAgree = 0;
let total = 0;
for (const [, models] of complete) {
const annA = models.get(a)!;
const annB = models.get(b)!;
total++;
const cMatch = annA.label.content_category === annB.label.content_category;
const sMatch = annA.label.specificity_level === annB.label.specificity_level;
if (cMatch) catAgree++;
if (sMatch) specAgree++;
if (cMatch && sMatch) bothAgree++;
}
pairRows.push([
`${SHORT[a]} - ${SHORT[b]}`,
String(catAgree),
pct(catAgree, total),
String(specAgree),
pct(specAgree, total),
String(bothAgree),
pct(bothAgree, total),
]);
}
printTable(pairHeaders, pairRows);
// ── 5. Conditional Outlier ───────────────────────────────────────────────
console.log("\n" + "=" .repeat(70));
console.log("5. CONDITIONAL OUTLIER: What does the outlier model say?");
console.log("=" .repeat(70));
// For each model, when it's the category outlier, what label does it give vs the majority?
for (const outlierModel of MODELS) {
const wrongLabelDist: Record<string, Record<string, number>> = {};
// wrongLabelDist[majorityLabel][outlierLabel] = count
for (const [, models] of complete) {
const cats = MODELS.map((m) => models.get(m)!.label.content_category);
const others = MODELS.filter((m) => m !== outlierModel);
const otherCats = others.map((m) => models.get(m)!.label.content_category);
if (otherCats[0] !== otherCats[1]) continue; // not a 2v1 with this model as outlier
const majority = otherCats[0];
const outlierCat = models.get(outlierModel)!.label.content_category;
if (outlierCat === majority) continue; // this model agrees
if (!wrongLabelDist[majority]) wrongLabelDist[majority] = {};
wrongLabelDist[majority][outlierCat] = (wrongLabelDist[majority][outlierCat] ?? 0) + 1;
}
console.log(`\n${SHORT[outlierModel]} as outlier — what it says vs majority:`);
const majorityLabels = Object.keys(wrongLabelDist).sort();
if (majorityLabels.length === 0) {
console.log(" (no outlier cases)");
continue;
}
for (const maj of majorityLabels) {
const entries = Object.entries(wrongLabelDist[maj]).sort((a, b) => b[1] - a[1]);
const total = entries.reduce((s, [, n]) => s + n, 0);
console.log(` Majority="${maj}" (${total} cases):`);
for (const [label, count] of entries) {
console.log(` -> "${label}": ${count} (${pct(count, total)})`);
}
}
}
// ── 6. Spec 4 Analysis ──────────────────────────────────────────────────
console.log("\n" + "=" .repeat(70));
console.log("6. SPEC 4 ANALYSIS: Who disagrees when majority says Spec 4?");
console.log("=" .repeat(70));
const spec4Outliers: Record<ModelId, Record<number, number>> = {} as any;
for (const m of MODELS) spec4Outliers[m] = {};
let spec4DisagreeTotal = 0;
for (const [, models] of complete) {
const specs = MODELS.map((m) => models.get(m)!.label.specificity_level);
// Find if majority is 4
const countOf4 = specs.filter((s) => s === 4).length;
if (countOf4 < 2) continue; // majority is not 4
if (countOf4 === 3) continue; // unanimous
// One model disagrees
for (let i = 0; i < 3; i++) {
if (specs[i] !== 4) {
spec4DisagreeTotal++;
spec4Outliers[MODELS[i]][specs[i]] = (spec4Outliers[MODELS[i]][specs[i]] ?? 0) + 1;
}
}
}
console.log(`\nTotal paragraphs where majority=Spec4 but one disagrees: ${spec4DisagreeTotal}\n`);
for (const m of MODELS) {
const entries = Object.entries(spec4Outliers[m])
.map(([lvl, n]) => [Number(lvl), n] as [number, number])
.sort((a, b) => a[0] - b[0]);
const total = entries.reduce((s, [, n]) => s + n, 0);
if (total === 0) {
console.log(`${SHORT[m]}: never the outlier on Spec 4`);
continue;
}
console.log(`${SHORT[m]}: ${total} times the outlier (${pct(total, spec4DisagreeTotal)} of Spec4 disputes)`);
for (const [lvl, n] of entries) {
console.log(` -> says Spec ${lvl}: ${n} times`);
}
}
// ── 7. Management Role vs RMP Analysis ───────────────────────────────────
console.log("\n" + "=" .repeat(70));
console.log("7. MANAGEMENT ROLE vs RMP DISPUTES");
console.log("=" .repeat(70));
const mgmtRmpCounts: Record<ModelId, { management: number; rmp: number }> = {} as any;
for (const m of MODELS) mgmtRmpCounts[m] = { management: 0, rmp: 0 };
let mgmtRmpTotal = 0;
for (const [, models] of complete) {
const cats = MODELS.map((m) => models.get(m)!.label.content_category);
const catSet = new Set(cats);
// Check if this is a Management vs RMP dispute
const hasMgmt = catSet.has("Management Role");
const hasRmp = catSet.has("Risk Management Process");
if (!hasMgmt || !hasRmp) continue;
// Only consider paragraphs where the dispute is specifically between these two
const relevantCats = cats.filter(
(c) => c === "Management Role" || c === "Risk Management Process",
);
if (relevantCats.length < 2) continue; // at least 2 models must be in this dispute
mgmtRmpTotal++;
for (const m of MODELS) {
const cat = models.get(m)!.label.content_category;
if (cat === "Management Role") mgmtRmpCounts[m].management++;
if (cat === "Risk Management Process") mgmtRmpCounts[m].rmp++;
}
}
console.log(`\nParagraphs with Management Role vs RMP dispute: ${mgmtRmpTotal}\n`);
printTable(
["Model", "Says Management", "Says RMP", "Says Other"],
MODELS.map((m) => {
const other = mgmtRmpTotal - mgmtRmpCounts[m].management - mgmtRmpCounts[m].rmp;
return [
SHORT[m],
`${mgmtRmpCounts[m].management} (${pct(mgmtRmpCounts[m].management, mgmtRmpTotal)})`,
`${mgmtRmpCounts[m].rmp} (${pct(mgmtRmpCounts[m].rmp, mgmtRmpTotal)})`,
`${other} (${pct(other, mgmtRmpTotal)})`,
];
}),
);
console.log("\nDone.");