SEC-cyBERT/ts/scripts/stage1-run.ts
2026-03-28 23:44:37 -04:00

159 lines
6.1 KiB
TypeScript
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/**
* Stage 1 production run: annotate all paragraphs with 3 models.
*
* Features:
* - Crash-safe: appends one JSONL line per annotation, resumes on restart
* - All 3 models run in parallel per paragraph (not sequentially)
* - Real-time progress + cost logging
* - Configurable concurrency (total concurrent API calls)
*
* Usage:
* bun ts/scripts/stage1-run.ts [--concurrency 30] [--input ../data/paragraphs/training.jsonl]
*
* Output:
* ../data/annotations/stage1.jsonl — one Annotation per (paragraph, model) pair
*/
import { readJsonl, readJsonlRaw, appendJsonl } from "../src/lib/jsonl.ts";
import { Paragraph } from "@sec-cybert/schemas/paragraph.ts";
import { STAGE1_MODELS } from "../src/lib/openrouter.ts";
import { annotateParagraph, type AnnotateOpts } from "../src/label/annotate.ts";
import { PROMPT_VERSION } from "../src/label/prompts.ts";
import { v4 as uuidv4 } from "uuid";
import { mkdir } from "node:fs/promises";
import { existsSync } from "node:fs";
import pLimit from "p-limit";
// ── Args ────────────────────────────────────────────────────────────────
const args = process.argv.slice(2);
function flag(name: string): string | undefined {
const idx = args.indexOf(`--${name}`);
return idx === -1 ? undefined : args[idx + 1];
}
const CONCURRENCY = parseInt(flag("concurrency") ?? "30", 10);
const INPUT_PATH = flag("input") ?? new URL("../../data/paragraphs/training.jsonl", import.meta.url).pathname;
const OUTPUT_DIR = new URL("../../data/annotations", import.meta.url).pathname;
const OUTPUT_PATH = `${OUTPUT_DIR}/stage1.jsonl`;
// ── Main ────────────────────────────────────────────────────────────────
async function main() {
if (!existsSync(OUTPUT_DIR)) await mkdir(OUTPUT_DIR, { recursive: true });
// Load training data
console.error(`Loading paragraphs from ${INPUT_PATH}...`);
const { records: paragraphs, skipped } = await readJsonl(INPUT_PATH, Paragraph);
if (skipped > 0) console.error(` ⚠ Skipped ${skipped} invalid lines`);
console.error(` Loaded ${paragraphs.length} paragraphs`);
console.error(` Models: ${STAGE1_MODELS.join(", ")}`);
console.error(` Prompt: ${PROMPT_VERSION}`);
console.error(` Concurrency: ${CONCURRENCY}`);
const totalJobs = paragraphs.length * STAGE1_MODELS.length;
console.error(` Total annotations needed: ${totalJobs.toLocaleString()}`);
// Load existing results for resume
const doneKeys = new Set<string>();
let resumedCost = 0;
if (existsSync(OUTPUT_PATH)) {
const { records: existing, skipped: badLines } = await readJsonlRaw(OUTPUT_PATH);
for (const rec of existing) {
const r = rec as { paragraphId?: string; provenance?: { modelId?: string; costUsd?: number } };
if (r.paragraphId && r.provenance?.modelId) {
doneKeys.add(`${r.paragraphId}|${r.provenance.modelId}`);
resumedCost += r.provenance.costUsd ?? 0;
}
}
if (doneKeys.size > 0) {
console.error(` Resuming: ${doneKeys.size.toLocaleString()} annotations already done ($${resumedCost.toFixed(2)}), ${(totalJobs - doneKeys.size).toLocaleString()} remaining`);
}
if (badLines > 0) console.error(`${badLines} corrupted lines in output (skipped)`);
}
if (doneKeys.size >= totalJobs) {
console.error(" ✓ All annotations already complete!");
return;
}
// Build job list: (paragraph, model) pairs not yet done
type Job = { paragraph: Paragraph; modelId: string };
const jobs: Job[] = [];
for (const paragraph of paragraphs) {
for (const modelId of STAGE1_MODELS) {
if (!doneKeys.has(`${paragraph.id}|${modelId}`)) {
jobs.push({ paragraph, modelId });
}
}
}
console.error(` Jobs to run: ${jobs.length.toLocaleString()}`);
// Run with concurrency limiter
const runId = uuidv4();
const limit = pLimit(CONCURRENCY);
let completed = doneKeys.size;
let failed = 0;
let sessionCost = 0;
const startTime = Date.now();
// Progress logging
const logInterval = setInterval(() => {
const elapsed = (Date.now() - startTime) / 1000;
const rate = (completed - doneKeys.size) / elapsed;
const remaining = totalJobs - completed;
const eta = rate > 0 ? remaining / rate : Infinity;
const etaMin = Math.round(eta / 60);
process.stderr.write(
`\r ${completed.toLocaleString()}/${totalJobs.toLocaleString()} (${((completed / totalJobs) * 100).toFixed(1)}%)` +
` $${(resumedCost + sessionCost).toFixed(2)}` +
` ${rate.toFixed(1)}/s` +
` ETA ${etaMin}m` +
` ${failed} failed `,
);
}, 2000);
const tasks = jobs.map((job) =>
limit(async () => {
const opts: AnnotateOpts = {
modelId: job.modelId,
stage: "stage1",
runId,
promptVersion: PROMPT_VERSION,
reasoningEffort: "low",
};
try {
const ann = await annotateParagraph(job.paragraph, opts);
await appendJsonl(OUTPUT_PATH, ann);
sessionCost += ann.provenance.costUsd;
completed++;
} catch (error) {
failed++;
const msg = error instanceof Error ? error.message : String(error);
// Log failures to stderr but don't crash — we can retry on next run
console.error(`\n ✖ ${job.modelId} × ${job.paragraph.id}: ${msg}`);
}
}),
);
await Promise.all(tasks);
clearInterval(logInterval);
const elapsed = ((Date.now() - startTime) / 1000).toFixed(0);
console.error(
`\n\n ═══ COMPLETE ═══` +
`\n Annotations: ${completed.toLocaleString()}/${totalJobs.toLocaleString()}` +
`\n Failed: ${failed}` +
`\n Session cost: $${sessionCost.toFixed(2)}` +
`\n Total cost: $${(resumedCost + sessionCost).toFixed(2)}` +
`\n Wall time: ${elapsed}s` +
`\n Output: ${OUTPUT_PATH}`,
);
if (failed > 0) {
console.error(`\n ⚠ ${failed} failures — re-run this script to retry them.`);
}
}
main().catch((err) => {
console.error(err);
process.exit(1);
});