SEC-cyBERT/ts/src/label/annotate.ts
2026-03-28 20:39:36 -04:00

159 lines
4.7 KiB
TypeScript

import { generateText, Output } from "ai";
import { openrouter, providerOf } from "../lib/openrouter.ts";
import { LabelOutputRaw, toLabelOutput } from "../schemas/label.ts";
import type { Annotation } from "../schemas/annotation.ts";
import type { Paragraph } from "../schemas/paragraph.ts";
import { SYSTEM_PROMPT, buildUserPrompt, buildJudgePrompt, PROMPT_VERSION } from "./prompts.ts";
import { withRetry } from "../lib/retry.ts";
/** OpenRouter reasoning effort levels. */
type ReasoningEffort = "low" | "medium" | "high";
/** Build providerOptions for OpenRouter with reasoning + usage tracking. */
function buildProviderOptions(effort: ReasoningEffort) {
return {
openrouter: {
reasoning: { effort },
usage: { include: true as const },
},
};
}
/** Extract cost from the result, checking both raw usage and providerMetadata. */
function extractCost(result: { usage: unknown; providerMetadata?: unknown }): number {
// Primary: raw usage.cost (always present in our smoke test)
const raw = result.usage as { raw?: { cost?: number } };
if (raw.raw?.cost !== undefined) return raw.raw.cost;
// Fallback: providerMetadata.openrouter.usage.cost
const meta = result.providerMetadata as
| { openrouter?: { usage?: { cost?: number } } }
| undefined;
return meta?.openrouter?.usage?.cost ?? 0;
}
export interface AnnotateOpts {
modelId: string;
stage: "stage1" | "stage2-judge" | "benchmark";
runId: string;
promptVersion?: string;
reasoningEffort?: ReasoningEffort;
}
/**
* Annotate a single paragraph with one model.
* Returns a full Annotation record with provenance.
*/
export async function annotateParagraph(
paragraph: Paragraph,
opts: AnnotateOpts,
): Promise<Annotation> {
const {
modelId,
stage,
runId,
promptVersion = PROMPT_VERSION,
reasoningEffort = "low",
} = opts;
const requestedAt = new Date().toISOString();
const start = Date.now();
const result = await withRetry(
() =>
generateText({
model: openrouter(modelId),
output: Output.object({ schema: LabelOutputRaw }),
system: SYSTEM_PROMPT,
prompt: buildUserPrompt(paragraph),
temperature: 0,
providerOptions: buildProviderOptions(reasoningEffort),
abortSignal: AbortSignal.timeout(120_000),
}),
{ label: `${modelId}:${paragraph.id}` },
);
const latencyMs = Date.now() - start;
const rawOutput = result.output;
if (!rawOutput) throw new Error(`No output from ${modelId} for ${paragraph.id}`);
return {
paragraphId: paragraph.id,
label: toLabelOutput(rawOutput),
provenance: {
modelId,
provider: providerOf(modelId),
generationId: result.response?.id ?? "unknown",
stage,
runId,
promptVersion,
inputTokens: result.usage.inputTokens ?? 0,
outputTokens: result.usage.outputTokens ?? 0,
reasoningTokens: result.usage.outputTokenDetails?.reasoningTokens ?? 0,
costUsd: extractCost(result),
latencyMs,
requestedAt,
},
};
}
export interface JudgeOpts {
runId: string;
promptVersion?: string;
}
/**
* Run the Stage 2 judge on a paragraph where Stage 1 models disagreed.
* Receives the paragraph + all 3 prior annotations in randomized order.
*/
export async function judgeParagraph(
paragraph: Paragraph,
priorAnnotations: Array<{
content_category: string;
specificity_level: number;
reasoning: string;
}>,
opts: JudgeOpts,
): Promise<Annotation> {
const { runId, promptVersion = PROMPT_VERSION } = opts;
const modelId = "anthropic/claude-sonnet-4.6";
const requestedAt = new Date().toISOString();
const start = Date.now();
const result = await withRetry(
() =>
generateText({
model: openrouter(modelId),
output: Output.object({ schema: LabelOutputRaw }),
system: SYSTEM_PROMPT,
prompt: buildJudgePrompt(paragraph, priorAnnotations),
temperature: 0,
providerOptions: buildProviderOptions("medium"),
abortSignal: AbortSignal.timeout(240_000),
}),
{ label: `judge:${paragraph.id}` },
);
const latencyMs = Date.now() - start;
const rawOutput = result.output;
if (!rawOutput) throw new Error(`No judge output for ${paragraph.id}`);
return {
paragraphId: paragraph.id,
label: toLabelOutput(rawOutput),
provenance: {
modelId,
provider: providerOf(modelId),
generationId: result.response?.id ?? "unknown",
stage: "stage2-judge",
runId,
promptVersion,
inputTokens: result.usage.inputTokens ?? 0,
outputTokens: result.usage.outputTokens ?? 0,
reasoningTokens: result.usage.outputTokenDetails?.reasoningTokens ?? 0,
costUsd: extractCost(result),
latencyMs,
requestedAt,
},
};
}