471 lines
16 KiB
TypeScript
471 lines
16 KiB
TypeScript
/**
|
|
* Model bias analysis for Stage 1 annotations.
|
|
* Identifies which model is the outlier most often, systematic biases,
|
|
* pairwise agreement, and category-specific dispute patterns.
|
|
*
|
|
* Usage: bun ts/scripts/model-bias-analysis.ts
|
|
*/
|
|
import { readJsonl, readJsonlRaw } from "../src/lib/jsonl.ts";
|
|
import { Paragraph } from "@sec-cybert/schemas/paragraph.ts";
|
|
|
|
const PARAGRAPHS_PATH = new URL(
|
|
"../../data/paragraphs/paragraphs-clean.jsonl",
|
|
import.meta.url,
|
|
).pathname;
|
|
const ANNOTATIONS_PATH = new URL(
|
|
"../../data/annotations/stage1.jsonl",
|
|
import.meta.url,
|
|
).pathname;
|
|
|
|
const MODELS = [
|
|
"google/gemini-3.1-flash-lite-preview",
|
|
"x-ai/grok-4.1-fast",
|
|
"xiaomi/mimo-v2-flash",
|
|
] as const;
|
|
type ModelId = (typeof MODELS)[number];
|
|
|
|
const SHORT: Record<ModelId, string> = {
|
|
"google/gemini-3.1-flash-lite-preview": "Gemini",
|
|
"x-ai/grok-4.1-fast": "Grok",
|
|
"xiaomi/mimo-v2-flash": "Mimo",
|
|
};
|
|
|
|
interface Ann {
|
|
paragraphId: string;
|
|
label: {
|
|
content_category: string;
|
|
specificity_level: number;
|
|
category_confidence: string;
|
|
specificity_confidence: string;
|
|
reasoning: string;
|
|
};
|
|
provenance: {
|
|
modelId: string;
|
|
costUsd: number;
|
|
inputTokens: number;
|
|
outputTokens: number;
|
|
reasoningTokens: number;
|
|
latencyMs: number;
|
|
requestedAt: string;
|
|
};
|
|
}
|
|
|
|
// ── Helpers ──────────────────────────────────────────────────────────────
|
|
|
|
function pct(n: number, total: number): string {
|
|
if (total === 0) return "0.0%";
|
|
return (100 * n / total).toFixed(1) + "%";
|
|
}
|
|
|
|
function padRight(s: string, len: number): string {
|
|
return s.length >= len ? s : s + " ".repeat(len - s.length);
|
|
}
|
|
|
|
function padLeft(s: string, len: number): string {
|
|
return s.length >= len ? s : " ".repeat(len - s.length) + s;
|
|
}
|
|
|
|
function printTable(headers: string[], rows: string[][], colWidths?: number[]) {
|
|
const widths =
|
|
colWidths ??
|
|
headers.map((h, i) =>
|
|
Math.max(h.length, ...rows.map((r) => (r[i] ?? "").length)),
|
|
);
|
|
const headerLine = headers.map((h, i) => padRight(h, widths[i])).join(" ");
|
|
const sep = widths.map((w) => "-".repeat(w)).join(" ");
|
|
console.log(headerLine);
|
|
console.log(sep);
|
|
for (const row of rows) {
|
|
console.log(row.map((c, i) => padRight(c, widths[i])).join(" "));
|
|
}
|
|
}
|
|
|
|
// ── Load data ────────────────────────────────────────────────────────────
|
|
|
|
console.log("Loading data...");
|
|
const [{ records: paragraphs, skipped: pSkip }, { records: rawAnns, skipped: aSkip }] =
|
|
await Promise.all([
|
|
readJsonl(PARAGRAPHS_PATH, Paragraph),
|
|
readJsonlRaw(ANNOTATIONS_PATH),
|
|
]);
|
|
const annotations = rawAnns as Ann[];
|
|
|
|
console.log(
|
|
`Loaded ${paragraphs.length} paragraphs (${pSkip} skipped), ${annotations.length} annotations (${aSkip} skipped)\n`,
|
|
);
|
|
|
|
// ── Group annotations by paragraphId ─────────────────────────────────────
|
|
|
|
const byParagraph = new Map<string, Map<ModelId, Ann>>();
|
|
for (const ann of annotations) {
|
|
const mid = ann.provenance.modelId as ModelId;
|
|
if (!MODELS.includes(mid)) continue;
|
|
if (!byParagraph.has(ann.paragraphId)) byParagraph.set(ann.paragraphId, new Map());
|
|
byParagraph.get(ann.paragraphId)!.set(mid, ann);
|
|
}
|
|
|
|
// Only keep paragraphs with all 3 models
|
|
const complete = new Map<string, Map<ModelId, Ann>>();
|
|
for (const [pid, models] of byParagraph) {
|
|
if (models.size === 3) complete.set(pid, models);
|
|
}
|
|
console.log(`Paragraphs with all 3 models: ${complete.size}\n`);
|
|
|
|
// ── 1. Outlier Analysis ──────────────────────────────────────────────────
|
|
|
|
console.log("=" .repeat(70));
|
|
console.log("1. OUTLIER ANALYSIS");
|
|
console.log("=" .repeat(70));
|
|
|
|
const catOutlierCount: Record<ModelId, number> = {
|
|
"google/gemini-3.1-flash-lite-preview": 0,
|
|
"x-ai/grok-4.1-fast": 0,
|
|
"xiaomi/mimo-v2-flash": 0,
|
|
};
|
|
const specOutlierCount: Record<ModelId, number> = { ...catOutlierCount };
|
|
// Reset specOutlierCount independently
|
|
for (const m of MODELS) specOutlierCount[m] = 0;
|
|
|
|
let catDisagree = 0;
|
|
let specDisagree = 0;
|
|
let catUnanimous = 0;
|
|
let specUnanimous = 0;
|
|
let threeWayDisagreeCat = 0;
|
|
let threeWayDisagreeSpec = 0;
|
|
|
|
for (const [, models] of complete) {
|
|
const cats = MODELS.map((m) => models.get(m)!.label.content_category);
|
|
const specs = MODELS.map((m) => models.get(m)!.label.specificity_level);
|
|
|
|
// Category outlier
|
|
if (cats[0] === cats[1] && cats[1] === cats[2]) {
|
|
catUnanimous++;
|
|
} else if (cats[0] === cats[1] && cats[2] !== cats[0]) {
|
|
catDisagree++;
|
|
catOutlierCount[MODELS[2]]++;
|
|
} else if (cats[0] === cats[2] && cats[1] !== cats[0]) {
|
|
catDisagree++;
|
|
catOutlierCount[MODELS[1]]++;
|
|
} else if (cats[1] === cats[2] && cats[0] !== cats[1]) {
|
|
catDisagree++;
|
|
catOutlierCount[MODELS[0]]++;
|
|
} else {
|
|
threeWayDisagreeCat++;
|
|
}
|
|
|
|
// Specificity outlier
|
|
if (specs[0] === specs[1] && specs[1] === specs[2]) {
|
|
specUnanimous++;
|
|
} else if (specs[0] === specs[1] && specs[2] !== specs[0]) {
|
|
specDisagree++;
|
|
specOutlierCount[MODELS[2]]++;
|
|
} else if (specs[0] === specs[2] && specs[1] !== specs[0]) {
|
|
specDisagree++;
|
|
specOutlierCount[MODELS[1]]++;
|
|
} else if (specs[1] === specs[2] && specs[0] !== specs[1]) {
|
|
specDisagree++;
|
|
specOutlierCount[MODELS[0]]++;
|
|
} else {
|
|
threeWayDisagreeSpec++;
|
|
}
|
|
}
|
|
|
|
console.log(`\nCategory: ${catUnanimous} unanimous, ${catDisagree} 2v1, ${threeWayDisagreeCat} three-way disagree`);
|
|
console.log("\nCategory outlier counts (when one model disagrees with the other two):");
|
|
printTable(
|
|
["Model", "Outlier Count", "% of 2v1"],
|
|
MODELS.map((m) => [SHORT[m], String(catOutlierCount[m]), pct(catOutlierCount[m], catDisagree)]),
|
|
);
|
|
|
|
console.log(`\nSpecificity: ${specUnanimous} unanimous, ${specDisagree} 2v1, ${threeWayDisagreeSpec} three-way disagree`);
|
|
console.log("\nSpecificity outlier counts:");
|
|
printTable(
|
|
["Model", "Outlier Count", "% of 2v1"],
|
|
MODELS.map((m) => [SHORT[m], String(specOutlierCount[m]), pct(specOutlierCount[m], specDisagree)]),
|
|
);
|
|
|
|
// ── 2. Category Bias ─────────────────────────────────────────────────────
|
|
|
|
console.log("\n" + "=" .repeat(70));
|
|
console.log("2. CATEGORY BIAS");
|
|
console.log("=" .repeat(70));
|
|
|
|
const allCategories = new Set<string>();
|
|
const catCounts: Record<ModelId, Record<string, number>> = {} as any;
|
|
for (const m of MODELS) catCounts[m] = {};
|
|
|
|
for (const ann of annotations) {
|
|
const mid = ann.provenance.modelId as ModelId;
|
|
if (!MODELS.includes(mid)) continue;
|
|
const cat = ann.label.content_category;
|
|
allCategories.add(cat);
|
|
catCounts[mid][cat] = (catCounts[mid][cat] ?? 0) + 1;
|
|
}
|
|
|
|
const categories = [...allCategories].sort();
|
|
const modelTotals: Record<ModelId, number> = {} as any;
|
|
for (const m of MODELS) {
|
|
modelTotals[m] = Object.values(catCounts[m]).reduce((a, b) => a + b, 0);
|
|
}
|
|
|
|
console.log("\nCategory distribution (% of each model's annotations):\n");
|
|
const catHeaders = ["Category", ...MODELS.map((m) => SHORT[m]), "Average"];
|
|
const catRows: string[][] = [];
|
|
for (const cat of categories) {
|
|
const pcts = MODELS.map((m) => (100 * (catCounts[m][cat] ?? 0)) / modelTotals[m]);
|
|
const avg = pcts.reduce((a, b) => a + b, 0) / 3;
|
|
catRows.push([cat, ...pcts.map((p) => p.toFixed(1) + "%"), avg.toFixed(1) + "%"]);
|
|
}
|
|
printTable(catHeaders, catRows);
|
|
|
|
console.log("\nOver/under-indexing vs average (percentage points):\n");
|
|
const biasHeaders = ["Category", ...MODELS.map((m) => SHORT[m])];
|
|
const biasRows: string[][] = [];
|
|
for (const cat of categories) {
|
|
const pcts = MODELS.map((m) => (100 * (catCounts[m][cat] ?? 0)) / modelTotals[m]);
|
|
const avg = pcts.reduce((a, b) => a + b, 0) / 3;
|
|
biasRows.push([
|
|
cat,
|
|
...pcts.map((p) => {
|
|
const diff = p - avg;
|
|
const sign = diff >= 0 ? "+" : "";
|
|
return sign + diff.toFixed(1) + "pp";
|
|
}),
|
|
]);
|
|
}
|
|
printTable(biasHeaders, biasRows);
|
|
|
|
// ── 3. Specificity Bias ──────────────────────────────────────────────────
|
|
|
|
console.log("\n" + "=" .repeat(70));
|
|
console.log("3. SPECIFICITY BIAS");
|
|
console.log("=" .repeat(70));
|
|
|
|
const specCounts: Record<ModelId, Record<number, number>> = {} as any;
|
|
for (const m of MODELS) specCounts[m] = {};
|
|
|
|
for (const ann of annotations) {
|
|
const mid = ann.provenance.modelId as ModelId;
|
|
if (!MODELS.includes(mid)) continue;
|
|
const spec = ann.label.specificity_level;
|
|
specCounts[mid][spec] = (specCounts[mid][spec] ?? 0) + 1;
|
|
}
|
|
|
|
const specLevels = [1, 2, 3, 4];
|
|
console.log("\nSpecificity distribution (% of each model's annotations):\n");
|
|
const specHeaders = ["Spec Level", ...MODELS.map((m) => SHORT[m]), "Average"];
|
|
const specRows: string[][] = [];
|
|
for (const lvl of specLevels) {
|
|
const pcts = MODELS.map((m) => (100 * (specCounts[m][lvl] ?? 0)) / modelTotals[m]);
|
|
const avg = pcts.reduce((a, b) => a + b, 0) / 3;
|
|
specRows.push([
|
|
String(lvl),
|
|
...pcts.map((p) => p.toFixed(1) + "%"),
|
|
avg.toFixed(1) + "%",
|
|
]);
|
|
}
|
|
printTable(specHeaders, specRows);
|
|
|
|
console.log("\nOver/under-indexing vs average (percentage points):\n");
|
|
const specBiasRows: string[][] = [];
|
|
for (const lvl of specLevels) {
|
|
const pcts = MODELS.map((m) => (100 * (specCounts[m][lvl] ?? 0)) / modelTotals[m]);
|
|
const avg = pcts.reduce((a, b) => a + b, 0) / 3;
|
|
specBiasRows.push([
|
|
String(lvl),
|
|
...pcts.map((p) => {
|
|
const diff = p - avg;
|
|
const sign = diff >= 0 ? "+" : "";
|
|
return sign + diff.toFixed(1) + "pp";
|
|
}),
|
|
]);
|
|
}
|
|
printTable(["Spec Level", ...MODELS.map((m) => SHORT[m])], specBiasRows);
|
|
|
|
// Mean specificity per model
|
|
console.log("\nMean specificity per model:");
|
|
for (const m of MODELS) {
|
|
let sum = 0;
|
|
let count = 0;
|
|
for (const [lvl, n] of Object.entries(specCounts[m])) {
|
|
sum += Number(lvl) * n;
|
|
count += n;
|
|
}
|
|
console.log(` ${SHORT[m]}: ${(sum / count).toFixed(3)}`);
|
|
}
|
|
|
|
// ── 4. Pairwise Agreement ────────────────────────────────────────────────
|
|
|
|
console.log("\n" + "=" .repeat(70));
|
|
console.log("4. PAIRWISE AGREEMENT");
|
|
console.log("=" .repeat(70));
|
|
|
|
const pairs: [ModelId, ModelId][] = [
|
|
[MODELS[0], MODELS[1]],
|
|
[MODELS[0], MODELS[2]],
|
|
[MODELS[1], MODELS[2]],
|
|
];
|
|
|
|
console.log("");
|
|
const pairHeaders = ["Pair", "Cat Agree", "Cat %", "Spec Agree", "Spec %", "Both Agree", "Both %"];
|
|
const pairRows: string[][] = [];
|
|
for (const [a, b] of pairs) {
|
|
let catAgree = 0;
|
|
let specAgree = 0;
|
|
let bothAgree = 0;
|
|
let total = 0;
|
|
for (const [, models] of complete) {
|
|
const annA = models.get(a)!;
|
|
const annB = models.get(b)!;
|
|
total++;
|
|
const cMatch = annA.label.content_category === annB.label.content_category;
|
|
const sMatch = annA.label.specificity_level === annB.label.specificity_level;
|
|
if (cMatch) catAgree++;
|
|
if (sMatch) specAgree++;
|
|
if (cMatch && sMatch) bothAgree++;
|
|
}
|
|
pairRows.push([
|
|
`${SHORT[a]} - ${SHORT[b]}`,
|
|
String(catAgree),
|
|
pct(catAgree, total),
|
|
String(specAgree),
|
|
pct(specAgree, total),
|
|
String(bothAgree),
|
|
pct(bothAgree, total),
|
|
]);
|
|
}
|
|
printTable(pairHeaders, pairRows);
|
|
|
|
// ── 5. Conditional Outlier ───────────────────────────────────────────────
|
|
|
|
console.log("\n" + "=" .repeat(70));
|
|
console.log("5. CONDITIONAL OUTLIER: What does the outlier model say?");
|
|
console.log("=" .repeat(70));
|
|
|
|
// For each model, when it's the category outlier, what label does it give vs the majority?
|
|
for (const outlierModel of MODELS) {
|
|
const wrongLabelDist: Record<string, Record<string, number>> = {};
|
|
// wrongLabelDist[majorityLabel][outlierLabel] = count
|
|
|
|
for (const [, models] of complete) {
|
|
const cats = MODELS.map((m) => models.get(m)!.label.content_category);
|
|
const others = MODELS.filter((m) => m !== outlierModel);
|
|
const otherCats = others.map((m) => models.get(m)!.label.content_category);
|
|
|
|
if (otherCats[0] !== otherCats[1]) continue; // not a 2v1 with this model as outlier
|
|
const majority = otherCats[0];
|
|
const outlierCat = models.get(outlierModel)!.label.content_category;
|
|
if (outlierCat === majority) continue; // this model agrees
|
|
|
|
if (!wrongLabelDist[majority]) wrongLabelDist[majority] = {};
|
|
wrongLabelDist[majority][outlierCat] = (wrongLabelDist[majority][outlierCat] ?? 0) + 1;
|
|
}
|
|
|
|
console.log(`\n${SHORT[outlierModel]} as outlier — what it says vs majority:`);
|
|
const majorityLabels = Object.keys(wrongLabelDist).sort();
|
|
if (majorityLabels.length === 0) {
|
|
console.log(" (no outlier cases)");
|
|
continue;
|
|
}
|
|
for (const maj of majorityLabels) {
|
|
const entries = Object.entries(wrongLabelDist[maj]).sort((a, b) => b[1] - a[1]);
|
|
const total = entries.reduce((s, [, n]) => s + n, 0);
|
|
console.log(` Majority="${maj}" (${total} cases):`);
|
|
for (const [label, count] of entries) {
|
|
console.log(` -> "${label}": ${count} (${pct(count, total)})`);
|
|
}
|
|
}
|
|
}
|
|
|
|
// ── 6. Spec 4 Analysis ──────────────────────────────────────────────────
|
|
|
|
console.log("\n" + "=" .repeat(70));
|
|
console.log("6. SPEC 4 ANALYSIS: Who disagrees when majority says Spec 4?");
|
|
console.log("=" .repeat(70));
|
|
|
|
const spec4Outliers: Record<ModelId, Record<number, number>> = {} as any;
|
|
for (const m of MODELS) spec4Outliers[m] = {};
|
|
let spec4DisagreeTotal = 0;
|
|
|
|
for (const [, models] of complete) {
|
|
const specs = MODELS.map((m) => models.get(m)!.label.specificity_level);
|
|
|
|
// Find if majority is 4
|
|
const countOf4 = specs.filter((s) => s === 4).length;
|
|
if (countOf4 < 2) continue; // majority is not 4
|
|
if (countOf4 === 3) continue; // unanimous
|
|
|
|
// One model disagrees
|
|
for (let i = 0; i < 3; i++) {
|
|
if (specs[i] !== 4) {
|
|
spec4DisagreeTotal++;
|
|
spec4Outliers[MODELS[i]][specs[i]] = (spec4Outliers[MODELS[i]][specs[i]] ?? 0) + 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
console.log(`\nTotal paragraphs where majority=Spec4 but one disagrees: ${spec4DisagreeTotal}\n`);
|
|
for (const m of MODELS) {
|
|
const entries = Object.entries(spec4Outliers[m])
|
|
.map(([lvl, n]) => [Number(lvl), n] as [number, number])
|
|
.sort((a, b) => a[0] - b[0]);
|
|
const total = entries.reduce((s, [, n]) => s + n, 0);
|
|
if (total === 0) {
|
|
console.log(`${SHORT[m]}: never the outlier on Spec 4`);
|
|
continue;
|
|
}
|
|
console.log(`${SHORT[m]}: ${total} times the outlier (${pct(total, spec4DisagreeTotal)} of Spec4 disputes)`);
|
|
for (const [lvl, n] of entries) {
|
|
console.log(` -> says Spec ${lvl}: ${n} times`);
|
|
}
|
|
}
|
|
|
|
// ── 7. Management Role vs RMP Analysis ───────────────────────────────────
|
|
|
|
console.log("\n" + "=" .repeat(70));
|
|
console.log("7. MANAGEMENT ROLE vs RMP DISPUTES");
|
|
console.log("=" .repeat(70));
|
|
|
|
const mgmtRmpCounts: Record<ModelId, { management: number; rmp: number }> = {} as any;
|
|
for (const m of MODELS) mgmtRmpCounts[m] = { management: 0, rmp: 0 };
|
|
let mgmtRmpTotal = 0;
|
|
|
|
for (const [, models] of complete) {
|
|
const cats = MODELS.map((m) => models.get(m)!.label.content_category);
|
|
const catSet = new Set(cats);
|
|
|
|
// Check if this is a Management vs RMP dispute
|
|
const hasMgmt = catSet.has("Management Role");
|
|
const hasRmp = catSet.has("Risk Management Process");
|
|
if (!hasMgmt || !hasRmp) continue;
|
|
|
|
// Only consider paragraphs where the dispute is specifically between these two
|
|
const relevantCats = cats.filter(
|
|
(c) => c === "Management Role" || c === "Risk Management Process",
|
|
);
|
|
if (relevantCats.length < 2) continue; // at least 2 models must be in this dispute
|
|
|
|
mgmtRmpTotal++;
|
|
for (const m of MODELS) {
|
|
const cat = models.get(m)!.label.content_category;
|
|
if (cat === "Management Role") mgmtRmpCounts[m].management++;
|
|
if (cat === "Risk Management Process") mgmtRmpCounts[m].rmp++;
|
|
}
|
|
}
|
|
|
|
console.log(`\nParagraphs with Management Role vs RMP dispute: ${mgmtRmpTotal}\n`);
|
|
printTable(
|
|
["Model", "Says Management", "Says RMP", "Says Other"],
|
|
MODELS.map((m) => {
|
|
const other = mgmtRmpTotal - mgmtRmpCounts[m].management - mgmtRmpCounts[m].rmp;
|
|
return [
|
|
SHORT[m],
|
|
`${mgmtRmpCounts[m].management} (${pct(mgmtRmpCounts[m].management, mgmtRmpTotal)})`,
|
|
`${mgmtRmpCounts[m].rmp} (${pct(mgmtRmpCounts[m].rmp, mgmtRmpTotal)})`,
|
|
`${other} (${pct(other, mgmtRmpTotal)})`,
|
|
];
|
|
}),
|
|
);
|
|
|
|
console.log("\nDone.");
|