SEC-cyBERT/labelapp/lib/sampling.ts
2026-03-29 00:32:24 -04:00

278 lines
7.7 KiB
TypeScript

export interface ParagraphWithVotes {
id: string;
stage1Category: string | null;
stage1Specificity: number | null;
/** Raw category votes from stage1 annotations */
categoryVotes: string[];
/** Raw specificity votes from stage1 annotations */
specificityVotes: number[];
}
export interface StratumConfig {
name: string;
count: number;
filter: (p: ParagraphWithVotes) => boolean;
}
export interface SamplingConfig {
total: number;
strata: StratumConfig[];
}
/**
* Shuffle an array in place using Fisher-Yates.
*/
function shuffle<T>(arr: T[]): T[] {
for (let i = arr.length - 1; i > 0; i--) {
const j = Math.floor(Math.random() * (i + 1));
[arr[i], arr[j]] = [arr[j], arr[i]];
}
return arr;
}
/**
* Check if a paragraph's annotations have a split between two specific categories.
* A "split" means at least one vote for each of the two categories.
*/
function hasCategorySplit(
p: ParagraphWithVotes,
catA: string,
catB: string,
): boolean {
return (
p.categoryVotes.includes(catA) && p.categoryVotes.includes(catB)
);
}
/**
* Check if a paragraph's specificity votes span between two specific values.
*/
function hasSpecificitySplit(
p: ParagraphWithVotes,
specA: number,
specB: number,
): boolean {
return (
p.specificityVotes.includes(specA) &&
p.specificityVotes.includes(specB)
);
}
/**
* Proportional stratified random sampling from category x specificity cells.
* Fills the remaining `count` slots proportionally based on cell sizes.
*/
function proportionalSample(
eligible: ParagraphWithVotes[],
count: number,
): string[] {
// Group by category x specificity
const cells = new Map<string, ParagraphWithVotes[]>();
for (const p of eligible) {
const key = `${p.stage1Category ?? "unknown"}|${p.stage1Specificity ?? 0}`;
const cell = cells.get(key);
if (cell) {
cell.push(p);
} else {
cells.set(key, [p]);
}
}
const total = eligible.length;
const selected: string[] = [];
// First pass: allocate floor proportions
const cellAllocations: { key: string; allocated: number; remainder: number }[] = [];
let allocated = 0;
for (const [key, members] of cells) {
const exact = (members.length / total) * count;
const floor = Math.floor(exact);
cellAllocations.push({ key, allocated: floor, remainder: exact - floor });
allocated += floor;
}
// Second pass: distribute remainder by largest remainders
let remaining = count - allocated;
cellAllocations.sort((a, b) => b.remainder - a.remainder);
for (const cell of cellAllocations) {
if (remaining <= 0) break;
cell.allocated++;
remaining--;
}
// Sample from each cell
for (const { key, allocated: cellCount } of cellAllocations) {
const members = cells.get(key)!;
shuffle(members);
const take = Math.min(cellCount, members.length);
for (let i = 0; i < take; i++) {
selected.push(members[i].id);
}
}
return selected;
}
/**
* Build the default sampling config for 1,200 paragraphs.
*/
export function defaultSamplingConfig(): SamplingConfig {
return {
total: 1200,
strata: [
{
name: "Mgmt↔RMP split votes",
count: 120,
filter: (p) =>
hasCategorySplit(p, "Management Role", "Risk Management Process"),
},
{
name: "None/Other↔Strategy splits",
count: 80,
filter: (p) =>
hasCategorySplit(p, "None/Other", "Strategy Integration"),
},
{
name: "Spec [3,4] splits",
count: 80,
filter: (p) => hasSpecificitySplit(p, 3, 4),
},
{
name: "Board↔Mgmt splits",
count: 80,
filter: (p) =>
hasCategorySplit(p, "Board Governance", "Management Role"),
},
],
};
}
/**
* Run stratified sampling. Returns selected paragraph IDs.
*
* Process:
* 1. For each stratum, filter eligible paragraphs, randomly select `count`
* 2. Already-selected paragraphs are excluded from later strata
* 3. "Rare category guarantee": ensure >= 15 per category, extra for Incident Disclosure
* 4. Final fill: proportional stratified random from category x specificity cells
*/
export function stratifiedSample(
paragraphs: ParagraphWithVotes[],
config: SamplingConfig,
): string[] {
const selected = new Set<string>();
// Phase 1: Named strata (split-vote strata)
for (const stratum of config.strata) {
const eligible = paragraphs.filter(
(p) => !selected.has(p.id) && stratum.filter(p),
);
shuffle(eligible);
const take = Math.min(stratum.count, eligible.length);
for (let i = 0; i < take; i++) {
selected.add(eligible[i].id);
}
console.log(
` Stratum "${stratum.name}": wanted ${stratum.count}, eligible ${eligible.length}, selected ${take}`,
);
}
// Phase 2: Rare category guarantee (120 slots, >= 15 per category)
const RARE_GUARANTEE_TOTAL = 120;
const MIN_PER_CATEGORY = 15;
const rareStartSize = selected.size;
// Find all categories
const categoryCounts = new Map<string, ParagraphWithVotes[]>();
for (const p of paragraphs) {
if (selected.has(p.id) || !p.stage1Category) continue;
const cat = p.stage1Category;
const bucket = categoryCounts.get(cat);
if (bucket) {
bucket.push(p);
} else {
categoryCounts.set(cat, [p]);
}
}
// Count how many of each category are already selected
const selectedByCat = new Map<string, number>();
for (const id of selected) {
const p = paragraphs.find((pp) => pp.id === id);
if (p?.stage1Category) {
selectedByCat.set(
p.stage1Category,
(selectedByCat.get(p.stage1Category) ?? 0) + 1,
);
}
}
// Top up categories that have fewer than MIN_PER_CATEGORY
let rareAdded = 0;
const allCategories = new Set<string>();
for (const p of paragraphs) {
if (p.stage1Category) allCategories.add(p.stage1Category);
}
// Sort categories by current count ascending so rarest get filled first
const sortedCats = [...allCategories].sort(
(a, b) =>
(selectedByCat.get(a) ?? 0) - (selectedByCat.get(b) ?? 0),
);
for (const cat of sortedCats) {
if (rareAdded >= RARE_GUARANTEE_TOTAL) break;
const current = selectedByCat.get(cat) ?? 0;
if (current >= MIN_PER_CATEGORY) continue;
const need = MIN_PER_CATEGORY - current;
const eligible = (categoryCounts.get(cat) ?? []).filter(
(p) => !selected.has(p.id),
);
shuffle(eligible);
const take = Math.min(need, eligible.length, RARE_GUARANTEE_TOTAL - rareAdded);
for (let i = 0; i < take; i++) {
selected.add(eligible[i].id);
rareAdded++;
}
}
// Give extra slots to "Incident Disclosure" if budget remains
if (rareAdded < RARE_GUARANTEE_TOTAL) {
const incidentEligible = (
categoryCounts.get("Incident Disclosure") ?? []
).filter((p) => !selected.has(p.id));
shuffle(incidentEligible);
const take = Math.min(
RARE_GUARANTEE_TOTAL - rareAdded,
incidentEligible.length,
);
for (let i = 0; i < take; i++) {
selected.add(incidentEligible[i].id);
rareAdded++;
}
}
console.log(
` Rare category guarantee: added ${selected.size - rareStartSize} (budget ${RARE_GUARANTEE_TOTAL})`,
);
// Phase 3: Proportional stratified random fill
const remaining = config.total - selected.size;
if (remaining > 0) {
const eligible = paragraphs.filter(
(p) => !selected.has(p.id) && p.stage1Category != null,
);
const filled = proportionalSample(eligible, remaining);
for (const id of filled) {
selected.add(id);
}
console.log(
` Proportional fill: added ${filled.length} (target ${remaining})`,
);
}
console.log(` Total selected: ${selected.size}`);
return [...selected];
}