2026-04-05 00:55:53 -04:00

224 lines
6.7 KiB
TypeScript

process.env.DATABASE_URL ??=
"postgresql://sec_cybert:sec_cybert@localhost:5432/sec_cybert";
import { readFile } from "node:fs/promises";
import { db } from "../db";
import * as schema from "../db/schema";
async function readJsonl<T = unknown>(path: string): Promise<T[]> {
const text = await readFile(path, "utf-8");
return text
.split("\n")
.filter((l) => l.trim())
.map((l) => JSON.parse(l) as T);
}
async function readJson<T = unknown>(path: string): Promise<T> {
const text = await readFile(path, "utf-8");
return JSON.parse(text) as T;
}
interface ParagraphRow {
id: string;
text: string;
textHash: string;
wordCount: number;
paragraphIndex: number;
filing: {
companyName: string;
cik: string;
ticker: string;
filingType: string;
filingDate: string;
fiscalYear: number;
accessionNumber: string;
secItem: string;
};
}
interface AnnotationRow {
paragraphId: string;
label: {
content_category: string;
specificity_level: number;
category_confidence: string;
specificity_confidence: string;
reasoning: string;
};
provenance: Record<string, unknown>;
}
function computeConsensus(annotations: AnnotationRow[]): {
category: string;
specificity: number;
method: string;
confidence: number;
} {
// Majority vote for category
const catCounts = new Map<string, number>();
for (const a of annotations) {
const cat = a.label.content_category;
catCounts.set(cat, (catCounts.get(cat) ?? 0) + 1);
}
let maxCatCount = 0;
let majorityCategory = "";
for (const [cat, count] of catCounts) {
if (count > maxCatCount) {
maxCatCount = count;
majorityCategory = cat;
}
}
// Majority vote for specificity
const specCounts = new Map<number, number>();
for (const a of annotations) {
const spec = a.label.specificity_level;
specCounts.set(spec, (specCounts.get(spec) ?? 0) + 1);
}
let maxSpecCount = 0;
let majoritySpecificity = 0;
for (const [spec, count] of specCounts) {
if (count > maxSpecCount) {
maxSpecCount = count;
majoritySpecificity = spec;
}
}
const total = annotations.length;
const allAgreeCategory = maxCatCount === total;
const allAgreeSpecificity = maxSpecCount === total;
const method =
allAgreeCategory && allAgreeSpecificity ? "unanimous" : "majority";
const agreedOnBoth = annotations.filter(
(a) =>
a.label.content_category === majorityCategory &&
a.label.specificity_level === majoritySpecificity,
).length;
const confidence = agreedOnBoth / total;
return {
category: majorityCategory,
specificity: majoritySpecificity,
method,
confidence,
};
}
async function main() {
const ROOT = "/home/joey/Documents/sec-cyBERT";
const PARAGRAPHS_PATH =
process.env.SEED_PARAGRAPHS_PATH ??
`${ROOT}/data/paragraphs/paragraphs-clean.jsonl`;
const ANNOTATIONS_PATH =
process.env.SEED_ANNOTATIONS_PATH ??
`${ROOT}/data/annotations/stage1.jsonl`;
const HOLDOUT_IDS_PATH =
process.env.SEED_HOLDOUT_IDS_PATH ??
`${ROOT}/data/gold/v2-holdout-ids.json`;
// 1. Load holdout IDs (the 1,200 v2 paragraphs)
console.log("Loading v2 holdout IDs...");
const holdoutIds = new Set(await readJson<string[]>(HOLDOUT_IDS_PATH));
console.log(` ${holdoutIds.size} holdout IDs`);
// 2. Read annotations and compute consensus (only for holdout paragraphs)
console.log("Reading annotations...");
const annotations = await readJsonl<AnnotationRow>(ANNOTATIONS_PATH);
console.log(` ${annotations.length} total annotations loaded`);
const annotationsByParagraph = new Map<string, AnnotationRow[]>();
for (const a of annotations) {
if (!holdoutIds.has(a.paragraphId)) continue;
const group = annotationsByParagraph.get(a.paragraphId);
if (group) {
group.push(a);
} else {
annotationsByParagraph.set(a.paragraphId, [a]);
}
}
console.log(
` ${annotationsByParagraph.size} holdout paragraphs have annotations`,
);
const consensusMap = new Map<
string,
ReturnType<typeof computeConsensus>
>();
for (const [pid, anns] of annotationsByParagraph) {
consensusMap.set(pid, computeConsensus(anns));
}
// 3. Read paragraphs, filter to holdout only, and insert
console.log("Reading paragraphs...");
const allParagraphs = await readJsonl<ParagraphRow>(PARAGRAPHS_PATH);
const paragraphs = allParagraphs.filter((p) => holdoutIds.has(p.id));
console.log(
` ${allParagraphs.length} total → ${paragraphs.length} holdout paragraphs`,
);
if (paragraphs.length !== holdoutIds.size) {
console.warn(
` WARNING: expected ${holdoutIds.size} holdout paragraphs but found ${paragraphs.length} in paragraphs file`,
);
}
const BATCH_SIZE = 500;
for (let i = 0; i < paragraphs.length; i += BATCH_SIZE) {
const batch = paragraphs.slice(i, i + BATCH_SIZE);
const rows = batch.map((p) => {
const consensus = consensusMap.get(p.id);
return {
id: p.id,
text: p.text,
wordCount: p.wordCount,
paragraphIndex: p.paragraphIndex,
companyName: p.filing.companyName,
cik: p.filing.cik,
ticker: p.filing.ticker || null,
filingType: p.filing.filingType,
filingDate: p.filing.filingDate,
fiscalYear: p.filing.fiscalYear,
accessionNumber: p.filing.accessionNumber,
secItem: p.filing.secItem,
stage1Category: consensus?.category ?? null,
stage1Specificity: consensus?.specificity ?? null,
stage1Method: consensus?.method ?? null,
stage1Confidence: consensus?.confidence ?? null,
};
});
await db
.insert(schema.paragraphs)
.values(rows)
.onConflictDoNothing();
const progress = Math.min(i + BATCH_SIZE, paragraphs.length);
console.log(` Inserted ${progress}/${paragraphs.length} paragraphs`);
}
// 4. Create annotator accounts (joey is admin, no separate admin account)
console.log("Creating annotator accounts...");
const annotatorAccounts = [
{ id: "aaryan", displayName: "Aaryan", password: "sec-cybert" },
{ id: "anuj", displayName: "Anuj", password: "sec-cybert" },
{ id: "meghan", displayName: "Meghan", password: "sec-cybert" },
{ id: "xander", displayName: "Xander", password: "sec-cybert" },
{ id: "elisabeth", displayName: "Elisabeth", password: "sec-cybert" },
{ id: "joey", displayName: "Joey", password: "sec-cybert" },
];
await db
.insert(schema.annotators)
.values(annotatorAccounts)
.onConflictDoNothing();
console.log(` Created ${annotatorAccounts.length} annotator accounts`);
console.log("Seed complete.");
process.exit(0);
}
main().catch((err) => {
console.error("Seed failed:", err);
process.exit(1);
});