SEC-cyBERT/ts/scripts/model-bench.ts
2026-03-28 23:44:37 -04:00

260 lines
12 KiB
TypeScript
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* Benchmark a single model on the 500-sample pilot set.
* Outputs JSONL + comparison report against Stage 1 annotations.
*
* Usage: bun ts/scripts/model-bench.ts <model-id> [--smoke] [--concurrency 15]
*
* --smoke: run only 5 paragraphs to check schema compliance
*/
import { readJsonl, readJsonlRaw, appendJsonl } from "../src/lib/jsonl.ts";
import { Paragraph } from "@sec-cybert/schemas/paragraph.ts";
import { annotateParagraph, type AnnotateOpts } from "../src/label/annotate.ts";
import { PROMPT_VERSION } from "../src/label/prompts.ts";
import { v4 as uuidv4 } from "uuid";
import { existsSync } from "node:fs";
import pLimit from "p-limit";
const args = process.argv.slice(2);
const MODEL = args.find(a => !a.startsWith("--"))!;
if (!MODEL) { console.error("Usage: bun ts/scripts/model-bench.ts <model-id> [--smoke]"); process.exit(1); }
const SMOKE = args.includes("--smoke");
const concIdx = args.indexOf("--concurrency");
const CONCURRENCY = concIdx !== -1 ? parseInt(args[concIdx + 1], 10) : 15;
const PILOT_SAMPLE = new URL("../../data/pilot/pilot-sample-v2.5.jsonl", import.meta.url).pathname;
const STAGE1_PATH = new URL("../../data/annotations/stage1.jsonl", import.meta.url).pathname;
const slug = MODEL.replace("/", "_");
const OUTPUT_PATH = new URL(`../../data/bench/${slug}.jsonl`, import.meta.url).pathname;
import { mkdir } from "node:fs/promises";
const benchDir = new URL("../../data/bench", import.meta.url).pathname;
if (!existsSync(benchDir)) await mkdir(benchDir, { recursive: true });
interface S1Ann {
paragraphId: string;
label: { content_category: string; specificity_level: number };
provenance: { modelId: string };
}
function pct(n: number, total: number): string {
return `${((n / total) * 100).toFixed(1)}%`;
}
async function main() {
const shortName = MODEL.split("/").pop()!;
console.error(`\n[${shortName}] Loading data...`);
const { records: allParagraphs } = await readJsonl(PILOT_SAMPLE, Paragraph);
const paragraphs = SMOKE ? allParagraphs.slice(0, 5) : allParagraphs;
console.error(`[${shortName}] ${paragraphs.length} paragraphs ${SMOKE ? "(smoke test)" : ""}`);
// Resume support
const doneKeys = new Set<string>();
if (existsSync(OUTPUT_PATH)) {
const { records: existing } = await readJsonlRaw(OUTPUT_PATH);
for (const r of existing) {
const a = r as { paragraphId?: string };
if (a.paragraphId) doneKeys.add(a.paragraphId);
}
if (doneKeys.size > 0) console.error(`[${shortName}] Resuming: ${doneKeys.size} already done`);
}
const remaining = paragraphs.filter(p => !doneKeys.has(p.id));
if (remaining.length === 0) {
console.error(`[${shortName}] All done, skipping to analysis`);
} else {
console.error(`[${shortName}] Running ${remaining.length} annotations (concurrency=${CONCURRENCY})...\n`);
const runId = uuidv4();
const limit = pLimit(CONCURRENCY);
let completed = 0, failed = 0, totalCost = 0;
const errors: { id: string; msg: string }[] = [];
const startTime = Date.now();
const tasks = remaining.map(p => limit(async () => {
const opts: AnnotateOpts = {
modelId: MODEL,
stage: "benchmark",
runId,
promptVersion: PROMPT_VERSION,
reasoningEffort: "low",
};
try {
const ann = await annotateParagraph(p, opts);
await appendJsonl(OUTPUT_PATH, ann);
totalCost += ann.provenance.costUsd;
completed++;
if (completed % 50 === 0 || SMOKE) {
const elapsed = (Date.now() - startTime) / 1000;
process.stderr.write(`\r[${shortName}] ${completed}/${remaining.length} (${(completed / elapsed).toFixed(1)}/s, $${totalCost.toFixed(4)}) `);
}
} catch (err) {
failed++;
const msg = err instanceof Error ? err.message : String(err);
errors.push({ id: p.id.slice(0, 8), msg: msg.slice(0, 200) });
if (SMOKE || failed <= 5) {
console.error(`\n[${shortName}] ✖ ${p.id.slice(0, 8)}: ${msg.slice(0, 300)}`);
}
}
}));
await Promise.all(tasks);
const elapsed = ((Date.now() - startTime) / 1000).toFixed(0);
console.error(`\n[${shortName}] Done: ${completed} ok, ${failed} failed, $${totalCost.toFixed(4)}, ${elapsed}s`);
if (errors.length > 5) {
console.error(`[${shortName}] ... and ${errors.length - 5} more errors`);
}
if (SMOKE) {
console.error(`\n[${shortName}] Smoke test complete.`);
return;
}
}
// ── Analysis ─────────────────────────────────────────────────────────
if (SMOKE) return;
const pilotIds = new Set(paragraphs.map(p => p.id));
console.error(`[${shortName}] Loading Stage 1 data for comparison...`);
const { records: allAnns } = await readJsonlRaw(STAGE1_PATH);
const s1ByParagraph = new Map<string, S1Ann[]>();
for (const raw of allAnns) {
const a = raw as S1Ann;
if (!pilotIds.has(a.paragraphId)) continue;
let arr = s1ByParagraph.get(a.paragraphId);
if (!arr) { arr = []; s1ByParagraph.set(a.paragraphId, arr); }
arr.push(a);
}
const { records: benchRaw } = await readJsonlRaw(OUTPUT_PATH);
const benchByParagraph = new Map<string, { content_category: string; specificity_level: number; costUsd: number; latencyMs: number; outputTokens: number; reasoningTokens: number }>();
for (const r of benchRaw) {
const a = r as { paragraphId: string; label: { content_category: string; specificity_level: number }; provenance: { costUsd: number; latencyMs: number; outputTokens: number; reasoningTokens: number } };
benchByParagraph.set(a.paragraphId, { ...a.label, costUsd: a.provenance.costUsd, latencyMs: a.provenance.latencyMs, outputTokens: a.provenance.outputTokens, reasoningTokens: a.provenance.reasoningTokens });
}
const n = benchByParagraph.size;
const s1Models = ["google/gemini-3.1-flash-lite-preview", "openai/gpt-5.4-nano", "x-ai/grok-4.1-fast"];
const sn = (m: string) => m.split("/").pop()!;
let totalCost = 0, totalLatency = 0, totalOutput = 0, totalReasoning = 0;
for (const v of benchByParagraph.values()) {
totalCost += v.costUsd;
totalLatency += v.latencyMs;
totalOutput += v.outputTokens;
totalReasoning += v.reasoningTokens;
}
// Output structured JSON report for aggregation
const report: Record<string, unknown> = {
model: MODEL,
shortName,
n,
totalCost: +totalCost.toFixed(4),
avgCost: +(totalCost / n).toFixed(6),
avgLatencyMs: +(totalLatency / n).toFixed(0),
avgOutputTokens: +(totalOutput / n).toFixed(0),
avgReasoningTokens: +(totalReasoning / n).toFixed(0),
pairwise: {} as Record<string, unknown>,
};
console.log(`\n═══ ${shortName} (n=${n}) ═══`);
console.log(` Cost: $${totalCost.toFixed(4)} total, $${(totalCost / n).toFixed(6)}/ann`);
console.log(` Latency: ${(totalLatency / n).toFixed(0)}ms avg`);
console.log(` Output: ${(totalOutput / n).toFixed(0)} tokens avg, ${(totalReasoning / n).toFixed(0)} reasoning avg`);
// Pairwise
console.log("\n Pairwise vs Stage 1:");
for (const model of s1Models) {
let catAgree = 0, specAgree = 0, bothAgree = 0, total = 0;
for (const [pid, bl] of benchByParagraph) {
const s1anns = s1ByParagraph.get(pid);
if (!s1anns) continue;
const s1 = s1anns.find(a => a.provenance.modelId === model);
if (!s1) continue;
total++;
if (s1.label.content_category === bl.content_category) catAgree++;
if (s1.label.specificity_level === bl.specificity_level) specAgree++;
if (s1.label.content_category === bl.content_category && s1.label.specificity_level === bl.specificity_level) bothAgree++;
}
(report.pairwise as Record<string, unknown>)[sn(model)] = { cat: +(catAgree / total * 100).toFixed(1), spec: +(specAgree / total * 100).toFixed(1), both: +(bothAgree / total * 100).toFixed(1) };
console.log(` × ${sn(model).padEnd(30)} cat ${pct(catAgree, total).padStart(6)} spec ${pct(specAgree, total).padStart(6)} both ${pct(bothAgree, total).padStart(6)}`);
}
// Majority agreement
let catMajAgree = 0, specMajAgree = 0, bothMajAgree = 0, totalMaj = 0;
for (const [pid, bl] of benchByParagraph) {
const s1anns = s1ByParagraph.get(pid);
if (!s1anns || s1anns.length !== 3) continue;
totalMaj++;
const cats = s1anns.map(a => a.label.content_category);
const catFreq = new Map<string, number>();
for (const c of cats) catFreq.set(c, (catFreq.get(c) ?? 0) + 1);
const majCat = [...catFreq.entries()].find(([, v]) => v >= 2)?.[0];
const specs = s1anns.map(a => a.label.specificity_level);
const specFreq = new Map<number, number>();
for (const s of specs) specFreq.set(s, (specFreq.get(s) ?? 0) + 1);
const majSpec = [...specFreq.entries()].find(([, v]) => v >= 2)?.[0];
if (majCat && bl.content_category === majCat) catMajAgree++;
if (majSpec !== undefined && bl.specificity_level === majSpec) specMajAgree++;
if (majCat && bl.content_category === majCat && majSpec !== undefined && bl.specificity_level === majSpec) bothMajAgree++;
}
report.vsMajority = { cat: +(catMajAgree / totalMaj * 100).toFixed(1), spec: +(specMajAgree / totalMaj * 100).toFixed(1), both: +(bothMajAgree / totalMaj * 100).toFixed(1) };
console.log(`\n vs Majority Vote: cat ${pct(catMajAgree, totalMaj).padStart(6)} spec ${pct(specMajAgree, totalMaj).padStart(6)} both ${pct(bothMajAgree, totalMaj).padStart(6)}`);
// Hypothetical replacement of nano
let newCatUnan = 0, newSpecUnan = 0, newBothUnan = 0;
let oldCatUnan = 0, oldSpecUnan = 0, oldBothUnan = 0;
let nCompare = 0;
for (const [pid, bl] of benchByParagraph) {
const s1anns = s1ByParagraph.get(pid);
if (!s1anns || s1anns.length !== 3) continue;
nCompare++;
const gemini = s1anns.find(a => a.provenance.modelId.includes("gemini"))!;
const nano = s1anns.find(a => a.provenance.modelId.includes("nano"))!;
const grok = s1anns.find(a => a.provenance.modelId.includes("grok"))!;
const oldCats = [gemini, nano, grok].map(a => a.label.content_category);
const oldSpecs = [gemini, nano, grok].map(a => a.label.specificity_level);
if (new Set(oldCats).size === 1) oldCatUnan++;
if (new Set(oldSpecs).size === 1) oldSpecUnan++;
if (new Set(oldCats).size === 1 && new Set(oldSpecs).size === 1) oldBothUnan++;
const newCats = [gemini.label.content_category, bl.content_category, grok.label.content_category];
const newSpecs = [gemini.label.specificity_level, bl.specificity_level, grok.label.specificity_level];
if (new Set(newCats).size === 1) newCatUnan++;
if (new Set(newSpecs).size === 1) newSpecUnan++;
if (new Set(newCats).size === 1 && new Set(newSpecs).size === 1) newBothUnan++;
}
report.replaceNano = {
oldBothUnan: +(oldBothUnan / nCompare * 100).toFixed(1),
newBothUnan: +(newBothUnan / nCompare * 100).toFixed(1),
deltaBothPp: +((newBothUnan - oldBothUnan) / nCompare * 100).toFixed(1),
};
console.log(`\n Replace nano hypothetical (n=${nCompare}):`);
console.log(` Both-unan: ${pct(oldBothUnan, nCompare)}${pct(newBothUnan, nCompare)} (${((newBothUnan - oldBothUnan) / nCompare * 100).toFixed(1)}pp)`);
// Outlier rate vs gemini×grok
let benchCatOut = 0, benchSpecOut = 0;
for (const [pid, bl] of benchByParagraph) {
const s1anns = s1ByParagraph.get(pid);
if (!s1anns || s1anns.length !== 3) continue;
const gemini = s1anns.find(a => a.provenance.modelId.includes("gemini"))!;
const grok = s1anns.find(a => a.provenance.modelId.includes("grok"))!;
if (gemini.label.content_category === grok.label.content_category && bl.content_category !== gemini.label.content_category) benchCatOut++;
if (gemini.label.specificity_level === grok.label.specificity_level && bl.specificity_level !== gemini.label.specificity_level) benchSpecOut++;
}
report.outlierVsGeminiGrok = { cat: +(benchCatOut / nCompare * 100).toFixed(1), spec: +(benchSpecOut / nCompare * 100).toFixed(1) };
console.log(`\n Outlier (gemini×grok agree, ${shortName} differs): cat ${pct(benchCatOut, nCompare)}, spec ${pct(benchSpecOut, nCompare)}`);
// Write report JSON
const reportPath = new URL(`../../data/bench/${slug}.report.json`, import.meta.url).pathname;
await Bun.write(reportPath, JSON.stringify(report, null, 2) + "\n");
console.error(`\n[${shortName}] Report saved to ${reportPath}`);
}
main().catch(err => { console.error(err); process.exit(1); });