SEC-cyBERT/ts/scripts/judge-bench.ts
2026-03-28 23:44:37 -04:00

456 lines
21 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* Benchmark Stage 2 judge candidates on disagreement paragraphs.
* Runs each model as a judge and compares against Stage 1 majority vote.
*
* Usage: bun ts/scripts/judge-bench.ts <model-id> [--n 50] [--concurrency 10]
*/
import { generateText, tool, Output } from "ai";
import { openrouter, providerOf } from "../src/lib/openrouter.ts";
import { readJsonl, readJsonlRaw, appendJsonl } from "../src/lib/jsonl.ts";
import { Paragraph } from "@sec-cybert/schemas/paragraph.ts";
import { LabelOutputRaw, toLabelOutput } from "@sec-cybert/schemas/label.ts";
import { SYSTEM_PROMPT, buildJudgePrompt, PROMPT_VERSION } from "../src/label/prompts.ts";
import { withRetry } from "../src/lib/retry.ts";
import { v4 as uuidv4 } from "uuid";
import { existsSync } from "node:fs";
import { mkdir } from "node:fs/promises";
import pLimit from "p-limit";
const args = process.argv.slice(2);
const MODEL = args.find(a => !a.startsWith("--"))!;
if (!MODEL) { console.error("Usage: bun ts/scripts/judge-bench.ts <model-id>"); process.exit(1); }
function flag(name: string): string | undefined {
const idx = args.indexOf(`--${name}`);
return idx === -1 ? undefined : args[idx + 1];
}
const N = parseInt(flag("n") ?? "50", 10);
const CONCURRENCY = parseInt(flag("concurrency") ?? "10", 10);
const MODE = (flag("mode") ?? "structured") as "structured" | "tool";
const shortName = MODEL.split("/").pop()!;
const slug = MODEL.replace("/", "_");
const STAGE1_PATH = new URL("../../data/annotations/stage1.jsonl", import.meta.url).pathname;
const PARAGRAPHS_PATH = new URL("../../data/paragraphs/training.jsonl", import.meta.url).pathname;
const BENCH_DIR = new URL("../../data/bench/judges", import.meta.url).pathname;
const SAMPLE_PATH = `${BENCH_DIR}/judge-sample.jsonl`;
const OUTPUT_PATH = `${BENCH_DIR}/${slug}.jsonl`;
if (!existsSync(BENCH_DIR)) await mkdir(BENCH_DIR, { recursive: true });
interface S1Ann {
paragraphId: string;
label: { content_category: string; specificity_level: number; reasoning: string };
provenance: { modelId: string };
}
function pct(n: number, total: number): string {
return `${((n / total) * 100).toFixed(1)}%`;
}
async function main() {
// ── Load Stage 1 annotations ────────────────────────────────────────
console.error(`[${shortName}] Loading Stage 1 data...`);
const { records: allAnns } = await readJsonlRaw(STAGE1_PATH);
const s1ByParagraph = new Map<string, S1Ann[]>();
for (const raw of allAnns) {
const a = raw as S1Ann;
let arr = s1ByParagraph.get(a.paragraphId);
if (!arr) { arr = []; s1ByParagraph.set(a.paragraphId, arr); }
arr.push(a);
}
// ── Find disagreement paragraphs ────────────────────────────────────
const disagreementIds: string[] = [];
for (const [pid, anns] of s1ByParagraph) {
if (anns.length !== 3) continue;
const cats = new Set(anns.map(a => a.label.content_category));
const specs = new Set(anns.map(a => a.label.specificity_level));
if (cats.size > 1 || specs.size > 1) {
disagreementIds.push(pid);
}
}
console.error(`[${shortName}] ${disagreementIds.length.toLocaleString()} disagreement paragraphs total`);
// ── Load or create stable sample ────────────────────────────────────
let sampleIds: string[];
if (existsSync(SAMPLE_PATH)) {
const { records } = await readJsonlRaw(SAMPLE_PATH);
sampleIds = (records as { id: string }[]).map(r => r.id);
console.error(`[${shortName}] Using existing sample of ${sampleIds.length} paragraphs`);
} else {
// Seeded shuffle for reproducibility
const seed = 42;
let rng = seed;
const nextRng = () => { rng = (rng * 1664525 + 1013904223) & 0x7fffffff; return rng / 0x7fffffff; };
const shuffled = [...disagreementIds];
for (let i = shuffled.length - 1; i > 0; i--) {
const j = Math.floor(nextRng() * (i + 1));
[shuffled[i], shuffled[j]] = [shuffled[j]!, shuffled[i]!];
}
sampleIds = shuffled.slice(0, N);
// Save stable sample
for (const id of sampleIds) {
await appendJsonl(SAMPLE_PATH, { id });
}
console.error(`[${shortName}] Created new sample of ${sampleIds.length} paragraphs`);
}
// ── Load paragraph texts ────────────────────────────────────────────
console.error(`[${shortName}] Loading paragraph texts...`);
const { records: allParagraphs } = await readJsonl(PARAGRAPHS_PATH, Paragraph);
const paragraphMap = new Map(allParagraphs.map(p => [p.id, p]));
// ── Resume support ──────────────────────────────────────────────────
const doneKeys = new Set<string>();
if (existsSync(OUTPUT_PATH)) {
const { records: existing } = await readJsonlRaw(OUTPUT_PATH);
for (const r of existing) {
const a = r as { paragraphId?: string };
if (a.paragraphId) doneKeys.add(a.paragraphId);
}
if (doneKeys.size > 0) console.error(`[${shortName}] Resuming: ${doneKeys.size} already done`);
}
const remaining = sampleIds.filter(id => !doneKeys.has(id));
if (remaining.length === 0) {
console.error(`[${shortName}] All done, skipping to analysis`);
} else {
console.error(`[${shortName}] Running ${remaining.length} judge calls (concurrency=${CONCURRENCY})...\n`);
const runId = uuidv4();
const limit = pLimit(CONCURRENCY);
let completed = 0, failed = 0, totalCost = 0;
const startTime = Date.now();
const tasks = remaining.map(pid => limit(async () => {
const paragraph = paragraphMap.get(pid);
if (!paragraph) { failed++; return; }
const priorAnns = s1ByParagraph.get(pid)!;
const priorForJudge = priorAnns.map(a => ({
content_category: a.label.content_category,
specificity_level: a.label.specificity_level,
reasoning: a.label.reasoning,
}));
const requestedAt = new Date().toISOString();
const start = Date.now();
try {
const providerOpts = {
openrouter: {
reasoning: { effort: "medium" as const },
usage: { include: true },
provider: { require_parameters: true },
},
};
let rawOutput: LabelOutputRaw;
let responseId: string;
let usage: { inputTokens?: number; outputTokens?: number; outputTokenDetails?: { reasoningTokens?: number }; raw?: { cost?: number } };
if (MODE === "tool") {
const r = await withRetry(
() => generateText({
model: openrouter(MODEL),
system: SYSTEM_PROMPT,
prompt: buildJudgePrompt(paragraph, priorForJudge),
temperature: 0,
tools: {
submit_label: tool({
description: "Submit your final label for this paragraph",
inputSchema: LabelOutputRaw,
}),
},
toolChoice: "required",
providerOptions: providerOpts,
abortSignal: AbortSignal.timeout(240_000),
}),
{ label: `${shortName}:${pid.slice(0, 8)}` },
);
const tc = r.toolCalls[0];
if (!tc) throw new Error(`No tool call from ${shortName} for ${pid}`);
rawOutput = tc.input as LabelOutputRaw;
responseId = r.response?.id ?? "unknown";
usage = r.usage as typeof usage;
} else {
const r = await withRetry(
() => generateText({
model: openrouter(MODEL),
output: Output.object({ schema: LabelOutputRaw }),
system: SYSTEM_PROMPT,
prompt: buildJudgePrompt(paragraph, priorForJudge),
temperature: 0,
providerOptions: providerOpts,
abortSignal: AbortSignal.timeout(240_000),
}),
{ label: `${shortName}:${pid.slice(0, 8)}` },
);
if (!r.output) throw new Error(`No output from ${shortName} for ${pid}`);
rawOutput = r.output;
responseId = r.response?.id ?? "unknown";
usage = r.usage as typeof usage;
}
const latencyMs = Date.now() - start;
const label = toLabelOutput(rawOutput);
const costUsd = usage.raw?.cost ?? 0;
const annotation = {
paragraphId: pid,
label,
provenance: {
modelId: MODEL,
provider: providerOf(MODEL),
generationId: responseId,
stage: "stage2-judge" as const,
runId,
promptVersion: PROMPT_VERSION,
inputTokens: usage.inputTokens ?? 0,
outputTokens: usage.outputTokens ?? 0,
reasoningTokens: usage.outputTokenDetails?.reasoningTokens ?? 0,
costUsd,
latencyMs,
requestedAt,
},
};
await appendJsonl(OUTPUT_PATH, annotation);
totalCost += costUsd;
completed++;
if (completed % 10 === 0) {
process.stderr.write(`\r[${shortName}] ${completed}/${remaining.length} ($${totalCost.toFixed(4)}) `);
}
} catch (err) {
failed++;
const msg = err instanceof Error ? err.message : String(err);
if (failed <= 3) console.error(`\n[${shortName}] ✖ ${pid.slice(0, 8)}: ${msg.slice(0, 200)}`);
}
}));
await Promise.all(tasks);
const elapsed = ((Date.now() - startTime) / 1000).toFixed(0);
console.error(`\n[${shortName}] Done: ${completed} ok, ${failed} failed, $${totalCost.toFixed(4)}, ${elapsed}s`);
}
// ── Analysis ────────────────────────────────────────────────────────
const { records: judgeRaw } = await readJsonlRaw(OUTPUT_PATH);
const judgeResults = new Map<string, { content_category: string; specificity_level: number; category_confidence: string; specificity_confidence: string; costUsd: number; outputTokens: number; reasoningTokens: number; latencyMs: number }>();
for (const r of judgeRaw) {
const a = r as { paragraphId: string; label: { content_category: string; specificity_level: number; category_confidence: string; specificity_confidence: string }; provenance: { costUsd: number; outputTokens: number; reasoningTokens: number; latencyMs: number } };
judgeResults.set(a.paragraphId, { ...a.label, ...a.provenance });
}
const n = judgeResults.size;
let totalCost = 0, totalOutput = 0, totalReasoning = 0, totalLatency = 0;
for (const v of judgeResults.values()) {
totalCost += v.costUsd;
totalOutput += v.outputTokens;
totalReasoning += v.reasoningTokens;
totalLatency += v.latencyMs;
}
console.log(`\n═══ ${shortName} as Judge (n=${n}) ═══`);
console.log(` Cost: $${totalCost.toFixed(4)} total, $${(totalCost / n).toFixed(5)}/call`);
console.log(` Latency: ${(totalLatency / n).toFixed(0)}ms avg`);
console.log(` Output: ${(totalOutput / n).toFixed(0)} tokens avg, ${(totalReasoning / n).toFixed(0)} reasoning avg`);
console.log(` Est. full Stage 2 cost (14,623 calls): $${(totalCost / n * 14623).toFixed(0)}`);
// ── Load gold labels ───────────────────────────────────────────────
const GOLD_PATH = `${BENCH_DIR}/gold-final.json`;
let goldLabels: Record<string, { cat: string; spec: number }> = {};
if (existsSync(GOLD_PATH)) {
goldLabels = JSON.parse(await Bun.file(GOLD_PATH).text());
console.log(`\n Gold labels loaded: ${Object.keys(goldLabels).length} paragraphs`);
} else {
console.log(`\n ⚠ No gold labels found at ${GOLD_PATH} — skipping gold comparison`);
}
// ── Compare vs gold labels ─────────────────────────────────────────
const hasGold = Object.keys(goldLabels).length > 0;
let goldCatMatch = 0, goldSpecMatch = 0, goldBothMatch = 0, goldTotal = 0;
let majGoldCatMatch = 0, majGoldSpecMatch = 0, majGoldBothMatch = 0, majGoldTotal = 0;
// Confidence breakdown vs gold accuracy
const confBuckets = { high: { correct: 0, total: 0 }, medium: { correct: 0, total: 0 }, low: { correct: 0, total: 0 } };
// Per-category accuracy vs gold
const catAccuracy = new Map<string, { correct: number; total: number }>();
// Confusion matrix for category errors
const catConfusions: { gold: string; judge: string }[] = [];
if (hasGold) {
for (const [pid, judgeLabel] of judgeResults) {
const gold = goldLabels[pid];
if (!gold) continue;
goldTotal++;
const catOk = judgeLabel.content_category === gold.cat;
const specOk = judgeLabel.specificity_level === gold.spec;
if (catOk) goldCatMatch++;
if (specOk) goldSpecMatch++;
if (catOk && specOk) goldBothMatch++;
// Track confidence vs accuracy (use lower of the two confidences)
const worstConf = judgeLabel.category_confidence === "low" || judgeLabel.specificity_confidence === "low"
? "low"
: judgeLabel.category_confidence === "medium" || judgeLabel.specificity_confidence === "medium"
? "medium"
: "high";
confBuckets[worstConf].total++;
if (catOk && specOk) confBuckets[worstConf].correct++;
// Per-category
if (!catAccuracy.has(gold.cat)) catAccuracy.set(gold.cat, { correct: 0, total: 0 });
const ca = catAccuracy.get(gold.cat)!;
ca.total++;
if (catOk) ca.correct++;
// Confusion matrix entries for errors
if (!catOk) catConfusions.push({ gold: gold.cat, judge: judgeLabel.content_category });
// Majority vote vs gold
const s1anns = s1ByParagraph.get(pid)!;
const cats = s1anns.map(a => a.label.content_category);
const catFreq = new Map<string, number>();
for (const c of cats) catFreq.set(c, (catFreq.get(c) ?? 0) + 1);
const majCat = [...catFreq.entries()].find(([, v]) => v >= 2)?.[0];
const specs = s1anns.map(a => a.label.specificity_level);
const specFreq = new Map<number, number>();
for (const s of specs) specFreq.set(s, (specFreq.get(s) ?? 0) + 1);
const majSpec = [...specFreq.entries()].find(([, v]) => v >= 2)?.[0];
majGoldTotal++;
if (majCat === gold.cat) majGoldCatMatch++;
if (majSpec === gold.spec) majGoldSpecMatch++;
if (majCat === gold.cat && majSpec === gold.spec) majGoldBothMatch++;
}
console.log(`\n ── vs GOLD LABELS (n=${goldTotal}) ──`);
console.log(` Judge: cat ${pct(goldCatMatch, goldTotal)}, spec ${pct(goldSpecMatch, goldTotal)}, both ${pct(goldBothMatch, goldTotal)}`);
console.log(` Majority: cat ${pct(majGoldCatMatch, majGoldTotal)}, spec ${pct(majGoldSpecMatch, majGoldTotal)}, both ${pct(majGoldBothMatch, majGoldTotal)}`);
console.log(` Delta: cat +${((goldCatMatch - majGoldCatMatch) / goldTotal * 100).toFixed(1)}pp, spec +${((goldSpecMatch - majGoldSpecMatch) / goldTotal * 100).toFixed(1)}pp, both +${((goldBothMatch - majGoldBothMatch) / goldTotal * 100).toFixed(1)}pp`);
// Confidence calibration
console.log(`\n ── CONFIDENCE CALIBRATION ──`);
for (const [level, bucket] of Object.entries(confBuckets)) {
if (bucket.total > 0) {
console.log(` ${level.padEnd(8)} ${pct(bucket.correct, bucket.total).padStart(6)} both-correct (n=${bucket.total})`);
}
}
// Per-category accuracy
console.log(`\n ── PER-CATEGORY ACCURACY (vs gold) ──`);
for (const [cat, acc] of [...catAccuracy.entries()].sort((a, b) => b[1].total - a[1].total)) {
console.log(` ${cat.padEnd(30)} ${pct(acc.correct, acc.total).padStart(6)} (${acc.correct}/${acc.total})`);
}
// Category confusions
if (catConfusions.length > 0) {
console.log(`\n ── CATEGORY ERRORS (${catConfusions.length} total) ──`);
const confusionCounts = new Map<string, number>();
for (const { gold, judge } of catConfusions) {
const key = `${gold}${judge}`;
confusionCounts.set(key, (confusionCounts.get(key) ?? 0) + 1);
}
for (const [pair, count] of [...confusionCounts.entries()].sort(([, a], [, b]) => b - a)) {
console.log(` ${pair}: ${count}`);
}
}
}
// ── Compare judge vs Stage 1 majority vote ─────────────────────────
let agreeMajCat = 0, agreeMajSpec = 0, agreeMajBoth = 0;
const modelAgreement = new Map<string, { cat: number; spec: number; total: number }>();
for (const [pid, judgeLabel] of judgeResults) {
const s1anns = s1ByParagraph.get(pid)!;
const cats = s1anns.map(a => a.label.content_category);
const catFreq = new Map<string, number>();
for (const c of cats) catFreq.set(c, (catFreq.get(c) ?? 0) + 1);
const majCat = [...catFreq.entries()].find(([, v]) => v >= 2)?.[0];
const specs = s1anns.map(a => a.label.specificity_level);
const specFreq = new Map<number, number>();
for (const s of specs) specFreq.set(s, (specFreq.get(s) ?? 0) + 1);
const majSpec = [...specFreq.entries()].find(([, v]) => v >= 2)?.[0];
if (majCat && judgeLabel.content_category === majCat) agreeMajCat++;
if (majSpec !== undefined && judgeLabel.specificity_level === majSpec) agreeMajSpec++;
if (majCat && judgeLabel.content_category === majCat && majSpec !== undefined && judgeLabel.specificity_level === majSpec) agreeMajBoth++;
for (const s1 of s1anns) {
const m = s1.provenance.modelId.split("/").pop()!;
if (!modelAgreement.has(m)) modelAgreement.set(m, { cat: 0, spec: 0, total: 0 });
const ma = modelAgreement.get(m)!;
ma.total++;
if (s1.label.content_category === judgeLabel.content_category) ma.cat++;
if (s1.label.specificity_level === judgeLabel.specificity_level) ma.spec++;
}
}
console.log(`\n ── vs Stage 1 Majority ──`);
console.log(` cat ${pct(agreeMajCat, n)}, spec ${pct(agreeMajSpec, n)}, both ${pct(agreeMajBoth, n)}`);
console.log(`\n vs Individual Stage 1 models:`);
for (const [m, a] of [...modelAgreement.entries()].sort()) {
console.log(` × ${m.padEnd(30)} cat ${pct(a.cat, a.total).padStart(6)} spec ${pct(a.spec, a.total).padStart(6)}`);
}
// How often does judge side with outlier vs majority?
let sidesMajority = 0, sidesOutlier = 0, sidesNeither = 0;
for (const [pid, judgeLabel] of judgeResults) {
const s1anns = s1ByParagraph.get(pid)!;
const cats = s1anns.map(a => a.label.content_category);
const catFreq = new Map<string, number>();
for (const c of cats) catFreq.set(c, (catFreq.get(c) ?? 0) + 1);
const majCat = [...catFreq.entries()].find(([, v]) => v >= 2)?.[0];
if (!majCat) { sidesNeither++; continue; }
const outlierCats = cats.filter(c => c !== majCat);
if (judgeLabel.content_category === majCat) sidesMajority++;
else if (outlierCats.includes(judgeLabel.content_category)) sidesOutlier++;
else sidesNeither++;
}
console.log(`\n Judge category decision pattern:`);
console.log(` Sides with majority: ${sidesMajority} (${pct(sidesMajority, n)})`);
console.log(` Sides with outlier: ${sidesOutlier} (${pct(sidesOutlier, n)})`);
console.log(` Neither (own pick): ${sidesNeither} (${pct(sidesNeither, n)})`);
// ── Confidence distribution ─────────────────────────────────────────
const catConfDist = { high: 0, medium: 0, low: 0 };
const specConfDist = { high: 0, medium: 0, low: 0 };
for (const v of judgeResults.values()) {
catConfDist[v.category_confidence as keyof typeof catConfDist]++;
specConfDist[v.specificity_confidence as keyof typeof specConfDist]++;
}
console.log(`\n ── CONFIDENCE DISTRIBUTION ──`);
console.log(` Category: high=${catConfDist.high} medium=${catConfDist.medium} low=${catConfDist.low}`);
console.log(` Specificity: high=${specConfDist.high} medium=${specConfDist.medium} low=${specConfDist.low}`);
// Write report JSON
const report = {
model: MODEL, shortName, n,
totalCost: +totalCost.toFixed(4),
costPerCall: +(totalCost / n).toFixed(5),
estFullCost: +(totalCost / n * 14623).toFixed(0),
avgOutputTokens: +(totalOutput / n).toFixed(0),
avgReasoningTokens: +(totalReasoning / n).toFixed(0),
avgLatencyMs: +(totalLatency / n).toFixed(0),
vsGold: hasGold ? { cat: +(goldCatMatch / goldTotal * 100).toFixed(1), spec: +(goldSpecMatch / goldTotal * 100).toFixed(1), both: +(goldBothMatch / goldTotal * 100).toFixed(1) } : null,
vsMajority: { cat: +(agreeMajCat / n * 100).toFixed(1), spec: +(agreeMajSpec / n * 100).toFixed(1), both: +(agreeMajBoth / n * 100).toFixed(1) },
majorityVsGold: hasGold ? { cat: +(majGoldCatMatch / majGoldTotal * 100).toFixed(1), spec: +(majGoldSpecMatch / majGoldTotal * 100).toFixed(1), both: +(majGoldBothMatch / majGoldTotal * 100).toFixed(1) } : null,
confidenceCalibration: hasGold ? Object.fromEntries(Object.entries(confBuckets).map(([k, v]) => [k, { accuracy: v.total > 0 ? +(v.correct / v.total * 100).toFixed(1) : null, n: v.total }])) : null,
sidesMajority: +(sidesMajority / n * 100).toFixed(1),
sidesOutlier: +(sidesOutlier / n * 100).toFixed(1),
};
await Bun.write(`${BENCH_DIR}/${slug}.report.json`, JSON.stringify(report, null, 2) + "\n");
console.error(`\n[${shortName}] Report saved`);
}
main().catch(err => { console.error(err); process.exit(1); });