SEC-cyBERT/ts/scripts/pilot.ts
2026-03-28 20:39:36 -04:00

477 lines
20 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* Prompt pilot: run all 3 Stage 1 models on a stratified sample of paragraphs.
*
* Usage:
* bun ts/scripts/pilot.ts [--n 40] [--seed 42] [--concurrency 5]
*
* Outputs:
* data/pilot/pilot-sample.jsonl — the sampled paragraphs
* data/pilot/pilot-results.jsonl — all annotations (3 per paragraph)
* data/pilot/pilot-report.txt — human-readable comparison report
*/
import { z } from "zod";
import { readJsonl, writeJsonl, appendJsonl } from "../src/lib/jsonl.ts";
import { Paragraph } from "../src/schemas/paragraph.ts";
import { STAGE1_MODELS } from "../src/lib/openrouter.ts";
import { annotateParagraph, type AnnotateOpts } from "../src/label/annotate.ts";
import { PROMPT_VERSION } from "../src/label/prompts.ts";
import { v4 as uuidv4 } from "uuid";
import { writeFile, mkdir } from "node:fs/promises";
import { existsSync } from "node:fs";
import pLimit from "p-limit";
// ── Args ────────────────────────────────────────────────────────────────
const args = process.argv.slice(2);
function flag(name: string): string | undefined {
const idx = args.indexOf(`--${name}`);
return idx === -1 ? undefined : args[idx + 1];
}
const N = parseInt(flag("n") ?? "40", 10);
const SEED = parseInt(flag("seed") ?? "42", 10);
const CONCURRENCY = parseInt(flag("concurrency") ?? "5", 10);
// ── Seeded PRNG (mulberry32) ────────────────────────────────────────────
function mulberry32(seed: number) {
let s = seed | 0;
return () => {
s = (s + 0x6d2b79f5) | 0;
let t = Math.imul(s ^ (s >>> 15), 1 | s);
t = (t + Math.imul(t ^ (t >>> 7), 61 | t)) ^ t;
return ((t ^ (t >>> 14)) >>> 0) / 4294967296;
};
}
// ── Stratified sampling ─────────────────────────────────────────────────
function sampleStratified(paragraphs: Paragraph[], n: number, seed: number): Paragraph[] {
const rng = mulberry32(seed);
// Bucket by word count quartile + filing type
const buckets: Map<string, Paragraph[]> = new Map();
for (const p of paragraphs) {
const wcBucket =
p.wordCount < 50 ? "short" :
p.wordCount < 100 ? "medium" :
p.wordCount < 200 ? "long" : "very-long";
const key = `${wcBucket}|${p.filing.filingType}`;
const arr = buckets.get(key) ?? [];
arr.push(p);
buckets.set(key, arr);
}
// Draw proportionally from each bucket
const sampled: Paragraph[] = [];
const bucketKeys = [...buckets.keys()].sort();
const perBucket = Math.max(1, Math.floor(n / bucketKeys.length));
for (const key of bucketKeys) {
const pool = buckets.get(key)!;
// Fisher-Yates shuffle with seeded RNG
for (let i = pool.length - 1; i > 0; i--) {
const j = Math.floor(rng() * (i + 1));
[pool[i], pool[j]] = [pool[j]!, pool[i]!];
}
sampled.push(...pool.slice(0, perBucket));
}
// Fill remaining from the full pool
if (sampled.length < n) {
const usedIds = new Set(sampled.map((p) => p.id));
const remaining = paragraphs.filter((p) => !usedIds.has(p.id));
for (let i = remaining.length - 1; i > 0; i--) {
const j = Math.floor(rng() * (i + 1));
[remaining[i], remaining[j]] = [remaining[j]!, remaining[i]!];
}
sampled.push(...remaining.slice(0, n - sampled.length));
}
return sampled.slice(0, n);
}
// ── Main ────────────────────────────────────────────────────────────────
const PILOT_DIR = "../data/pilot";
const TRAINING_PATH = "../data/paragraphs/training.jsonl";
async function main() {
if (!existsSync(PILOT_DIR)) await mkdir(PILOT_DIR, { recursive: true });
// Load training data
console.error(`Loading training data from ${TRAINING_PATH}...`);
const { records: paragraphs, skipped } = await readJsonl(TRAINING_PATH, Paragraph);
if (skipped > 0) console.error(` ⚠ Skipped ${skipped} invalid lines`);
console.error(` Loaded ${paragraphs.length} paragraphs`);
// Sample
const sample = sampleStratified(paragraphs, N, SEED);
const samplePath = `${PILOT_DIR}/pilot-sample-${PROMPT_VERSION}.jsonl`;
await writeJsonl(samplePath, sample);
console.error(` Sampled ${sample.length} paragraphs (${samplePath})`);
// Show distribution
const filingTypes = new Map<string, number>();
const wcBuckets = new Map<string, number>();
for (const p of sample) {
filingTypes.set(p.filing.filingType, (filingTypes.get(p.filing.filingType) ?? 0) + 1);
const wc = p.wordCount < 50 ? "<50w" : p.wordCount < 100 ? "50-99w" : p.wordCount < 200 ? "100-199w" : "200+w";
wcBuckets.set(wc, (wcBuckets.get(wc) ?? 0) + 1);
}
console.error(` Filing types: ${[...filingTypes.entries()].map(([k, v]) => `${k}=${v}`).join(", ")}`);
console.error(` Word counts: ${[...wcBuckets.entries()].map(([k, v]) => `${k}=${v}`).join(", ")}`);
// Run all 3 models — with resume support
const runId = uuidv4();
const resultsPath = `${PILOT_DIR}/pilot-results-${PROMPT_VERSION}.jsonl`;
const limit = pLimit(CONCURRENCY);
type AnnotationResult = {
paragraphId: string;
modelId: string;
content_category: string;
specificity_level: number;
category_confidence: string;
specificity_confidence: string;
reasoning: string;
inputTokens: number;
outputTokens: number;
reasoningTokens: number;
costUsd: number;
latencyMs: number;
};
// Load existing results for resume
const doneKeys = new Set<string>();
const allResults: AnnotationResult[] = [];
let totalCost = 0;
if (existsSync(resultsPath)) {
const { records: existing } = await readJsonl(resultsPath, z.object({
paragraphId: z.string(),
provenance: z.object({ modelId: z.string(), costUsd: z.number() }),
label: z.object({
content_category: z.string(),
specificity_level: z.number(),
category_confidence: z.string(),
specificity_confidence: z.string(),
reasoning: z.string(),
}),
}));
for (const rec of existing) {
doneKeys.add(`${rec.paragraphId}|${rec.provenance.modelId}`);
allResults.push({
paragraphId: rec.paragraphId,
modelId: rec.provenance.modelId,
content_category: rec.label.content_category,
specificity_level: rec.label.specificity_level,
category_confidence: rec.label.category_confidence,
specificity_confidence: rec.label.specificity_confidence,
reasoning: rec.label.reasoning,
inputTokens: 0, outputTokens: 0, reasoningTokens: 0,
costUsd: rec.provenance.costUsd,
latencyMs: 0,
});
totalCost += rec.provenance.costUsd;
}
if (doneKeys.size > 0) {
console.error(` Resuming: ${doneKeys.size} annotations already done, skipping`);
}
}
for (const modelId of STAGE1_MODELS) {
console.error(`\n ═══ ${modelId} ═══`);
const modelResults: AnnotationResult[] = [];
let modelCost = 0;
let modelInputTokens = 0;
let modelOutputTokens = 0;
let modelReasoningTokens = 0;
const tasks = sample.map((paragraph) =>
limit(async () => {
// Skip if already done (resume)
if (doneKeys.has(`${paragraph.id}|${modelId}`)) return;
const opts: AnnotateOpts = {
modelId,
stage: "stage1",
runId,
promptVersion: PROMPT_VERSION,
reasoningEffort: "low",
};
try {
const ann = await annotateParagraph(paragraph, opts);
const result: AnnotationResult = {
paragraphId: paragraph.id,
modelId,
content_category: ann.label.content_category,
specificity_level: ann.label.specificity_level,
category_confidence: ann.label.category_confidence,
specificity_confidence: ann.label.specificity_confidence,
reasoning: ann.label.reasoning,
inputTokens: ann.provenance.inputTokens,
outputTokens: ann.provenance.outputTokens,
reasoningTokens: ann.provenance.reasoningTokens,
costUsd: ann.provenance.costUsd,
latencyMs: ann.provenance.latencyMs,
};
modelResults.push(result);
allResults.push(result);
await appendJsonl(resultsPath, ann);
modelCost += ann.provenance.costUsd;
modelInputTokens += ann.provenance.inputTokens;
modelOutputTokens += ann.provenance.outputTokens;
modelReasoningTokens += ann.provenance.reasoningTokens;
totalCost += ann.provenance.costUsd;
const doneForModel = allResults.filter(r => r.modelId === modelId).length;
process.stderr.write(`\r ${doneForModel}/${sample.length} done $${modelCost.toFixed(4)}`);
} catch (error) {
console.error(`\n ✖ ${modelId} failed on ${paragraph.id}: ${error instanceof Error ? error.message : String(error)}`);
}
}),
);
await Promise.all(tasks);
console.error(
`\n ${modelId}: ${modelResults.length}/${sample.length} done` +
`$${modelCost.toFixed(4)}` +
`${modelInputTokens.toLocaleString()} in / ${modelOutputTokens.toLocaleString()} out / ${modelReasoningTokens.toLocaleString()} reasoning`,
);
}
// ── Generate report ─────────────────────────────────────────────────
const report: string[] = [];
report.push(`SEC-cyBERT Prompt Pilot Report — ${new Date().toISOString()}`);
report.push(`Prompt version: ${PROMPT_VERSION}`);
report.push(`Sample: ${sample.length} paragraphs, seed=${SEED}`);
report.push(`Models: ${STAGE1_MODELS.join(", ")}`);
report.push(`Total cost: $${totalCost.toFixed(4)}`);
report.push("");
// Per-model stats
report.push("═══ PER-MODEL STATS ═══");
for (const modelId of STAGE1_MODELS) {
const modelAnns = allResults.filter((r) => r.modelId === modelId);
const cost = modelAnns.reduce((s, r) => s + r.costUsd, 0);
const inTok = modelAnns.reduce((s, r) => s + r.inputTokens, 0);
const outTok = modelAnns.reduce((s, r) => s + r.outputTokens, 0);
const reasonTok = modelAnns.reduce((s, r) => s + r.reasoningTokens, 0);
const avgLatency = modelAnns.length > 0
? Math.round(modelAnns.reduce((s, r) => s + r.latencyMs, 0) / modelAnns.length)
: 0;
report.push(`\n${modelId}:`);
report.push(` Cost: $${cost.toFixed(4)} ($${(cost / modelAnns.length).toFixed(6)}/para)`);
report.push(` Tokens: ${inTok.toLocaleString()} in, ${outTok.toLocaleString()} out, ${reasonTok.toLocaleString()} reasoning`);
report.push(` Avg latency: ${avgLatency}ms`);
// Category distribution
const catCounts = new Map<string, number>();
for (const r of modelAnns) {
catCounts.set(r.content_category, (catCounts.get(r.content_category) ?? 0) + 1);
}
report.push(` Categories: ${[...catCounts.entries()].sort((a, b) => b[1] - a[1]).map(([k, v]) => `${k}=${v}`).join(", ")}`);
// Specificity distribution
const specCounts = new Map<number, number>();
for (const r of modelAnns) {
specCounts.set(r.specificity_level, (specCounts.get(r.specificity_level) ?? 0) + 1);
}
report.push(` Specificity: ${[...specCounts.entries()].sort((a, b) => a[0] - b[0]).map(([k, v]) => `${k}=${v}`).join(", ")}`);
// Confidence distribution
const catConf = new Map<string, number>();
const specConf = new Map<string, number>();
for (const r of modelAnns) {
catConf.set(r.category_confidence, (catConf.get(r.category_confidence) ?? 0) + 1);
specConf.set(r.specificity_confidence, (specConf.get(r.specificity_confidence) ?? 0) + 1);
}
report.push(` Category confidence: ${[...catConf.entries()].map(([k, v]) => `${k}=${v}`).join(", ")}`);
report.push(` Specificity confidence: ${[...specConf.entries()].map(([k, v]) => `${k}=${v}`).join(", ")}`);
}
// Agreement analysis
report.push("\n\n═══ AGREEMENT ANALYSIS ═══");
const byParagraph = new Map<string, AnnotationResult[]>();
for (const r of allResults) {
const arr = byParagraph.get(r.paragraphId) ?? [];
arr.push(r);
byParagraph.set(r.paragraphId, arr);
}
let catAgree3 = 0, catAgree2 = 0, catDisagreeAll = 0;
let specAgree3 = 0, specAgree2 = 0, specDisagreeAll = 0;
let bothAgree3 = 0;
for (const [, anns] of byParagraph) {
if (anns.length !== 3) continue;
// Category agreement
const cats = anns.map((a) => a.content_category);
const uniqueCats = new Set(cats).size;
if (uniqueCats === 1) catAgree3++;
else if (uniqueCats === 2) catAgree2++;
else catDisagreeAll++;
// Specificity agreement
const specs = anns.map((a) => a.specificity_level);
const uniqueSpecs = new Set(specs).size;
if (uniqueSpecs === 1) specAgree3++;
else if (uniqueSpecs === 2) specAgree2++;
else specDisagreeAll++;
// Both agree
if (uniqueCats === 1 && uniqueSpecs === 1) bothAgree3++;
}
const total = byParagraph.size;
report.push(`Paragraphs with all 3 models: ${total}`);
report.push("");
report.push(`Content Category Agreement:`);
report.push(` 3/3 unanimous: ${catAgree3}/${total} (${((catAgree3/total)*100).toFixed(1)}%)`);
report.push(` 2/3 majority: ${catAgree2}/${total} (${((catAgree2/total)*100).toFixed(1)}%)`);
report.push(` All disagree: ${catDisagreeAll}/${total} (${((catDisagreeAll/total)*100).toFixed(1)}%)`);
report.push("");
report.push(`Specificity Level Agreement:`);
report.push(` 3/3 unanimous: ${specAgree3}/${total} (${((specAgree3/total)*100).toFixed(1)}%)`);
report.push(` 2/3 majority: ${specAgree2}/${total} (${((specAgree2/total)*100).toFixed(1)}%)`);
report.push(` All disagree: ${specDisagreeAll}/${total} (${((specDisagreeAll/total)*100).toFixed(1)}%)`);
report.push("");
report.push(`Both dimensions 3/3: ${bothAgree3}/${total} (${((bothAgree3/total)*100).toFixed(1)}%)`);
report.push(`Consensus (2/3+ on both): ${total - catDisagreeAll}/${total} (${(((total - catDisagreeAll)/total)*100).toFixed(1)}%)`);
// Specificity spread (mean absolute deviation across 3 models per paragraph)
const spreads: number[] = [];
for (const [, anns] of byParagraph) {
if (anns.length !== 3) continue;
const specs = anns.map((a) => a.specificity_level);
const mean = specs.reduce((s, v) => s + v, 0) / specs.length;
const mad = specs.reduce((s, v) => s + Math.abs(v - mean), 0) / specs.length;
spreads.push(mad);
}
const meanSpread = spreads.reduce((s, v) => s + v, 0) / spreads.length;
const maxSpread = Math.max(...spreads);
report.push(`\nSpecificity spread (MAD): mean=${meanSpread.toFixed(3)}, max=${maxSpread.toFixed(3)}`);
report.push(` Spread=0 (perfect): ${spreads.filter(s => s === 0).length}/${total} (${((spreads.filter(s => s === 0).length/total)*100).toFixed(1)}%)`);
report.push(` Spread≤0.33 (1 off): ${spreads.filter(s => s <= 0.34).length}/${total}`);
report.push(` Spread>0.67 (2+ off): ${spreads.filter(s => s > 0.67).length}/${total}`);
// Pairwise agreement (category + specificity)
report.push("\nPairwise agreement:");
for (let i = 0; i < STAGE1_MODELS.length; i++) {
for (let j = i + 1; j < STAGE1_MODELS.length; j++) {
let catAgree = 0, specAgree = 0, count = 0;
for (const [, anns] of byParagraph) {
const a = anns.find((r) => r.modelId === STAGE1_MODELS[i]);
const b = anns.find((r) => r.modelId === STAGE1_MODELS[j]);
if (a && b) {
count++;
if (a.content_category === b.content_category) catAgree++;
if (a.specificity_level === b.specificity_level) specAgree++;
}
}
const short = (m: string) => m.split("/")[1]!;
report.push(` ${short(STAGE1_MODELS[i]!)} × ${short(STAGE1_MODELS[j]!)}: cat=${((catAgree/count)*100).toFixed(1)}%, spec=${((specAgree/count)*100).toFixed(1)}%`);
}
}
// Category confusion matrix (which categories get mixed up)
report.push("\n\n═══ CATEGORY DISAGREEMENT PATTERNS ═══");
const catConfusion = new Map<string, number>();
for (const [, anns] of byParagraph) {
if (anns.length !== 3) continue;
const cats = anns.map((a) => a.content_category).sort();
const unique = new Set(cats);
if (unique.size > 1) {
const key = [...unique].sort().join(" ↔ ");
catConfusion.set(key, (catConfusion.get(key) ?? 0) + 1);
}
}
for (const [pair, count] of [...catConfusion.entries()].sort((a, b) => b[1] - a[1])) {
report.push(` ${pair}: ${count}`);
}
// Specificity disagreement patterns
report.push("\n═══ SPECIFICITY DISAGREEMENT PATTERNS ═══");
const specConfusion = new Map<string, number>();
for (const [, anns] of byParagraph) {
if (anns.length !== 3) continue;
const specs = anns.map((a) => a.specificity_level).sort();
const unique = new Set(specs);
if (unique.size > 1) {
const key = specs.join(",");
specConfusion.set(key, (specConfusion.get(key) ?? 0) + 1);
}
}
for (const [pattern, count] of [...specConfusion.entries()].sort((a, b) => b[1] - a[1])) {
report.push(` [${pattern}]: ${count}`);
}
// Per-category specificity agreement
report.push("\n═══ PER-CATEGORY AGREEMENT (where all 3 agree on category) ═══");
const catSpecAgreement = new Map<string, { total: number; specAgree: number }>();
for (const [, anns] of byParagraph) {
if (anns.length !== 3) continue;
const cats = anns.map((a) => a.content_category);
if (new Set(cats).size !== 1) continue;
const cat = cats[0]!;
const entry = catSpecAgreement.get(cat) ?? { total: 0, specAgree: 0 };
entry.total++;
if (new Set(anns.map((a) => a.specificity_level)).size === 1) entry.specAgree++;
catSpecAgreement.set(cat, entry);
}
for (const [cat, { total, specAgree }] of [...catSpecAgreement.entries()].sort((a, b) => b[1].total - a[1].total)) {
report.push(` ${cat.padEnd(28)} spec agree: ${specAgree}/${total} (${((specAgree/total)*100).toFixed(1)}%)`);
}
// Per-paragraph detail for disagreements
report.push("\n\n═══ DISAGREEMENT DETAILS ═══");
for (const [pid, anns] of byParagraph) {
if (anns.length !== 3) continue;
const cats = new Set(anns.map((a) => a.content_category));
const specs = new Set(anns.map((a) => a.specificity_level));
if (cats.size === 1 && specs.size === 1) continue; // skip agreements
const paragraph = sample.find((p) => p.id === pid);
const textPreview = paragraph ? paragraph.text.slice(0, 200) + (paragraph.text.length > 200 ? "..." : "") : "(not found)";
report.push(`\n--- ${pid} ---`);
report.push(`Company: ${paragraph?.filing.companyName ?? "?"}`);
report.push(`Text: ${textPreview}`);
for (const a of anns) {
const short = a.modelId.split("/")[1]!;
report.push(` ${short.padEnd(30)}${a.content_category.padEnd(25)} spec=${a.specificity_level} (cat:${a.category_confidence}, spec:${a.specificity_confidence})`);
report.push(` ${" ".repeat(30)} ${a.reasoning}`);
}
}
// Cost projections
report.push("\n\n═══ COST PROJECTIONS (50K paragraphs) ═══");
for (const modelId of STAGE1_MODELS) {
const modelAnns = allResults.filter((r) => r.modelId === modelId);
if (modelAnns.length === 0) continue;
const costPerPara = modelAnns.reduce((s, r) => s + r.costUsd, 0) / modelAnns.length;
const projected = costPerPara * 50000;
report.push(` ${modelId}: $${projected.toFixed(2)} ($${costPerPara.toFixed(6)}/para)`);
}
const totalCostPerPara = totalCost / (sample.length * STAGE1_MODELS.length);
const projectedTotal = totalCostPerPara * 50000 * 3;
report.push(` TOTAL Stage 1 (all 3 models): ~$${projectedTotal.toFixed(2)}`);
// Estimated judge cost (~17% disagreement rate from codebook)
const disagreeRate = (catAgree2 + catDisagreeAll) / total;
report.push(`\n Observed disagreement rate: ${(disagreeRate * 100).toFixed(1)}%`);
report.push(` Estimated Stage 2 judge calls: ~${Math.round(50000 * disagreeRate).toLocaleString()}`);
report.push(` (Judge cost depends on Sonnet 4.6 pricing — see OpenRouter)`);
const reportText = report.join("\n");
await writeFile(`${PILOT_DIR}/pilot-report-${PROMPT_VERSION}.txt`, reportText);
// Print to stdout
console.log(reportText);
}
main().catch((err) => {
console.error(err);
process.exit(1);
});