diff --git a/docs/NARRATIVE.md b/docs/NARRATIVE.md index 53a12db..f02f261 100644 --- a/docs/NARRATIVE.md +++ b/docs/NARRATIVE.md @@ -795,7 +795,24 @@ The TAPT corpus is 72K Item 1C paragraphs (~10M tokens) — 50x smaller than the **Whole-word masking and tokenization:** Whole-word masking requires `offset_mapping` from the tokenizer to determine word boundaries. This is incompatible with DAPT's concatenate-and-chunk approach (which destroys offset_mapping by merging documents). TAPT tokenizes each paragraph individually with truncation, preserving offset_mapping. The data collator handles dynamic padding per batch. This is a different code path from DAPT's concatenation, but the data justifies it: paragraphs are natural self-contained units, unlike DAPT's long filings that must be chunked. -**Training time:** ~2,139 steps/epoch × 5 epochs = ~10,695 total steps. At ~1.84 it/s on the 3090, ~1.6 hours total. +**Training time:** ~2,139 steps/epoch × 5 epochs = ~10,695 total steps. 50 minutes on the RTX 3090 at ~3.56 steps/s (averaged over full run including torch.compile warmup). + +### TAPT Results + +| Metric | Value | +|--------|-------| +| Epochs | 5 | +| Total steps | 10,695 | +| Training time | 50 minutes | +| Initial loss | 1.46 | +| Final train loss (avg) | 0.6428 | +| Final eval loss | 1.0754 | +| Final perplexity | 2.11 | +| Throughput | 114 samples/s, 3.56 steps/s | + +Loss dropped from 1.46 → 1.08 over 5 epochs. For comparison, DAPT ended at eval loss 0.72 with standard subword masking at the same 30% rate — the gap reflects the harder whole-word masking objective (no subword hints), not a weaker model. The model learns to predict masked domain terms ("CISO", "materiality", "tabletop") from surrounding paragraph context alone, which is exactly the inductive bias TAPT is designed to create. + +The TAPT checkpoint is saved at `checkpoints/tapt/modernbert-large/final/` and is ready for fine-tuning. ### TAPT Launch — Whole-Word Masking Bugs @@ -851,8 +868,8 @@ Only nano's portion ($21.24) of the first run was wasted — the gemini and grok | Documentation + narrative | ~2h | Codebook updates, narrative writing, technical guide updates | | Labelapp build + infrastructure | ~8h | Monorepo restructure, Next.js app, quiz/warmup/labeling flows, BIBD assignment, sampling, Docker deployment, timer + migration infrastructure | | DAPT pre-training | ~14.5h GPU | 1 epoch on 500M tokens, RTX 3090. Two sessions (resumed from checkpoint-1280). | -| TAPT debugging + pre-training | ~2h dev + ~1.6h GPU | 4 bugs in transformers whole-word masking + Python 3.14 rollback. Training: 5 epochs on 72K paragraphs. | -| **Total to date** | **~53h** | Includes ~16h GPU time | +| TAPT debugging + pre-training | ~2h dev + ~50min GPU | 4 bugs in transformers whole-word masking + Python 3.14 rollback. Training: 5 epochs on 72K paragraphs, 50 min. | +| **Total to date** | **~52h** | Includes ~15.3h GPU time | ### Remaining Work (estimated) diff --git a/docs/STATUS.md b/docs/STATUS.md index 0171857..9fb544a 100644 --- a/docs/STATUS.md +++ b/docs/STATUS.md @@ -21,7 +21,8 @@ - [x] DAPT corpus: 14,568 documents, ~1.056B tokens, cleaned (XBRL, URLs, page numbers stripped) - [x] DAPT training complete: eval loss 0.7250, perplexity 1.65. 1 epoch on 500M tokens, ~14.5h on RTX 3090. - [x] DAPT checkpoint at `checkpoints/dapt/modernbert-large/final/` -- [x] TAPT config: 5 epochs, whole-word masking, seq_len=512, batch=32 +- [x] TAPT training complete: eval loss 1.0754, perplexity 2.11. 5 epochs, whole-word masking, ~50 min on RTX 3090. Loss: 1.46 → 1.08. +- [x] TAPT checkpoint at `checkpoints/tapt/modernbert-large/final/` - [x] Custom `WholeWordMaskCollator` (upstream `transformers` collator broken for BPE tokenizers) - [x] Python 3.14 → 3.13 rollback (dill/datasets pickle incompatibility) - [x] Procedure documented in `docs/DAPT-PROCEDURE.md` @@ -30,16 +31,10 @@ - [x] `docs/DATA-QUALITY-AUDIT.md` — full audit with all patches and quality tiers - [x] `docs/EDGAR-FILING-GENERATORS.md` — 14 generators with signatures and quality profiles - [x] `docs/DAPT-PROCEDURE.md` — pre-flight checklist, commands, monitoring guide -- [x] `docs/NARRATIVE.md` — 11 phases documented through TAPT launch +- [x] `docs/NARRATIVE.md` — 11 phases documented through TAPT completion ## What's In Progress -### TAPT Training — Running -Training on 72K Item 1C paragraphs using DAPT checkpoint. 5 epochs, whole-word masking, seq_len=512, batch=32. Early loss: 1.46 → 1.40 (first 1% of training). Expected ~1.6h total on RTX 3090. Expecting final loss ~1.0-1.2. -```bash -bun run py:train dapt --config configs/tapt/modernbert.yaml -``` - ### Human Labeling (139/1,200) - 3 of 6 annotators started: 68 + 50 + 21 paragraphs completed - Deployed via labelapp with quiz gating + warmup @@ -82,7 +77,7 @@ Full GenAI benchmark (9 models) on 1,200 holdout. Comparison tables. Write-up. ## Parallel Tracks ``` -Track A (GPU): DAPT ✓ → TAPT (running) → Fine-tuning → Eval +Track A (GPU): DAPT ✓ → TAPT ✓ → Fine-tuning → Eval ↑ Track B (API): Judge v3 → Judge run ───────────┤ ↑ @@ -91,7 +86,7 @@ Track C (Human): Labeling (139/1200) → Gold set validation Track D (Code): Fine-tune pipeline build ───────┘ ``` -TAPT finishes in ~1.5h. Track D (fine-tune pipeline) can proceed now. Track B can start (prompt update) but production run waits for Track C. Everything converges at fine-tuning. +DAPT + TAPT complete. Track D (fine-tune pipeline) can proceed now. Track B can start (prompt update) but production run waits for Track C. Everything converges at fine-tuning. ## Key File Locations diff --git a/docs/reference/Capstone_assn_instructions.md b/docs/reference/Capstone_assn_instructions.md new file mode 100644 index 0000000..315574f --- /dev/null +++ b/docs/reference/Capstone_assn_instructions.md @@ -0,0 +1,449 @@ +# Capstone: Build a Business-Grade Text Classifier + +**Due:** April 23 by 12pm | **Points:** 35 + +--- + +## Team Assignment: Build a Business-Grade Text Classifier + +**Team size:** 5–6 students: at least half must be en rolled in COMP488 or BUSI488. + +One-sentence summary: Your team will build an end-to-end system that turns raw text (reviews, filings, reports, speeches, etc.) into reliable business labels—and you’ll compare two approaches: genAI labeling vs a fine-tuned specialist model. + +## Form Your Team + +**People > Groups > CAPSTONE : People** + +- Each team member must join an EXISTING project group for THEIR SECTION +- DO NOT create your own (additional) projects groups. +- Each group must satisfy the following: + - Have at least two students enrolled in COMP488 + - Have at least two students enrolled in BUSI488 +- You can have at most three students enrolled in COMP488 and/or BUSI488 in your team, conditional on: + - All other teams having at least two of each in their team (see above). + - If other teams have less than two COMP488 and BUSI488 students, you cannot join a team that already has this minimum requirement. +- **ALL TEAM MEMBERS MUST BE IN THE SAME SECTION (1 vs. 2)** + +## Why This Assignment Matters (Business Reality) + +### Businesses Are Drowning in Text + +- Customer reviews, support chats, emails +- Social posts, news, press releases +- 10-K/10-Q filings, earnings calls, ESG reports +- Policies, regulatory communications, safety notices +- Research abstracts, patents, job postings + +### Text Classification Enables Business Value + +Text classification is one of the most common, high-ROI uses of AI because it converts messy language into structured signals that businesses can: + +- Count — How many "billing issues" this week? +- Trend — Are "delivery failures" rising? +- Segment — Which regions complain about what? +- Route — Send messages/tickets to the right team automatically +- Flag risk — Safety/adverse events, compliance issues, reputational threats +- Inform strategy — Competitor moves, pricing tactics, CX weaknesses +- Input into (predictive) models — As independent variables + +### Why Not Manual Reading? + +Firms don't just "read it manually" because volume is too high, humans are slow and inconsistent across people/time, and decisions often need speed (minutes/hours, not weeks). Leadership needs dashboards and measurable KPIs. + +This will feel like a stretch—and that's the point. You will be supported by structure, templates, and teamwork. + +## What You Will Produce (End Result) + +A working classifier for a well-documented, established, theoretically grounded construct (your choice), plus a business-style memo explaining: + +- What the construct of interest is that you selected from the seven options: + https://www.ringel.ai/UNC/2026/BUSI488/Class23/Ringel_488-2026_Capstone_Constructs.pdf +- How it was established and where it is theoretically anchored and motivated +- Why the business should care +- How well your classifier works +- What it costs (time + money) +- Whether it's reliable and reproducible + +## Choose a Construct of Interest: Meaningful and Labelable + +Your construct should be: + +- **Business-relevant** — addresses a real business decision +- **Theoretically grounded** — anchored in established literature +- **Well documented** — clearly defined in academic or industry sources +- **Observable in text** — detectable in your data source +- **Definable with clear rules** — specific enough for reliable labeling +- **Complex & nuanced** — more than just sentiment (not positive/negative) + +You must pick one of the seven provided constructs of interest from here: +https://www.ringel.ai/UNC/2026/BUSI488/Class23/Ringel_488-2026_Capstone_Constructs.pdf + +## Potential Public Data Sources + +Below are places to obtain public text data for your classifier (downloads and/or official APIs). For each of the seven constructs of interest, there are already some data sources suggested. Below are additional sources to consider. Use sources with clear educational/research access terms whenever possible. Or use your own source and data (but observe guardrails & ethics as outlined below). + +**Important:** Choose a source that fits your construct and is realistically useful to a firm. + +### Customer Reviews & Product Voice + +**Amazon Reviews (research datasets)** + +- McAuley Lab Amazon datasets +- Stanford SNAP Amazon dataset page +- Hugging Face: Amazon Reviews 2023 + +**Yelp Open Dataset** + +- Kaggle: Yelp Open Dataset + +### Corporate Filings & Investor Communications + +**SEC EDGAR (10-K / 10-Q / 8-K etc.) — Official APIs** + +- EDGAR Application Programming Interfaces +- SEC Developer Resources +- SEC Disclosure Data API announcement + +### Regulation, Enforcement, Safety & Compliance + +**FDA Warning Letters** + +- FDA: Warning Letters +- Data.gov: Warning Letters dataset + +**U.S. Consumer Product Safety Commission (CPSC) Recalls — API** + +- CPSC Recalls API information + +### Health & Public Policy (Public) + +**ClinicalTrials.gov (Modernized API)** + +- ClinicalTrials.gov Data API + +**PubMed / NCBI (E-utilities API)** + +- NCBI E-utilities documentation + +### Research & Innovation + +**arXiv (bulk data)** + +- arXiv bulk data help +- arXiv bulk data on AWS S3 + +**Patents (USPTO + PatentsView)** + +- USPTO Open Data Portal: bulk data +- PatentsView: bulk downloads + +### Government Documents & Speeches + +**govinfo API + documentation** (Congressional docs, Federal Register, etc.) + +- govinfo API overview +- govinfo API docs + +**Congress.gov API** + +- Congress.gov API (GPO) + +**U.S. Presidential speeches** (Miller Center data portal) + +- Miller Center: Presidential Speech Data + +### News / Media Monitoring + +**GDELT** (global news monitoring) + +- GDELT data downloads +- GDELT 2.0 API documentation (blog) + +**Common Crawl News dataset** (CC-NEWS) + +- Common Crawl: News dataset announcement + +### Security / Technical Risk + +**NIST National Vulnerability Database (NVD)** — CVE APIs & feeds + +- NVD: Vulnerabilities API +- NVD: Data feeds + +### Earnings Calls / Transcripts (public option) + +**Academic/open dataset option** + +- STRUX dataset page + +### Tips & Rules + +- **Tip:** Before committing, do a quick pilot on 100–200 texts to confirm your construct appears in the source and that your labels are workable +- **Rule:** Use public data or properly de-identified data only. No sensitive internal company data. + +--- + +## The Workflow: What You Must Do (Step-by-Step) + +### Step 1 — Construct Choice + Business Case + +**Goal:** Pick a well-documented, theoretically founded construct of interest. Explain why a firm would want to classify this construct at scale. +https://www.ringel.ai/UNC/2026/BUSI488/Class23/Ringel_488-2026_Capstone_Constructs.pdf + +**Include:** + +- Who is the stakeholder? (CX lead, compliance, product manager, investor relations, etc.) +- What decision/action will classification enable? +- What happens if the classifier is wrong? (false positives vs false negatives) +- Why now? (volume, speed, competitive need) + +**Deliverable:** 1–2 page concept brief + +### Step 2 — Define the Construct Precisely + +**Goal:** Turn the concept into labels humans can apply consistently. + +**You must create:** + +- Label set (classes) +- Clear definitions and decision rules +- Borderline cases: what to do when unclear +- "None/other" policy IF APPLICABLE (multi-class yes, multi-label no) +- 2-3 example texts per label (your own) + +**Decide:** multi-class (one label) vs multi-label (multiple labels can apply). + +**Deliverable:** Labeling codebook (PDF) + +### Step 3 — Identify and Justify the Data Source + +**Goal:** Show that your source actually contains your construct. + +**Include:** + +- Why this source fits the business purpose +- How a firm would use it regularly (weekly monitoring, quarterly reporting, etc.) +- Basic access plan (download/API/scrape—follow site rules) + +**Deliverable:** Data source plan + sampling approach + +### Step 4 — Collect Your Dataset + +**Targets:** + +- 20,000 texts total for train/test +- 1,200 texts for a locked holdout set + +**Deliverable:** Raw dataset file + collection notebook/script + documentation + +### Step 5 — Preprocess and Split + +**You must:** + +- Clean text (remove duplicates, empty, obvious spam) +- Create: train/test totaling 20,000; holdout = 1,200 (lock it in early—don't tune on it) + - **Beware of label imbalance!** May need to curate strategically! +- Report length stats and basic summaries +- Plan for imbalance (it's normal—just don't ignore it) + +**Deliverable:** Preprocessing notebook + split files + basic descriptive stats + +### Step 6 — Human Label the Holdout + +**Requirements:** + +- The 1,200 holdout must be labeled by humans +- At least 3 people must label each example +- Either independent labels + agreement report (Krippendorff's alpha recommended), or discussion-based consensus + documented process + +**Deliverable:** Holdout with 3+ labels per item (or consensus + notes) + reliability report + +### Step 7 — Benchmark GenAI Labeling (≥6 Models) + +**Goal:** Treat genAI as alternative "labelers" and compare. + +**Requirements:** + +- Run at least 6 models from at least 3 different suppliers (e.g., OpenAI, Anthropic, Meta, xAI, Google, FireworksAI, Deepseek, Moonshot) +- Fix prompts to make them comparable across models +- Track runtime, estimated cost, and reproducibility strategy + +**Metrics to report:** + +- Krippendorff's alpha (vs human labels) +- F1 (macro and per-class if possible) +- AUC (where applicable) +- MCC + +**Deliverable:** Benchmark notebook + results table + cost/time log + +### Step 8 — Select GenAI Labeling Strategy and Label Train/Test + +Choose the best single model or a combination (e.g., majority vote). Then label your 20,000 train/test set. + +**Deliverable:** Labeled train/test dataset + labeling script + total cost/time summary + +### Step 9 — Fine-Tune a Specialist Model + +**Goal:** Create a smaller, task-specific classifier that can match/exceed genAI labeling quality. + +**Minimum requirements:** + +- Fine-tune at least one pretrained model (RoBERTa or similar) +- Try at least four configurations (model choice or hyperparameters) +- Evaluate on holdout (only at the end) + +**Deliverable:** Training notebook + saved model + evaluation on holdout + +### Step 10 — Final Comparison: GenAI vs Specialist + +Your final analysis must answer: + +- Which is better on holdout and by how much? +- Which is cheaper per 1,000,000 texts? +- Which is faster? +- Which is more consistent and reproducible? +- What would you recommend a firm deploy, and why? + +**Deliverable:** Final comparison section + recommendation + +--- + +## What To Submit (Deliverables Checklist) + +### A) Executive Memo (Executive Style, Max 5 Pages) + +- Construct + why it matters + theoretically grounded and well documented +- Data source + governance/ethics +- Label schema overview +- Results summary: best genAI vs best specialist +- Cost/time/reproducibility comparison +- Recommendation for a real firm + +### B) Technical Appendix (Slides or PDF) + +- Pipeline diagram (data → labels → model → evaluation) +- Label codebook (or link/appendix) +- Benchmark table (≥6 genAI models) +- Fine-tuning experiments + results +- Error analysis: where does it fail and why? + +### C) Code + Artifacts (Datasets) + +- Colab notebooks (reproducible) +- **Datasets:** + - Holdout with human labels + - Train/test with genAI labels + - Any other data used + - All labels of all models for each run + majority labels +- Saved fine-tuned model + inference script (provide link to webspace/Google Drive/Dropbox; **do not upload to Canvas!**) +- Cost/time log + +### D) IGNITE Talk in Class + +- **20 PPTX Slides**, 15 sec per slide (automatic transitions), 5 min total +- Submit your PPTX slides (not PDF, not Google Slides, not other format!) set to auto transitions after 15 sec on Canvas +- I will have all slides ready for you to go in class +- **Key messages and insights only** +- **Every team member must present at least 2 slides** +- There will be a 3 min Q&A after each IGNITE talk + +**Learn more:** + +- https://en.wikipedia.org/wiki/Ignite_(event) +- https://www.ignitetalks.io/ +- https://robbiesenbach.com/deliver-successful-ignite-talk/ + +--- + +- **Start simple, then improve:** Your first codebook will be imperfect. Iterate. +- **Design labels for reliability:** If humans disagree a lot, the model will struggle. Fix definitions before scaling. +- **Lock the holdout early:** It counts 35% of your grade. Don't change direction repeatedly—fix it early and make sure it's representative for what you ultimately want to achieve from a business perspective. +- **Do a small pilot before spending money:** Test prompts and schema on 100–200 examples first. +- **Don't hide class imbalance:** Imbalance is normal. But it can damage evaluation and training. **You must solve this!** Use macro F1 and MCC, and document prevalence. + +## Team Structure (Recommended Roles) + +- **Project lead:** Scope, project plan, milestones, quality control +- **Data lead:** Collection, cleaning, dataset documentation +- **Labeling lead:** Codebook, human labeling workflow, reliability stats +- **GenAI lead:** API scripts, benchmarking, cost/time tracking +- **Modeling lead:** Fine-tuning, hyperparameters, reproducibility, evaluation +- **Delivery lead:** Assemble all outputs, organize write-ups, final quality control + +## Grading Rubric (100 Points) + +| Criterion | Points | +| ----------------------------------------- | ------ | +| Business framing & construct clarity | 20 | +| Data pipeline quality + documentation | 15 | +| Human labeling process + reliability | 15 | +| GenAI benchmarking rigor | 20 | +| Fine-tuning rigor + evaluation discipline | 20 | +| Final comparison + recommendation quality | 10 | + +### Minimum Requirements Per Letter Grade: + +#### C (- to +) + +- Fine-tuned model with F1 score > 0.80 +- Performance comparison genAI vs. fine-tuned model +- Labeled datasets +- Documentation +- Python notebook(s) to replicate pre-processing, training, and evaluation + +#### B (- to +) + +All of the above plus **at least three** of the following: + +- Cost, time, reproducibility analysis +- Comparison of 6 or more models from at least 3 different suppliers +- Contemporary data that you collected (not an off-the-shelf dataset) +- Compelling use-case for your classifier with complete business case + +#### A (- to A) + +All of the above plus **at least three** of the following: + +- Error-analysis (corner cases, rare or complex texts) +- Mitigation strategy and implementation to overcome identified model weaknesses +- Additional baselines (e.g., alternative classification approaches like dictionaries, topic models) +- Comparison to amateur labels + +## Guardrails (Ethics & Compliance) + +- Public data or approved de-identified data only +- Remove or avoid sensitive personal information +- Document limitations and potential bias +- If your construct is sensitive (health, safety, harassment), include a brief risk statement and mitigation steps + +## Estimated Effort + +| Task | Hours | +| --------------------------------------------------------- | ----- | +| Construct + codebook v1 + data source plan | 4 | +| Data collection + preprocessing + splits | 4 | +| 1.2K human labeling + reliability + codebook v2 | 8 | +| GenAI benchmarking (≥6 models) + choose labeling strategy | 3 | +| Label 20k + fine-tune specialist (2+ configs) | 2 | +| Final evaluation + memo + presentation | 3 | + +**Total: ~24 hours per student** + +## Best Work Featured in Vertical AI Paper + +I will select (and improve) the best 1–2 classifiers to be featured in my Vertical AI Paper where each team member will be acknowledged for their application of the synthetic expert/specialist approach. I will ask members whether they want to be named or not (choosing to remain anonymous has not impacted your grade). + +## Getting Started + +To help you with your Capstone, I wrote a full pipeline in a Python notebook that does all the key steps you need for your Capstone Project by example of classifying 10K sentences into business functions. This includes querying genAI via API at scale, creating holdout and training datasets, fine-tuning a pretrained LLM, and evaluating the performance of genAI and your fine-tuned (vertical AI) model. + +**What it does not do** is give you a construct of interest, collect your data, clean and preprocess your data, or draw conclusions and write reports for you. + +**Python notebook:** http://ringel.ai/UNC/2026/helpers/Ringel_2026_VerticalAI_Capstone_Pipeline_Example.ipynb + +**Zip file with outputs:** http://ringel.ai/UNC/2026/helpers/Ringel_2026_VerticalAI_Capstone_Pipeline_Example.zip (excludes the actual trained vertical AI because it is 1.5GB) + +All subfolders and datasets are included. This is a great blueprint for what data you need to deliver with your capstone on a shared drive (provide link to me) or uploaded if sufficiently small (less than 20MB). + +The contents of the zip file also help you see what the expected output is (by example of a multi-label classification problem). You will need to adapt this code to your problem. Use genAI (e.g., Claude Opus 4.6) for this. The pipeline gives you a solid base to work off. diff --git a/docs/reference/P3_SEC_Cybersecurity_Capstone.md b/docs/reference/P3_SEC_Cybersecurity_Capstone.md new file mode 100644 index 0000000..8fd17f2 --- /dev/null +++ b/docs/reference/P3_SEC_Cybersecurity_Capstone.md @@ -0,0 +1,688 @@ +# Project 3: SEC Cybersecurity Disclosure Quality Classifier + +## Capstone 2026 — BUSI488/COMP488 — Team Knowledge Transfer + +**Project:** Build a validated, reusable classifier that labels SEC cybersecurity disclosures by content category and specificity level, then fine-tune an open-weights model for deployment at scale. + +**Methodology:** Ringel (2023) "Synthetic Experts" pipeline — use frontier LLMs to generate training labels, then distill into a small open-weights encoder model. + +**Why this project:** No HuggingFace dataset of extracted Item 1C disclosures exists. No trained classifier for cybersecurity disclosure quality exists. No domain-adapted ModernBERT on SEC filings exists. The iXBRL CYD taxonomy just went live (Dec 2024). We produce **three publishable artifacts**: a novel dataset, a labeling methodology, and a SOTA classifier. + +--- + +## Table of Contents + +1. [Regulatory Background](#1-regulatory-background) +2. [Labeling Rubric](#2-labeling-rubric) +3. [Data Acquisition](#3-data-acquisition) +4. [GenAI Labeling Pipeline](#4-genai-labeling-pipeline) +5. [Model Strategy](#5-model-strategy) +6. [Evaluation & Validation](#6-evaluation--validation) +7. [Release Artifacts](#7-release-artifacts) +8. [3-Week Schedule (6 People)](#8-3-week-schedule-6-people) +9. [Budget](#9-budget) +10. [Reference Links](#10-reference-links) + +--- + +## 1. Regulatory Background + +### The Rule: SEC Release 33-11216 (July 2023) + +The SEC adopted final rules requiring public companies to disclose cybersecurity risk management, strategy, governance, and material incidents. This created a massive new text corpus with natural variation in quality — perfect for classification. + +Full rule PDF: +Fact sheet: + +### Item 1C — Annual Disclosure (10-K) + +Appears as **Regulation S-K Item 106**, reported in **Item 1C** of the 10-K. Two mandated subsections: + +**Item 106(b) — Risk Management and Strategy:** +1. Processes for assessing, identifying, and managing material cybersecurity risks +2. Whether/how cybersecurity processes integrate into overall enterprise risk management (ERM) +3. Whether the company engages external assessors, consultants, or auditors +4. Processes to oversee/identify risks from third-party service providers +5. Whether cybersecurity risks (including prior incidents) have materially affected or are reasonably likely to affect business strategy, results, or financial condition + +**Item 106(c) — Governance:** + +*Board Oversight (106(c)(1)):* +- Description of board's oversight of cybersecurity risks +- Identification of responsible board committee/subcommittee +- Processes by which the board/committee is informed about risks + +*Management's Role (106(c)(2)):* +- Which management positions/committees are responsible +- Relevant expertise of those persons +- How management monitors prevention, detection, mitigation, and remediation +- Whether and how frequently management reports to the board + +**Key design note:** The SEC uses "describe" — it does not prescribe specific items. The enumerated sub-items are non-exclusive suggestions. This principles-based approach creates natural variation in specificity and content, which is exactly what our rubric captures. + +### Item 1.05 — Incident Disclosure (8-K) + +Required within **4 business days** of determining a cybersecurity incident is material: +1. Material aspects of the nature, scope, and timing of the incident +2. Material impact or reasonably likely material impact on the registrant + +**Key nuances:** +- The 4-day clock starts at the **materiality determination**, not the incident itself +- Companies explicitly do NOT need to disclose technical details that would impede response/remediation +- The AG can delay disclosure up to 120 days for national security +- Companies must amend the 8-K when new material information becomes available + +**The May 2024 shift:** After SEC Director Erik Gerding clarified that Item 1.05 is only for *material* incidents, companies pivoted from Item 1.05 to Items 8.01/7.01 for non-material disclosures: +- Pre-guidance: 72% used Item 1.05, 28% used 8.01/7.01 +- Post-guidance: 34% used Item 1.05, 66% used 8.01/7.01 + +**Our extraction must capture all three item types.** + +### Compliance Timeline + +| Date | Milestone | +|------|-----------| +| Jul 26, 2023 | Rule adopted | +| Sep 5, 2023 | Rule effective | +| Dec 15, 2023 | Item 1C required in 10-Ks (FY ending on/after this date) | +| Dec 18, 2023 | Item 1.05 required in 8-Ks | +| Jun 15, 2024 | Item 1.05 required for smaller reporting companies | +| Dec 15, 2024 | iXBRL tagging of Item 106 (CYD taxonomy) required | +| Dec 18, 2024 | iXBRL tagging of 8-K Item 1.05 required | + +### iXBRL CYD Taxonomy + +The SEC published the **Cybersecurity Disclosure (CYD) Taxonomy** on Sep 16, 2024. Starting with filings after Dec 15, 2024, Item 1C disclosures are tagged in Inline XBRL using the `cyd` prefix. This means 2025 filings can be parsed programmatically via XBRL rather than HTML scraping. + +Taxonomy schema: `http://xbrl.sec.gov/cyd/2024` +Taxonomy guide: + +### Corpus Size + +| Filing Type | Estimated Count (as of early 2026) | +|-------------|-----------------------------------| +| 10-K with Item 1C (FY2023 cycle) | ~4,500 | +| 10-K with Item 1C (FY2024 cycle) | ~4,500 | +| 8-K cybersecurity incidents | ~80 filings (55 incidents + amendments) | +| **Total filings** | **~9,000-10,000** | +| **Estimated paragraphs** (from Item 1C) | **~50,000-80,000** | + +--- + +## 2. Labeling Rubric + +### Dimension 1: Content Category (single-label per paragraph) + +Derived directly from the SEC rule structure. Each paragraph receives exactly one category: + +| Category | SEC Basis | What It Covers | Example Markers | +|----------|-----------|----------------|-----------------| +| **Board Governance** | 106(c)(1) | Board/committee oversight, briefing frequency, board cyber expertise | "Audit Committee," "Board of Directors oversees," "quarterly briefings" | +| **Management Role** | 106(c)(2) | CISO/CTO identification, qualifications, reporting structure | "Chief Information Security Officer," "reports to," "years of experience" | +| **Risk Management Process** | 106(b) | Assessment/identification processes, ERM integration, framework references | "NIST CSF," "ISO 27001," "risk assessment," "vulnerability management" | +| **Third-Party Risk** | 106(b) | Vendor oversight, external assessors/consultants, supply chain risk | "third-party," "service providers," "penetration testing by," "external auditors" | +| **Incident Disclosure** | 8-K 1.05 | Nature/scope/timing of incidents, material impact, remediation | "unauthorized access," "detected," "incident," "remediation," "impacted" | +| **Strategy Integration** | 106(b)(2) | Material impact on business strategy, cyber insurance, resource allocation | "business strategy," "insurance," "investment," "material," "financial condition" | +| **None/Other** | — | Boilerplate intros, legal disclaimers, non-cybersecurity content | Forward-looking statement disclaimers, general risk language | + +### Dimension 2: Specificity (4-point ordinal per paragraph) + +Grounded in Berkman et al. (2018), Gibson Dunn surveys, and PwC quality tiers: + +| Level | Label | Definition | Decision Test | +|-------|-------|------------|---------------| +| **1** | **Generic Boilerplate** | Could apply to any company. Conditional language ("may," "could"). No named entities. Passive voice. | "Could I paste this into a different company's filing unchanged?" → Yes | +| **2** | **Sector-Adapted** | References industry context or named frameworks (NIST, ISO) but no firm-specific detail. | "Does this name something specific but not unique to THIS company?" → Yes | +| **3** | **Firm-Specific** | Names roles (CISO by name), committees, reporting lines, specific programs, or processes unique to the firm. Active voice with accountability. | "Does this contain at least one fact unique to THIS company?" → Yes | +| **4** | **Quantified-Verifiable** | Includes metrics, dollar amounts, dates, frequencies, third-party audit references, or independently verifiable facts. Multiple firm-specific facts with operational detail. | "Could an outsider verify a specific claim in this paragraph?" → Yes | + +**Boundary rules for annotators:** +- If torn between 1 and 2: "Does it name ANY framework, standard, or industry term?" → Yes = 2 +- If torn between 2 and 3: "Does it mention anything unique to THIS company?" → Yes = 3 +- If torn between 3 and 4: "Does it contain TWO OR MORE specific, verifiable facts?" → Yes = 4 + +**Important:** EvasionBench (Ma et al., 2026) found that a 5-level ordinal scale failed (kappa < 0.5) and had to be collapsed to 3 levels. **Pilot test this 4-level scale on 50 paragraphs early.** Be prepared to merge levels 1-2 or 3-4 if inter-annotator agreement is poor. + +### Boilerplate vs. Substantive Markers (from the literature) + +**Boilerplate indicators:** +- Conditional language: "may," "could," "might" +- Generic risk statements without company-specific context +- No named individuals, committees, or frameworks +- Identical language across same-industry filings (cosine similarity > 0.8) +- Passive voice: "cybersecurity risks are managed" + +**Substantive indicators:** +- Named roles and reporting structures ("Our CISO, Jane Smith, reports quarterly to the Audit Committee") +- Specific frameworks by name (NIST CSF, ISO 27001, SOC 2, PCI-DSS) +- Concrete processes (penetration testing frequency, tabletop exercises) +- Quantification (dollar investment, headcount, incident counts, training completion rates) +- Third-party names or types of assessments +- Temporal specificity (dates, frequencies, durations) + +### Mapping to NIST CSF 2.0 + +For academic grounding, our content categories map to NIST CSF 2.0 functions: + +| Our Category | NIST CSF 2.0 | +|-------------|-------------| +| Board Governance | GOVERN (GV.OV, GV.RR) | +| Management Role | GOVERN (GV.RR, GV.RM) | +| Risk Management Process | IDENTIFY (ID.RA), GOVERN (GV.RM), PROTECT (all) | +| Third-Party Risk | GOVERN (GV.SC) | +| Incident Disclosure | DETECT, RESPOND, RECOVER | +| Strategy Integration | GOVERN (GV.OC, GV.RM) | + +--- + +## 3. Data Acquisition + +### 3.1 Extracting 10-K Item 1C + +**Recommended pipeline:** + +``` +sec-edgar-downloader → edgar-crawler → paragraph segmentation → dataset + (bulk download) (parse Item 1C) (split into units) +``` + +**Tools:** + +| Tool | Purpose | Install | Notes | +|------|---------|---------|-------| +| `sec-edgar-downloader` | Bulk download 10-K filings by CIK | `pip install sec-edgar-downloader` | Pure downloader, no parsing | +| `edgar-crawler` | Extract specific item sections to JSON | `git clone github.com/lefterisloukas/edgar-crawler` | Best for bulk extraction; configure `['1C']` in items list | +| `edgartools` | Interactive exploration, XBRL parsing | `pip install edgartools` | `tenk['Item 1C']` accessor; great for prototyping | +| `sec-api` | Commercial API, zero parsing headaches | `pip install sec-api` | `extractorApi.get_section(url, "1C", "text")` — paid, free tier available | + +**EDGAR API requirements:** +- Rate limit: 10 requests/second +- Required: Custom `User-Agent` header with name and email (e.g., `"TeamName team@email.com"`) +- SEC blocks requests without proper User-Agent (returns 403) + +**For iXBRL-tagged filings (2025+):** Use `edgartools` XBRL parser to extract CYD taxonomy elements directly. This gives pre-structured data aligned with regulatory categories. + +**Fallback corpus:** `PleIAs/SEC` on HuggingFace (373K 10-K full texts, CC0 license) — but sections are NOT pre-parsed; you must extract Item 1C yourself. + +### 3.2 Extracting 8-K Incident Disclosures + +| Tool | Purpose | URL | +|------|---------|-----| +| `sec-8k-item105` | Extract Item 1.05 from 8-Ks, iXBRL + HTML fallback | `github.com/JMousqueton/sec-8k-item105` | +| `SECurityTr8Ker` | Monitor SEC RSS for new cyber 8-Ks, Slack/Teams alerts | `github.com/pancak3lullz/SECurityTr8Ker` | +| Debevoise 8-K Tracker | Curated list with filing links, dates, amendments | `debevoisedatablog.com/2024/03/06/cybersecurity-form-8-k-tracker/` | +| Board Cybersecurity Tracker | Links filings to MITRE ATT&CK, impact assessments | `board-cybersecurity.com/incidents/tracker` | + +**Critical:** Must capture Item 1.05 AND Items 8.01/7.01 (post-May 2024 shift). + +### 3.3 Paragraph Segmentation + +Once Item 1C text is extracted, segment into paragraphs: +- Split on double newlines or `

` tags (depending on extraction format) +- Minimum paragraph length: 20 words (filter out headers, whitespace) +- Maximum paragraph length: 500 words (split longer blocks at sentence boundaries) +- Preserve metadata: company name, CIK, ticker, filing date, fiscal year + +Expected yield: ~5-8 paragraphs per Item 1C disclosure × ~9,000 filings = **~50,000-70,000 paragraphs** + +### 3.4 Pre-Existing Datasets and Resources + +| Resource | What It Is | URL | +|----------|-----------|-----| +| PleIAs/SEC | 373K full 10-K texts (CC0) | `huggingface.co/datasets/PleIAs/SEC` | +| EDGAR-CORPUS | 220K filings with sections pre-parsed (Apache 2.0) | `huggingface.co/datasets/eloukas/edgar-corpus` | +| Board Cybersecurity 23-Feature Analysis | Regex-based extraction of 23 governance/security features from 4,538 10-Ks | `board-cybersecurity.com/research/insights/` | +| Gibson Dunn S&P 100 Survey | Detailed feature analysis of disclosure content | `corpgov.law.harvard.edu/2025/01/09/cybersecurity-disclosure-overview-...` | +| Florackis et al. (2023) "Cybersecurity Risk" | Firm-level cyber risk measure from 10-K text, RFS publication | SSRN: 3725130, data companion: 4319606 | +| zeroshot/cybersecurity-corpus | General cybersecurity text (not SEC-specific, useful for DAPT) | `huggingface.co/datasets/zeroshot/cybersecurity-corpus` | + +--- + +## 4. GenAI Labeling Pipeline + +### 4.1 Multi-Model Consensus (EvasionBench Architecture) + +We follow Ma et al. (2026, arXiv:2601.09142) — the EvasionBench pipeline designed for an almost identical task (ordinal classification of financial text). Their approach achieved Cohen's Kappa = 0.835 with human annotators. + +**Stage 1 — Dual Independent Annotation (all ~50K paragraphs):** +- Annotator A: **Claude Sonnet 4.6** (batch API — $1.50/$7.50 per M input/output tokens) +- Annotator B: **Gemini 2.5 Flash** ($0.30/$2.50 per M tokens) +- Architectural diversity (Anthropic vs. Google) minimizes correlated errors +- ~83% of paragraphs will have immediate agreement + +**Stage 2 — Judge Panel for Disagreements (~17% = ~8,500 cases):** +- Judge 1: **Claude Opus 4.6** (batch — $2.50/$12.50 per M tokens) +- Judge 2: **GPT-5** (batch — $0.63/$5.00 per M tokens) +- Judge 3: **Gemini 2.5 Pro** (~$2-4/$12-18 per M tokens) +- Majority vote (2/3) resolves disagreements +- Anti-bias: randomize label presentation order + +**Stage 3 — Active Learning Pass:** +- Cluster remaining low-confidence cases +- Human-review ~5% (~2,500 cases) to identify systematic errors +- Iterate rubric if needed, re-run affected subsets + +### 4.2 Prompt Template + +``` +SYSTEM PROMPT: +You are an expert annotator classifying paragraphs from SEC cybersecurity +disclosures (10-K Item 1C and 8-K Item 1.05 filings). + +For each paragraph, assign: +(a) content_category: exactly one of ["Board Governance", "Management Role", + "Risk Management Process", "Third-Party Risk", "Incident Disclosure", + "Strategy Integration", "None/Other"] +(b) specificity_level: integer 1-4 + +CONTENT CATEGORIES: +- Board Governance: Board/committee oversight of cybersecurity risks, briefing + frequency, board member cyber expertise +- Management Role: CISO/CTO/CIO identification, qualifications, reporting + structure, management committees +- Risk Management Process: Risk assessment methodology, framework adoption + (NIST, ISO, etc.), vulnerability management, monitoring, incident response + planning, tabletop exercises +- Third-Party Risk: Vendor/supplier risk oversight, external assessor engagement, + contractual security requirements, supply chain risk +- Incident Disclosure: Description of cybersecurity incidents, scope, timing, + impact, remediation actions +- Strategy Integration: Material impact on business strategy or financials, + cyber insurance, investment/resource allocation +- None/Other: Boilerplate introductions, legal disclaimers, forward-looking + statement warnings, non-cybersecurity content + +SPECIFICITY SCALE: +1 - Generic Boilerplate: Could apply to any company. Conditional language + ("may," "could"). No named entities. + Example: "We face cybersecurity risks that could materially affect our + business operations." + +2 - Sector-Adapted: References industry context or named frameworks but no + firm-specific details. + Example: "We employ a cybersecurity framework aligned with the NIST + Cybersecurity Framework to manage cyber risk." + +3 - Firm-Specific: Contains facts unique to this company — named roles, + committees, specific programs, reporting lines. + Example: "Our CISO reports quarterly to the Audit Committee on + cybersecurity risk posture and incident trends." + +4 - Quantified-Verifiable: Includes metrics, dollar amounts, dates, + frequencies, third-party audit references, or independently verifiable facts. + Example: "Following the March 2024 incident affecting our payment systems, + we engaged CrowdStrike and implemented network segmentation at a cost of + $4.2M, completing remediation in Q3 2024." + +BOUNDARY RULES: +- If torn between 1 and 2: "Does it name ANY framework, standard, or industry + term?" If yes → 2 +- If torn between 2 and 3: "Does it mention anything unique to THIS company?" + If yes → 3 +- If torn between 3 and 4: "Does it contain TWO OR MORE specific, verifiable + facts?" If yes → 4 + +Respond with valid JSON only. Include a brief reasoning field. + +USER PROMPT: +Company: {company_name} +Filing Date: {filing_date} +Paragraph: +{paragraph_text} +``` + +**Expected output:** +```json +{ + "content_category": "Board Governance", + "specificity_level": 3, + "reasoning": "Identifies Audit Committee by name and describes quarterly briefing cadence, both firm-specific facts." +} +``` + +### 4.3 Practical Labeling Notes + +- **Always use Batch API.** Both OpenAI and Anthropic offer 50% discount for async/batch processing (24-hour turnaround). No reason to use real-time. +- **Prompt caching:** The system prompt (~800 tokens) is identical for every request. With Anthropic's prompt caching, cached reads cost 10% of base price. Combined with batch discount = 5% of standard price. +- **Structured output mode:** Use JSON mode / structured outputs on all providers. Reduces parsing errors by ~90%. +- **Reasoning models (o3, extended thinking):** Use ONLY as judges for disagreement cases, not as primary annotators. They're overkill for clear-cut classification and expensive due to reasoning token consumption. + +### 4.4 Gold Set Protocol + +**Non-negotiable for publication quality.** + +1. Sample 300-500 paragraphs, stratified by: + - Expected content category (ensure all 7 represented) + - Expected specificity level (ensure all 4 represented) + - Industry (financial services, tech, healthcare, manufacturing) + - Filing year (FY2023 vs FY2024) + +2. Two team members independently label the full gold set + +3. Compute: + - Cohen's Kappa (binary/nominal categories) + - Krippendorff's Alpha (ordinal specificity scale) + - Per-class confusion matrices + - Target: Kappa > 0.75 ("substantial agreement") + +4. Adjudicate disagreements with a third team member + +5. Run the full MMC pipeline on the gold set and compare + +--- + +## 5. Model Strategy + +### 5.1 Primary: SEC-ModernBERT-large + +**This model does not exist publicly. Building it is a core contribution.** + +**Base model:** `answerdotai/ModernBERT-large` +- 395M parameters +- 8,192-token native context (vs. 512 for DeBERTa-v3-large) +- RoPE + alternating local/global attention + FlashAttention +- 2-4x faster than DeBERTa-v3-large +- Apache 2.0 license +- GLUE: 90.4 (only 1 point behind DeBERTa-v3-large's 91.4) + +**Step 1 — Domain-Adaptive Pre-Training (DAPT):** + +Continue MLM pre-training on SEC filing text to create "SEC-ModernBERT-large": + +- **Training corpus:** 200-500M tokens of SEC filings (from PleIAs/SEC or your own EDGAR download). Include 10-Ks, 10-Qs, 8-Ks, proxy statements. +- **MLM objective:** 30% masking rate (ModernBERT convention) +- **Learning rate:** ~5e-5 (much lower than from-scratch pre-training) +- **Hardware (RTX 3090):** bf16, gradient checkpointing, seq_len=1024-2048, batch_size=2-4 + gradient accumulation to effective batch 16-32 +- **VRAM estimate:** ~12-15GB at batch=4, seq=2048 with gradient checkpointing — fits on 3090 + +**Evidence DAPT works:** +- Gururangan et al. (2020): consistent improvements across all tested domains +- Patent domain ModernBERT (arXiv:2509.14926): +0.9 to +2.8 F1 from continued pre-training on 31.6B tokens +- Scaling-law analysis on SEC filings (arXiv:2512.12384): consistent improvement with largest gains in first 200M tokens +- Databricks customer report: 70% → 95% accuracy with domain-specific pre-training + +**Step 2 — Classification Fine-Tuning:** + +Fine-tune SEC-ModernBERT-large on the 50K labeled paragraphs: + +- **Sequence length:** 2048 tokens (captures full regulatory paragraphs that 512-token models truncate) +- **Two classification heads:** content_category (7-class softmax) + specificity_level (4-class ordinal or softmax) +- **Add supervised contrastive loss (SCL):** Combine standard cross-entropy with SCL that pulls same-class embeddings together. Gunel et al. (2020) showed +0.5-1.5% improvement, especially for rare/imbalanced classes. +- **VRAM:** ~11-13GB at batch=8, seq=2048 in bf16 — comfortable on 3090 +- **3090 supports bf16** natively via Ampere Tensor Cores. Use `bf16=True` in HuggingFace Trainer. No loss scaling needed (unlike fp16). + +### 5.2 Dark Horse: NeoBERT + +`chandar-lab/NeoBERT` +- **250M parameters** (100M fewer than ModernBERT-large, 185M fewer than DeBERTa-v3-large) +- 4,096-token context +- SwiGLU, RoPE, Pre-RMSNorm, FlashAttention +- GLUE: 89.0 (close to DeBERTa-v3-large's 91.4) +- MTEB: 51.3 (crushes everything else — ModernBERT-large is 46.9) +- MIT license +- Requires `trust_remote_code=True` +- Almost nobody is using it for domain-specific tasks + +Same DAPT + fine-tuning pipeline as ModernBERT-large, with even less VRAM. + +### 5.3 Baseline: DeBERTa-v3-large + +`microsoft/deberta-v3-large` +- 304M backbone + 131M embedding = ~435M total +- 512-token native context (can push to ~1024) +- Disentangled attention + ELECTRA-style RTD pre-training +- GLUE: **91.4** — still the highest among all encoders +- MIT license +- **Weakness:** no long context support, completely fails at retrieval tasks + +Include as baseline to show improvement from (a) long context and (b) DAPT. + +### 5.4 Ablation Design + +| Experiment | Model | Context | DAPT | SCL | Purpose | +|-----------|-------|---------|------|-----|---------| +| Baseline | DeBERTa-v3-large | 512 | No | No | "Standard" approach per syllabus | +| + Long context | ModernBERT-large | 2048 | No | No | Shows context window benefit | +| + Domain adapt | SEC-ModernBERT-large | 2048 | Yes | No | Shows DAPT benefit | +| + Contrastive | SEC-ModernBERT-large | 2048 | Yes | Yes | Shows SCL benefit | +| Efficiency | NeoBERT (+ DAPT) | 2048 | Yes | Yes | 40% fewer params, comparable? | +| **Ensemble** | SEC-ModernBERT + DeBERTa | mixed | mixed | — | Maximum performance | + +The ensemble averages logits from SEC-ModernBERT-large (long context, domain-adapted) and DeBERTa-v3-large (highest raw NLU). Their architecturally different attention mechanisms mean uncorrelated errors. + +### 5.5 Training Framework + +- **Encoder fine-tuning:** HuggingFace `transformers` + `Trainer` with `AutoModelForSequenceClassification` +- **DAPT continued pre-training:** HuggingFace `transformers` with `DataCollatorForLanguageModeling` +- **SCL implementation:** Custom training loop or modify Trainer with dual loss +- **Few-shot prototyping:** `SetFit` (sentence-transformers based) for rapid baseline in <30 seconds + +**Key reference:** Phil Schmid's ModernBERT fine-tuning tutorial: + +### 5.6 Domain-Specific Encoder Models (for comparison only) + +These exist but are all BERT-base (110M params, 512 context) — architecturally outdated: + +| Model | HuggingFace ID | Domain | Params | +|-------|---------------|--------|--------| +| SEC-BERT | `nlpaueb/sec-bert-base` | 260K 10-K filings | 110M | +| SEC-BERT-SHAPE | `nlpaueb/sec-bert-shape` | Same, with number normalization | 110M | +| FinBERT | `ProsusAI/finbert` | Financial sentiment | 110M | +| Legal-BERT | `nlpaueb/legal-bert-base-uncased` | 12GB legal text | 110M | +| SecureBERT | arXiv:2204.02685 | Cybersecurity text | 110M | + +Our DAPT approach on a modern architecture (ModernBERT-large or NeoBERT) will outperform all of these. Include SEC-BERT as an additional baseline if time permits. + +--- + +## 6. Evaluation & Validation + +### 6.1 Required Metrics (from syllabus) + +| Metric | Target | Notes | +|--------|--------|-------| +| **Macro-F1** on human holdout | Report per-class and overall | Minimum 1.2K holdout examples | +| **Per-class F1** | Identify weak categories | Expect "None/Other" to be noisiest | +| **Krippendorff's Alpha** | > 0.67 (adequate), > 0.75 (good) | GenAI labels vs. human gold set | +| **Calibration plots** | Reliability diagrams | For probabilistic outputs (softmax) | +| **Robustness splits** | Report by time period, industry, filing size | FY2023 vs FY2024; GICS sector; word count quartiles | + +### 6.2 Downstream Validity Tests + +These demonstrate that the classifier's predictions correlate with real-world outcomes: + +**Test 1 — Breach Prediction (strongest):** +- Do firms with lower specificity scores subsequently appear in breach databases? +- Cross-reference with: + - **Privacy Rights Clearinghouse** (80K+ breaches; Mendeley dataset provides ticker/CIK matching: `doi.org/10.17632/w33nhh3282.1`) + - **VCDB** (8K+ incidents, VERIS schema: `github.com/vz-risk/VCDB`) + - **Board Cybersecurity Incident Tracker** (direct SEC filing links: `board-cybersecurity.com/incidents/tracker`) + - **CISA KEV Catalog** (known exploited vulnerabilities: `cisa.gov/known-exploited-vulnerabilities-catalog`) + +**Test 2 — Market Reaction (if time permits):** +- Event study: abnormal returns in [-1, +3] window around 8-K Item 1.05 filing +- Does prior Item 1C disclosure quality predict magnitude of reaction? +- Small sample (~55 incidents) but high signal +- Regression: CAR = f(specificity_score, incident_severity, firm_size, industry) + +**Test 3 — Known-Groups Validity (easy, always include):** +- Do regulated industries (financial services under NYDFS, healthcare under HIPAA) produce systematically higher-specificity disclosures? +- Do larger firms (by market cap) have more specific disclosures? +- These are expected results — confirming them validates the measure + +**Test 4 — Boilerplate Index (easy, always include):** +- Compute cosine similarity of each company's Item 1C to the industry-median disclosure +- Does our specificity score inversely correlate with this similarity measure? +- This is an independent, construct-free validation of the "uniqueness" dimension + +### 6.3 External Benchmark + +Per syllabus: "include an external benchmark approach (i.e., previous best practice)." + +- **Board Cybersecurity's 23-feature regex extraction** is the natural benchmark. Their binary (present/absent) feature coding is the prior best practice. Our classifier should capture everything their regex captures plus the quality/specificity dimension they cannot measure. +- **Florackis et al. (2023) cybersecurity risk measure** from Item 1A text is another comparison — different section (1A vs 1C), different methodology (dictionary vs. classifier), different era (pre-rule vs. post-rule). + +--- + +## 7. Release Artifacts + +By project end, publish: + +1. **HuggingFace Dataset:** Extracted Item 1C paragraphs with labels — first public dataset of its kind +2. **SEC-ModernBERT-large:** Domain-adapted model weights — first SEC-specific ModernBERT +3. **Fine-tuned classifiers:** Content category + specificity models, ready to deploy +4. **Labeling rubric + prompt templates:** Reusable for future SEC disclosure research +5. **Extraction pipeline code:** EDGAR → structured paragraphs → labeled dataset +6. **Evaluation notebook:** All metrics, ablations, validation tests + +--- + +## 8. 3-Week Schedule (6 People) + +### Team Roles + +| Role | Person(s) | Primary Responsibility | +|------|-----------|----------------------| +| **Data Lead** | Person A | EDGAR extraction pipeline, paragraph segmentation, data cleaning | +| **Data Support** | Person B | 8-K extraction, breach database cross-referencing, dataset QA | +| **Labeling Lead** | Person C | Rubric refinement, GenAI prompt engineering, MMC pipeline orchestration | +| **Annotation** | Person D | Gold set human labeling, inter-rater reliability, active learning review | +| **Model Lead** | Person E | DAPT pre-training, classification fine-tuning, ablation experiments | +| **Eval & Writing** | Person F | Validation tests, metrics computation, final presentation, documentation | + +### Week 1: Data + Rubric + +| Day | Person A (Data Lead) | Person B (Data Support) | Person C (Labeling Lead) | Person D (Annotation) | Person E (Model Lead) | Person F (Eval & Writing) | +|-----|---------------------|------------------------|-------------------------|----------------------|----------------------|--------------------------| +| **Mon** | Set up EDGAR extraction pipeline (edgar-crawler + sec-edgar-downloader) | Set up 8-K extraction (sec-8k-item105) | Draft labeling rubric v1 from SEC rule | Read SEC rule + Gibson Dunn survey | Download ModernBERT-large, set up training env | Outline evaluation plan, identify breach databases | +| **Tue** | Begin bulk 10-K download (FY2023 cycle) | Extract all 8-K cyber filings (Items 1.05, 8.01, 7.01) | Pilot rubric on 30 paragraphs with Claude Opus | Pilot rubric on same 30 paragraphs independently | Download PleIAs/SEC corpus, prepare DAPT data | Download PRC Mendeley dataset, VCDB, set up cross-ref | +| **Wed** | Continue download (FY2024 cycle), begin Item 1C parsing | Build company metadata table (CIK → ticker → GICS sector → market cap) | Compare pilot labels with Person D, revise rubric boundary rules | Compute initial inter-rater agreement, flag problem areas | Begin DAPT pre-training (SEC-ModernBERT-large, ~2-3 days on 3090) | Map VCDB incidents to SEC filers by name matching | +| **Thu** | Paragraph segmentation pipeline, quality checks | Merge 8-K incidents with Board Cybersecurity Tracker data | Rubric v2 finalized; set up batch API calls for dual annotation | Begin gold set sampling (300-500 paragraphs, stratified) | DAPT continues (monitor loss, checkpoint) | Draft presentation outline | +| **Fri** | **Milestone: Full paragraph corpus ready (~50K+ paragraphs)** | **Milestone: 8-K incident dataset complete** | Launch Stage 1 dual annotation (Sonnet + Gemini Flash) on full corpus | Continue gold set labeling (target: finish 150/300) | DAPT continues | **Milestone: Evaluation framework + breach cross-ref ready** | + +### Week 2: Labeling + Training + +| Day | Person A | Person B | Person C | Person D | Person E | Person F | +|-----|----------|----------|----------|----------|----------|----------| +| **Mon** | Data cleaning — fix extraction errors, handle edge cases | Assist Person D with gold set labeling (second annotator) | Monitor dual annotation results (should be ~60% complete) | Continue gold set labeling, begin second pass | DAPT finishes; begin DeBERTa-v3-large baseline fine-tuning | Compute gold set inter-rater reliability (Kappa, Alpha) | +| **Tue** | Build train/holdout split logic (stratified by industry, year, specificity) | Continue gold set second-annotator pass | Dual annotation complete → extract disagreements (~17%) | Finish gold set, adjudicate disagreements with Person C | Baseline results in; begin ModernBERT-large (no DAPT) fine-tuning | Analyze gold set confusion patterns, recommend rubric tweaks | +| **Wed** | Final dataset assembly | Assist Person C with judge panel setup | Launch Stage 2 judge panel (Opus + GPT-5 + Gemini Pro) on disagreements | Run MMC pipeline on gold set, compare with human labels | ModernBERT-large done; begin SEC-ModernBERT-large fine-tuning | **Milestone: Gold set validated, Kappa computed** | +| **Thu** | Prepare HuggingFace dataset card | Begin active learning — cluster low-confidence cases | Judge panel results in; assemble final labeled dataset | Human-review ~500 low-confidence cases from active learning | SEC-ModernBERT-large done; begin NeoBERT experiment | Robustness split analysis (by industry, year, filing size) | +| **Fri** | **Milestone: Labeled dataset finalized (~50K paragraphs)** | **Milestone: Active learning pass complete** | QA final labels — spot-check 100 random samples | Assist Person E with evaluation | Begin ensemble experiment (SEC-ModernBERT + DeBERTa) | **Milestone: All baseline + ablation training complete** | + +### Week 3: Evaluation + Presentation + +| Day | Person A | Person B | Person C | Person D | Person E | Person F | +|-----|----------|----------|----------|----------|----------|----------| +| **Mon** | Publish dataset to HuggingFace | Run breach prediction validation (PRC + VCDB cross-ref) | Write labeling methodology section | Calibration plots for all models | Final ensemble tuning; publish model weights to HuggingFace | Compile all metrics into evaluation tables | +| **Tue** | Write data acquisition section | Run known-groups validity (industry, size effects) | Write GenAI labeling section | Boilerplate index validation (cosine similarity) | Write model strategy section | Draft full results section | +| **Wed** | Code cleanup, README for extraction pipeline | Market reaction analysis if feasible (optional) | Review/edit all written sections | Create figures: confusion matrices, calibration plots | Review/edit model section | Assemble presentation slides | +| **Thu** | **Full team: review presentation, rehearse, polish** | | | | | | +| **Fri** | **Presentation day** | | | | | | + +### Critical Path & Dependencies + +``` +Week 1: + Data extraction (A,B) ──────────────────┐ + Rubric design (C,D) ───→ Pilot test ───→ Rubric v2 ──→ GenAI labeling launch (Fri) + DAPT pre-training (E) ──────────────────────────────────→ (continues into Week 2) + Eval framework (F) ─────────────────────────────────────→ (ready for Week 2) + +Week 2: + GenAI labeling (C) ───→ Judge panel ───→ Active learning ───→ Final labels (Fri) + Gold set (D + B) ──────────────────────→ Validated (Wed) + Fine-tuning experiments (E) ───→ Baseline → ModernBERT → SEC-ModernBERT → NeoBERT → Ensemble + Metrics (F) ───────────────────→ Robustness splits + +Week 3: + Validation tests (B,D,F) ───→ Breach prediction, known-groups, boilerplate index + Writing (all) ──────────────→ Sections → Review → Presentation + Release (A,E) ──────────────→ HuggingFace dataset + model weights +``` + +--- + +## 9. Budget + +| Item | Cost | +|------|------| +| GenAI labeling — Stage 1 dual annotation (50K × 2 models, batch) | ~$115 | +| GenAI labeling — Stage 2 judge panel (~8.5K × 3 models, batch) | ~$55 | +| Prompt caching savings | -$30 to -$40 | +| SEC EDGAR data | $0 (public domain) | +| Breach databases (PRC open data, VCDB, CISA KEV) | $0 | +| Compute (RTX 3090, already owned) | $0 | +| **Total** | **~$130-170** | + +For comparison, human annotation at $0.50/label would cost $25,000+ for single-annotated, $75,000+ for triple-annotated. + +--- + +## 10. Reference Links + +### SEC Rule & Guidance +- [SEC Final Rule 33-11216 (PDF)](https://www.sec.gov/files/rules/final/2023/33-11216.pdf) +- [SEC Fact Sheet](https://www.sec.gov/files/33-11216-fact-sheet.pdf) +- [SEC Small Business Compliance Guide](https://www.sec.gov/resources-small-businesses/small-business-compliance-guides/cybersecurity-risk-management-strategy-governance-incident-disclosure) +- [CYD iXBRL Taxonomy Guide (PDF)](https://xbrl.sec.gov/cyd/2024/cyd-taxonomy-guide-2024-09-16.pdf) + +### Law Firm Surveys & Analysis +- [Gibson Dunn S&P 100 Survey (Harvard Law Forum)](https://corpgov.law.harvard.edu/2025/01/09/cybersecurity-disclosure-overview-a-survey-of-form-10-k-cybersecurity-disclosures-by-sp-100-companies/) +- [PwC First Wave of 10-K Cyber Disclosures](https://www.pwc.com/us/en/services/consulting/cybersecurity-risk-regulatory/sec-final-cybersecurity-disclosure-rules/sec-10-k-cyber-disclosures.html) +- [Debevoise 8-K Lessons Learned](https://www.debevoisedatablog.com/2024/03/06/cybersecurity-form-8-k-tracker/) +- [Greenberg Traurig 2025 Trends Update](https://www.gtlaw.com/en/insights/2025/2/sec-cybersecurity-disclosure-trends-2025-update-on-corporate-reporting-practices) +- [Known Trends: First Year of 8-K Filings](https://www.knowntrends.com/2025/02/snapshot-the-first-year-of-cybersecurity-incident-filings-on-form-8-k-since-adoption-of-new-rules/) +- [NYU: Lessons Learned from 8-K Reporting](https://wp.nyu.edu/compliance_enforcement/2025/03/25/lessons-learned-one-year-of-form-8-k-material-cybersecurity-incident-reporting/) + +### Data Extraction Tools +- [edgar-crawler (GitHub)](https://github.com/lefterisloukas/edgar-crawler) +- [edgartools (GitHub)](https://github.com/dgunning/edgartools) +- [sec-edgar-downloader (PyPI)](https://pypi.org/project/sec-edgar-downloader/) +- [sec-8k-item105 (GitHub)](https://github.com/JMousqueton/sec-8k-item105) +- [SECurityTr8Ker (GitHub)](https://github.com/pancak3lullz/SECurityTr8Ker) +- [SEC EDGAR APIs](https://www.sec.gov/search-filings/edgar-application-programming-interfaces) +- [SEC EDGAR Full-Text Search](https://efts.sec.gov/LATEST/search-index) + +### Datasets +- [PleIAs/SEC — 373K 10-K texts (HuggingFace, CC0)](https://huggingface.co/datasets/PleIAs/SEC) +- [EDGAR-CORPUS — 220K filings, sections parsed (HuggingFace, Apache 2.0)](https://huggingface.co/datasets/eloukas/edgar-corpus) +- [Board Cybersecurity 23-Feature Analysis](https://www.board-cybersecurity.com/research/insights/risk-frameworks-security-standards-in-10k-item-1c-cybersecurity-disclosures-through-2024-06-30/) +- [Board Cybersecurity Incident Tracker](https://www.board-cybersecurity.com/incidents/tracker) +- [PRC Mendeley Breach Dataset (with tickers)](http://dx.doi.org/10.17632/w33nhh3282.1) +- [VCDB (GitHub)](https://github.com/vz-risk/VCDB) +- [CISA KEV Catalog](https://www.cisa.gov/known-exploited-vulnerabilities-catalog) +- [zeroshot/cybersecurity-corpus (HuggingFace)](https://huggingface.co/datasets/zeroshot/cybersecurity-corpus) + +### Models +- [ModernBERT-large (HuggingFace, Apache 2.0)](https://huggingface.co/answerdotai/ModernBERT-large) +- [ModernBERT-base (HuggingFace, Apache 2.0)](https://huggingface.co/answerdotai/ModernBERT-base) +- [NeoBERT (HuggingFace, MIT)](https://huggingface.co/chandar-lab/NeoBERT) +- [DeBERTa-v3-large (HuggingFace, MIT)](https://huggingface.co/microsoft/deberta-v3-large) +- [SEC-BERT (HuggingFace)](https://huggingface.co/nlpaueb/sec-bert-base) +- [ProsusAI FinBERT (HuggingFace)](https://huggingface.co/ProsusAI/finbert) +- [EvasionBench Eva-4B-V2 (HuggingFace)](https://huggingface.co/FutureMa/Eva-4B-V2) + +### Key Papers +- Ringel (2023), "Creating Synthetic Experts with Generative AI" — [SSRN:4542949](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4542949) +- Ludwig et al. (2026), "Extracting Consumer Insight from Text" — [arXiv:2602.15312](https://arxiv.org/abs/2602.15312) +- Ma et al. (2026), "EvasionBench" — [arXiv:2601.09142](https://arxiv.org/abs/2601.09142) +- Florackis et al. (2023), "Cybersecurity Risk" — [SSRN:3725130](https://papers.ssrn.com/sol3/papers.cfm?abstract_id=3725130) +- Gururangan et al. (2020), "Don't Stop Pretraining" — [arXiv:2004.10964](https://arxiv.org/abs/2004.10964) +- ModernBERT paper — [arXiv:2412.13663](https://arxiv.org/abs/2412.13663) +- NeoBERT paper — [arXiv:2502.19587](https://arxiv.org/abs/2502.19587) +- ModernBERT vs DeBERTa-v3 comparison — [arXiv:2504.08716](https://arxiv.org/abs/2504.08716) +- Patent domain ModernBERT DAPT — [arXiv:2509.14926](https://arxiv.org/abs/2509.14926) +- SEC filing scaling laws for continued pre-training — [arXiv:2512.12384](https://arxiv.org/abs/2512.12384) +- Gunel et al. (2020), Supervised Contrastive Learning for fine-tuning — [OpenReview](https://openreview.net/forum?id=cu7IUiOhujH) +- Phil Schmid, "Fine-tune classifier with ModernBERT in 2025" — [philschmid.de](https://www.philschmid.de/fine-tune-modern-bert-in-2025) +- Berkman et al. (2018), Cybersecurity disclosure quality scoring +- Li, No, and Boritz (2023), BERT-based classification of cybersecurity disclosures +- Scalable 10-K Analysis with LLMs — [arXiv:2409.17581](https://arxiv.org/abs/2409.17581) +- SecureBERT — [arXiv:2204.02685](https://arxiv.org/abs/2204.02685) +- Gilardi et al. (2023), "ChatGPT Outperforms Crowd-Workers" (PNAS) — [arXiv:2303.15056](https://arxiv.org/abs/2303.15056) +- Pangakis et al. (2023), "Automated Annotation Requires Validation" — [arXiv:2306.00176](https://arxiv.org/abs/2306.00176) + +### Methodological Playbook +- [Ringel 2026 Capstone Pipeline Example (ZIP)](http://ringel.ai/UNC/2026/helpers/Ringel_2026_VerticalAI_Capstone_Pipeline_Example.zip) +- [Class 21 Exemplary Presentation (PDF)](http://www.ringel.ai/UNC/2026/BUSI488/Class21/Ringel_488-2026_Class21.pdf) diff --git a/docs/reference/Ringel_2026_VerticalAI_Capstone_Pipeline_Example.ipynb b/docs/reference/Ringel_2026_VerticalAI_Capstone_Pipeline_Example.ipynb new file mode 100644 index 0000000..c17fc3e --- /dev/null +++ b/docs/reference/Ringel_2026_VerticalAI_Capstone_Pipeline_Example.ipynb @@ -0,0 +1,4433 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f2658472-1630-4aae-aaa4-01987b0eff1e", + "metadata": {}, + "source": [ + "# Training a Vertical AI for 10K Text Classification in Business Functions\n", + "Supports APIs of OpenAI (Responses API), Anthropic, xAI, Fireworks AI, Google, UNC Azure hosted\n", + "\n", + "February 4, 2026 \n", + "\n", + "*Version 1.0*\n", + "\n", + "Copyright 2026 Daniel M. Ringel \n", + "www.ringel.ai \n", + "\n", + "\n", + "**Please cite this paper** if you use any part or all of this code in a project - be it commercial or academic: \n", + "\n", + "> Ringel, Daniel, *Creating Synthetic Experts with Generative Artificial Intelligence* (December 5, 2023). Kenan Institute of Private Enterprise Research Paper No. 4542949, Available at SSRN: https://ssrn.com/abstract=4542949 or http://dx.doi.org/10.2139/ssrn.4542949 \n" + ] + }, + { + "cell_type": "markdown", + "id": "4f4233be-6c67-41db-bc89-7b47bde71d59", + "metadata": {}, + "source": [ + "Query various serverless genAI models to classify text by example of a multi-label classification problem\n", + "- Sentences from 10K filings of fortune 500 companies\n", + "- Construct of Interest: Business functions\n", + "- Valid labels: Marketing, Finance, Accounting, Operations, IT, HR" + ] + }, + { + "cell_type": "markdown", + "id": "eda002cb-c5ef-4a68-ba10-4d4ee65e836a", + "metadata": {}, + "source": [ + "> **IMPORTANT** Running this code will cost you API credits (and requires you to ahve accounts with the providers). You will need to supply your own API keys. Beware that you may be subject to rate limits (how many queries you can send per minute) and which models you can use (OpenAI, for example, requires you to verify your identidy with an ID to access many models). Regardless, every time you execute this code, you will drain your API credits = real money! Thus, make wise decisions about what and how much to label." + ] + }, + { + "cell_type": "markdown", + "id": "9228e155-f4b0-45f8-96ba-8cb99c0bb990", + "metadata": {}, + "source": [ + "## **Disclaimer**\n", + "\n", + "**USE AT YOUR OWN RISK**\n", + "\n", + "This code is provided \"as is\" without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and noninfringement.\n", + "\n", + "Under no circumstances shall Daniel Ringel be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services, loss of use, data, or profits, or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this code, even if advised of the possibility of such damage.\n", + "\n", + "**Additional Notes:**\n", + "- API costs incurred from running this code are your sole responsibility\n", + "- Verify all API pricing before running at scale\n", + "- Test with small samples before processing large datasets\n", + "- The author makes no guarantees about the accuracy, reliability, or completeness of results\n", + "- This code is for educational and research purposes\n", + "\n", + "By using this code, you acknowledge that you have read this disclaimer and agree to its terms.\n" + ] + }, + { + "cell_type": "markdown", + "id": "4eed4986-cdb9-4285-80d4-7c2e415a5268", + "metadata": {}, + "source": [ + "# Installs and updates\n", + "- need to run only once if you are on your own computer\n", + "- on CoLab, you may need to run each time" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8714eefd-4af8-4116-a4fb-922bf1aa4667", + "metadata": {}, + "outputs": [], + "source": [ + "# Core\n", + "!pip install --upgrade pandas tqdm\n", + "# OpenAI (also used for xAI and Fireworks)\n", + "!pip install --upgrade openai\n", + "# Anthropic\n", + "!pip install --upgrade anthropic\n", + "# Google Gemini\n", + "!pip install --upgrade google-genai google-api-core\n", + "# For Label Agreement\n", + "!pip install -q -U krippendorff \n", + "# For Fine-tuning a pretrained LLM\n", + "!pip install -q -U transformers datasets accelerate scikit-learn\n", + "!pip install iterative-stratification" + ] + }, + { + "cell_type": "markdown", + "id": "9ef820f3-b6b2-4799-9af8-45e8e4ac5194", + "metadata": {}, + "source": [ + "# Set Environmental Variables (API Keys) on local Computer \n", + "### *(or on Colab in Secret Keys Tab)*\n", + "> **DO NOT SHARE API KEYS!!!** Delte them before sharing this notebook\n", + " \n", + "> On ***Colab***, use \"Secrets\" in Tab (key) on right. Make sure to use the exact same API Key names (e.g., \"OPENAI_API_KEY\") spelled and capitalized as shown!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c82d0c2c-8b19-49e0-80d1-1f7ebeb53f78", + "metadata": {}, + "outputs": [], + "source": [ + "# Put your API Keys here if you run this locally.\n", + "import os\n", + "os.environ[\"OPENAI_API_KEY\"] = \"\"\n", + "os.environ[\"ANTHROPIC_API_KEY\"] = \"\"\n", + "os.environ[\"FIREWORKS_API_KEY\"] = \"\"\n", + "os.environ[\"XAI_API_KEY\"] = \"\"\n", + "os.environ[\"GOOGLE_API_KEY\"] = \"\"\n", + "os.environ[\"AZURE_API_KEY\"] = \"\"" + ] + }, + { + "cell_type": "markdown", + "id": "8803ce06-67df-49d4-851e-2617cbf157be", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0dbd424-0d89-4b2a-af13-aa98bc07f0a6", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import time\n", + "import datetime\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "\n", + "# Vendor SDKs\n", + "import openai\n", + "from openai import OpenAI, AzureOpenAI, RateLimitError, APIError, AuthenticationError\n", + "import anthropic\n", + "from google import genai\n", + "from google.genai import types\n", + "from google.api_core import exceptions as google_exceptions" + ] + }, + { + "cell_type": "markdown", + "id": "e3be8e22-dd79-4417-bf34-235900d5487a", + "metadata": {}, + "source": [ + "# Vendor/Model Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c180681-cd7c-4d27-a429-06a149d8933c", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Vendor/Model Configuration\n", + "# All prices per 1M tokens (USD) ----> UPDATE THE PRICES!!!!\n", + "# supplier = who made the model (for when using 3rd party APIs like Fireworks)\n", + "# -----------------------------\n", + "VENDORS = {\n", + " \"openai\": {\n", + " \"models\": {\n", + " # GPT-5 family\n", + " \"gpt-5.2\": {\"supplier\": \"openai\", \"price_in\": 1.75, \"price_out\": 14.00},\n", + " \"gpt-5\": {\"supplier\": \"openai\", \"price_in\": 1.25, \"price_out\": 10.00},\n", + " \"gpt-5-mini\": {\"supplier\": \"openai\", \"price_in\": 0.30, \"price_out\": 2.50},\n", + " \"gpt-5-nano\": {\"supplier\": \"openai\", \"price_in\": 0.10, \"price_out\": 0.40},\n", + " # GPT-4.1 family\n", + " \"gpt-4.1\": {\"supplier\": \"openai\", \"price_in\": 2.00, \"price_out\": 8.00},\n", + " # GPT-4o family\n", + " \"gpt-4o\": {\"supplier\": \"openai\", \"price_in\": 2.50, \"price_out\": 10.00},\n", + " }\n", + " },\n", + " \"azure\": {\n", + " \"models\": {\n", + " \"gpt-4.1\": {\"supplier\": \"openai\", \"price_in\": 2.00, \"price_out\": 8.00},\n", + " \"gpt-4o\": {\"supplier\": \"openai\", \"price_in\": 2.50, \"price_out\": 8.00},\n", + " },\n", + " \"endpoint\": \"https://azureaiapi.cloud.unc.edu\",\n", + " \"api_version\": \"2025-04-01-preview\",\n", + " },\n", + " \"anthropic\": {\n", + " \"models\": {\n", + " # Claude 4.5 family\n", + " \"claude-opus-4-5-20251101\": {\"supplier\": \"anthropic\", \"price_in\": 5.00, \"price_out\": 25.00},\n", + " \"claude-sonnet-4-5-20250929\": {\"supplier\": \"anthropic\", \"price_in\": 3.00, \"price_out\": 15.00},\n", + " \"claude-haiku-4-5-20251001\": {\"supplier\": \"anthropic\", \"price_in\": 1.00, \"price_out\": 5.00},\n", + " # Claude 4 family\n", + " \"claude-sonnet-4-20250514\": {\"supplier\": \"anthropic\", \"price_in\": 3.00, \"price_out\": 15.00},\n", + " }\n", + " },\n", + " \"google\": {\n", + " \"models\": {\n", + " # Gemini 3 - latest flagship - preview\n", + " \"gemini-3-pro-preview\": {\"supplier\": \"google\", \"price_in\": 2.00, \"price_out\": 12.00},\n", + " \"gemini-3-flash-preview\": {\"supplier\": \"google\", \"price_in\": 0.50, \"price_out\": 3.00},\n", + " # Gemini 2.5 - stable\n", + " \"gemini-2.5-pro\": {\"supplier\": \"google\", \"price_in\": 1.25, \"price_out\": 10.00},\n", + " \"gemini-2.5-flash\": {\"supplier\": \"google\", \"price_in\": 0.30, \"price_out\": 2.50},\n", + " }\n", + " },\n", + " \"xai\": {\n", + " \"models\": {\n", + " \"grok-4\": {\"supplier\": \"xai\", \"price_in\": 3.00, \"price_out\": 15.00},\n", + " \"grok-4-1-fast-reasoning\": {\"supplier\": \"xai\", \"price_in\": 0.20, \"price_out\": 0.50},\n", + " \"grok-4-1-fast-non-reasoning\": {\"supplier\": \"xai\", \"price_in\": 0.20, \"price_out\": 0.50}\n", + " }\n", + " },\n", + " \"fireworks\": {\n", + " \"models\": {\n", + " \"deepseek-v3p2\": {\"supplier\": \"deepseek\", \"price_in\": 0.56, \"price_out\": 1.68},\n", + " \"qwen3-vl-235b-a22b-instruct\": {\"supplier\": \"alibaba\", \"price_in\": 0.22, \"price_out\": 0.88},\n", + " \"deepseek-r1-0528\": {\"supplier\": \"deepseek\", \"price_in\": 1.35, \"price_out\": 5.40},\n", + " \"qwen3-vl-235b-a22b-thinking\": {\"supplier\": \"alibaba\", \"price_in\": 0.22, \"price_out\": 0.88},\n", + " \"kimi-k2p5\": {\"supplier\": \"moonshot\", \"price_in\": 1.20, \"price_out\": 1.20}\n", + " }\n", + " } \n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "3c799dc8-e11d-43f5-b48a-61ce699dcfec", + "metadata": {}, + "source": [ + "# System Prompt\n", + "- Be clear\n", + "- Be specific\n", + "- Try RTF (role task format)\n", + "- Succinct and exhaustive construct definitions\n", + "- Could give examples (few-shot), but may create noise or focus model too much on these cases\n", + "- Explain tie-breakers or dricky cases\n", + "- Define output format." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b46f6ca7-d791-48b2-838d-039fd696780d", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# System Prompt\n", + "# -----------------------------\n", + "SYSTEM_PROMPT = \"\"\"You are a business analyst classifying sentences from 10-K filings.\n", + "\n", + "Classify the sentence into one or more business functions: Marketing, Finance, Accounting, Operations, IT, HR (human resources)\n", + "\n", + "Definitions:\n", + "- Marketing: Customers, markets, demand, branding, advertising, promotion, pricing strategy, market research, segmentation, positioning, sales strategy, sales channels.\n", + "- Finance: Capital structure, funding, liquidity, treasury, investing, valuation, financial risk management (interest rates, FX, hedging), dividends, buybacks, M&A, financing activities.\n", + "- Accounting: Financial reporting, disclosures, GAAP/IFRS, accounting policies, accounting estimates, revenue recognition, impairments, reserves, depreciation, amortization, audits, internal controls over financial reporting (ICFR), tax accounting.\n", + "- Operations: Production, service delivery, supply chain, logistics, procurement, inventory management, manufacturing, capacity planning, facilities, quality control, safety, process efficiency, fulfillment, operational infrastructure, IT systems.\n", + "- IT: Information technology systems, software, hardware, cybersecurity, data management, cloud computing, digital infrastructure, technology platforms, system integration, IT support, technology investments, data analytics infrastructure.\n", + "- HR: Human resources, workforce, hiring, recruitment, talent acquisition, employee benefits, compensation, training, professional development, labor relations, employee retention, workplace safety, organizational culture.\n", + "\n", + "Rules:\n", + "1. Assign labels only when there is clear, direct evidence in the sentence.\n", + "2. Assign multiple labels if clearly relevant to more than one field.\n", + "3. Tie-breakers: Reporting/policies/controls/disclosures → Accounting; Funding/treasury/hedging/M&A → Finance.\n", + "4. Out of scope: General corporate governance, board matters, executive compensation, legal proceedings → return empty array.\n", + "5. Output: Return ONLY a JSON array with exact spelling: \"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\". If none apply, return [].\n", + "6. DO NOT provide a reason or explanation for your labels.\"\"\"\n", + " \n", + "ALLOWED_LABELS = {\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"}\n", + "LABEL_ORDER = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]" + ] + }, + { + "cell_type": "markdown", + "id": "09e63238-a066-4a80-95cd-1eacb59d19d5", + "metadata": {}, + "source": [ + "# Client Initialization & API Call Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2069211c-abff-4c85-8a5e-6de3539431dd", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Client Initialization on a local PC (requires setting them at the beginning) or \n", + "# on Google CoLab (requires defining them in the secrets tab with key symbol)\n", + "# -----------------------------\n", + "def get_api_key(key_name: str) -> str:\n", + " \"\"\"Get API key from Colab secrets or environment variables.\"\"\"\n", + " try:\n", + " from google.colab import userdata\n", + " return userdata.get(key_name)\n", + " except (ImportError, ModuleNotFoundError):\n", + " # Not in Colab, use environment variables\n", + " return os.environ[key_name]\n", + "\n", + "\n", + "def init_client(vendor: str):\n", + " \"\"\"Initialize API client for vendor.\"\"\"\n", + " if vendor == \"openai\":\n", + " return OpenAI(api_key=get_api_key(\"OPENAI_API_KEY\"))\n", + " elif vendor == \"azure\":\n", + " return AzureOpenAI(\n", + " api_key=get_api_key(\"AZURE_API_KEY\"),\n", + " azure_endpoint=VENDORS[\"azure\"][\"endpoint\"],\n", + " api_version=VENDORS[\"azure\"][\"api_version\"],\n", + " ) \n", + " elif vendor == \"anthropic\":\n", + " return anthropic.Anthropic(api_key=get_api_key(\"ANTHROPIC_API_KEY\"))\n", + " elif vendor == \"google\":\n", + " return genai.Client(api_key=get_api_key(\"GOOGLE_API_KEY\"))\n", + " elif vendor == \"xai\":\n", + " return OpenAI(\n", + " api_key=get_api_key(\"XAI_API_KEY\"),\n", + " base_url=\"https://api.x.ai/v1\"\n", + " )\n", + " elif vendor == \"fireworks\":\n", + " return OpenAI(\n", + " api_key=get_api_key(\"FIREWORKS_API_KEY\"),\n", + " base_url=\"https://api.fireworks.ai/inference/v1\"\n", + " )\n", + " else:\n", + " raise ValueError(f\"Unknown vendor: {vendor}\")\n", + "\n", + "# -----------------------------\n", + "# API Call Functions\n", + "# -----------------------------\n", + "def call_openai(client, sentence: str, model: str, system_prompt: str,\n", + " max_tokens: int = 64, reasoning_effort: str = None) -> dict:\n", + " \"\"\"\n", + " Call OpenAI Responses API.\n", + " \n", + " Token billing:\n", + " - input_tokens: billed at input rate\n", + " - output_tokens: includes reasoning_tokens, all billed at output rate\n", + " - reasoning_tokens: subset of output_tokens (internal thinking)\n", + " \n", + " Note: For reasoning models, max_output_tokens must accommodate BOTH\n", + " reasoning tokens AND response tokens. We scale up accordingly.\n", + " \"\"\"\n", + " # For reasoning models, we need more output tokens to fit both reasoning + response\n", + " if reasoning_effort:\n", + " # Reasoning needs room: base response tokens + reasoning overhead\n", + " # low ~500, medium ~1000, high ~2000+ reasoning tokens typical\n", + " reasoning_overhead = {\"low\": 512, \"medium\": 1024, \"high\": 2048}.get(reasoning_effort, 512)\n", + " effective_max_tokens = max_tokens + reasoning_overhead\n", + " else:\n", + " effective_max_tokens = max_tokens\n", + " \n", + " params = {\n", + " \"model\": model,\n", + " \"instructions\": system_prompt,\n", + " \"input\": sentence,\n", + " \"max_output_tokens\": effective_max_tokens,\n", + " }\n", + " \n", + " if reasoning_effort:\n", + " params[\"reasoning\"] = {\"effort\": reasoning_effort}\n", + " else:\n", + " params[\"temperature\"] = 0\n", + " \n", + " resp = client.responses.create(**params)\n", + " \n", + " # Extract text - output_text may be empty for some response types\n", + " text = \"\"\n", + " if resp.output_text:\n", + " text = resp.output_text.strip()\n", + " else:\n", + " # Fallback: extract from output items\n", + " for item in resp.output:\n", + " if getattr(item, 'type', None) == 'message':\n", + " for block in getattr(item, 'content', []):\n", + " if getattr(block, 'type', None) == 'text':\n", + " text += getattr(block, 'text', '')\n", + " text = text.strip()\n", + " \n", + " # Extract token breakdown\n", + " input_tokens = resp.usage.input_tokens\n", + " output_tokens = resp.usage.output_tokens\n", + " \n", + " # Reasoning tokens are part of output_tokens (not additional)\n", + " reasoning_tokens = 0\n", + " if hasattr(resp.usage, 'output_tokens_details') and resp.usage.output_tokens_details:\n", + " reasoning_tokens = getattr(resp.usage.output_tokens_details, 'reasoning_tokens', 0) or 0\n", + " \n", + " return {\n", + " \"text\": text,\n", + " \"input_tokens\": input_tokens,\n", + " \"output_tokens\": output_tokens, # total output (includes reasoning)\n", + " \"reasoning_tokens\": reasoning_tokens, # internal thinking (subset of output)\n", + " \"response_tokens\": output_tokens - reasoning_tokens, # visible response\n", + " \"raw_response\": resp.model_dump() if hasattr(resp, 'model_dump') else str(resp),\n", + " }\n", + "\n", + "def call_azure(client, sentence: str, model: str, system_prompt: str,\n", + " max_tokens: int = 64) -> dict:\n", + " \"\"\"\n", + " Call Azure OpenAI API (chat completions).\n", + " \n", + " Azure uses the standard chat completions endpoint, not Responses API.\n", + " \"\"\"\n", + " params = {\n", + " \"model\": model,\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": sentence}\n", + " ],\n", + " \"max_tokens\": max_tokens,\n", + " \"temperature\": 0, # Set temp to 0 to be deterministic\n", + " }\n", + " \n", + " resp = client.chat.completions.create(**params)\n", + " \n", + " text = (resp.choices[0].message.content or \"\").strip()\n", + " \n", + " return {\n", + " \"text\": text,\n", + " \"input_tokens\": resp.usage.prompt_tokens,\n", + " \"output_tokens\": resp.usage.completion_tokens,\n", + " \"reasoning_tokens\": 0,\n", + " \"response_tokens\": resp.usage.completion_tokens,\n", + " \"raw_response\": resp.model_dump() if hasattr(resp, 'model_dump') else str(resp),\n", + " }\n", + " \n", + "def call_anthropic(client, sentence: str, model: str, system_prompt: str,\n", + " max_tokens: int = 64, thinking_budget: int = None) -> dict:\n", + " \"\"\"\n", + " Call Anthropic Messages API.\n", + " \n", + " Token billing:\n", + " - input_tokens: billed at input rate\n", + " - output_tokens: includes thinking tokens, all billed at output rate\n", + " - NOTE: Anthropic doesn't provide separate thinking token count in usage.\n", + " The output_tokens is what's billed (includes full thinking, not summary).\n", + " \"\"\"\n", + " params = {\n", + " \"model\": model,\n", + " \"system\": system_prompt,\n", + " \"messages\": [{\"role\": \"user\", \"content\": sentence}],\n", + " }\n", + " \n", + " if thinking_budget:\n", + " thinking_budget = max(1024, thinking_budget)\n", + " params[\"thinking\"] = {\"type\": \"enabled\", \"budget_tokens\": thinking_budget}\n", + " params[\"max_tokens\"] = thinking_budget + max_tokens\n", + " else:\n", + " params[\"temperature\"] = 0\n", + " params[\"max_tokens\"] = max_tokens\n", + " \n", + " resp = client.messages.create(**params)\n", + " \n", + " # Extract text (skip thinking blocks - they're summarized in Claude 4)\n", + " text = \"\".join(b.text for b in resp.content if b.type == \"text\").strip()\n", + " \n", + " # output_tokens includes thinking tokens (billed amount)\n", + " # Anthropic doesn't provide breakdown like OpenAI does\n", + " output_tokens = resp.usage.output_tokens\n", + " \n", + " return {\n", + " \"text\": text,\n", + " \"input_tokens\": resp.usage.input_tokens,\n", + " \"output_tokens\": output_tokens, # total billed (includes thinking)\n", + " \"reasoning_tokens\": 0, # Anthropic doesn't provide breakdown\n", + " \"response_tokens\": output_tokens, # Can't separate, use total\n", + " \"raw_response\": resp.model_dump() if hasattr(resp, 'model_dump') else str(resp),\n", + " }\n", + "\n", + "def call_google(client, sentence: str, model: str, system_prompt: str,\n", + " max_tokens: int = 64, thinking_level: str = None) -> dict:\n", + " \"\"\"\n", + " Call Google Gemini API.\n", + " \n", + " Gemini 3 Pro: thinking_level \"low\", \"high\" (default)\n", + " Gemini 3 Flash: thinking_level \"minimal\", \"low\", \"medium\", \"high\" (default)\n", + " \n", + " Note: Even \"low\" still uses thinking tokens! Only Flash \"minimal\" truly minimizes.\n", + " \n", + " Token billing:\n", + " - output tokens include thinking tokens (no separate breakdown)\n", + " \"\"\"\n", + " from google.genai import types\n", + " \n", + " # Build config\n", + " config_params = {}\n", + " \n", + " # Determine token allocation based on thinking level: Even \"low\" uses thinking tokens (~50-200), so we need buffer\n", + " if thinking_level == \"minimal\":\n", + " # Only Flash supports minimal - truly minimal thinking\n", + " config_params[\"max_output_tokens\"] = max_tokens + 128\n", + " elif thinking_level == \"low\":\n", + " # Low still uses thinking tokens, need buffer\n", + " config_params[\"max_output_tokens\"] = max_tokens + 512\n", + " elif thinking_level == \"medium\":\n", + " config_params[\"max_output_tokens\"] = max_tokens + 2048\n", + " else:\n", + " # high or default (None) - full thinking\n", + " config_params[\"max_output_tokens\"] = max_tokens + 4096\n", + " \n", + " # Set thinking level if specified\n", + " if thinking_level:\n", + " config_params[\"thinking_config\"] = types.ThinkingConfig(\n", + " thinking_level=thinking_level\n", + " )\n", + " \n", + " # Gemini 3 recommends temperature=1.0 (default), don't override\n", + " \n", + " config = types.GenerateContentConfig(**config_params)\n", + " \n", + " # Combine system prompt and sentence\n", + " full_prompt = f\"{system_prompt}\\n{sentence}\"\n", + " \n", + " resp = client.models.generate_content(\n", + " model=model,\n", + " contents=full_prompt,\n", + " config=config,\n", + " )\n", + " \n", + " # Extract text - resp.text may be empty, need to check parts\n", + " text = \"\"\n", + " if resp.text:\n", + " text = resp.text.strip()\n", + " elif resp.candidates and resp.candidates[0].content and resp.candidates[0].content.parts:\n", + " # Extract text from parts, skipping thinking parts\n", + " for part in resp.candidates[0].content.parts:\n", + " # Skip thinking parts (they have 'thought' attribute set to True)\n", + " if hasattr(part, 'thought') and part.thought:\n", + " continue\n", + " if hasattr(part, 'text') and part.text:\n", + " text += part.text\n", + " text = text.strip()\n", + " \n", + " # Token usage\n", + " usage = resp.usage_metadata\n", + " input_tokens = usage.prompt_token_count\n", + " output_tokens = usage.candidates_token_count\n", + " # - candidates_token_count = actual response tokens\n", + " # - thoughts_token_count = thinking/reasoning tokens (billed as output)\n", + " thoughts_tokens = getattr(resp.usage_metadata, 'thoughts_token_count', 0) or 0\n", + " candidates_tokens = resp.usage_metadata.candidates_token_count or 0\n", + "\n", + " return {\n", + " \"text\": text,\n", + " \"input_tokens\": resp.usage_metadata.prompt_token_count,\n", + " \"output_tokens\": candidates_tokens + thoughts_tokens, # Total billed output\n", + " \"reasoning_tokens\": thoughts_tokens,\n", + " \"response_tokens\": candidates_tokens,\n", + " \"raw_response\": str(resp),\n", + " }\n", + " \n", + "def call_fireworks(client, sentence: str, model: str, system_prompt: str,\n", + " max_tokens: int = 64, reasoning_effort: str = None) -> dict:\n", + " \"\"\"\n", + " Call Fireworks API using OpenAI-compatible endpoint.\n", + " \n", + " Thinking models (*-thinking, r1, kimi) output reasoning in content,\n", + " so we allocate extra tokens and use recommended temperature.\n", + " \"\"\"\n", + " full_model_id = f\"accounts/fireworks/models/{model}\"\n", + " \n", + " # Thinking models need more tokens (thinking appears in content)\n", + " is_thinking_model = \"thinking\" in model or \"r1\" in model or model == \"kimi-k2p5\"\n", + " \n", + " if is_thinking_model:\n", + " effective_max_tokens = max_tokens + 4096\n", + " temperature = 0.6 # Recommended for thinking\n", + " else:\n", + " effective_max_tokens = max_tokens\n", + " temperature = 0 # Deterministic for chat\n", + " \n", + " params = {\n", + " \"model\": full_model_id,\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": sentence}\n", + " ],\n", + " \"max_tokens\": effective_max_tokens,\n", + " \"temperature\": temperature,\n", + " }\n", + "\n", + " resp = client.chat.completions.create(**params)\n", + " \n", + " text = (resp.choices[0].message.content or \"\").strip()\n", + " \n", + " return {\n", + " \"text\": text,\n", + " \"input_tokens\": resp.usage.prompt_tokens,\n", + " \"output_tokens\": resp.usage.completion_tokens,\n", + " \"reasoning_tokens\": 0,\n", + " \"response_tokens\": resp.usage.completion_tokens,\n", + " \"raw_response\": resp.model_dump() if hasattr(resp, 'model_dump') else str(resp),\n", + " }\n", + "\n", + "def call_xai(client, sentence: str, model: str, system_prompt: str,\n", + " max_tokens: int = 64) -> dict:\n", + " \"\"\"\n", + " Call xAI API using OpenAI-compatible endpoint.\n", + " \n", + " grok-4 and *-reasoning models are always reasoning.\n", + " *-non-reasoning models are chat mode.\n", + " \n", + " Token billing:\n", + " - completion_tokens: visible response only\n", + " - reasoning_tokens: in completion_tokens_details, billed separately\n", + " - total output billed = completion_tokens + reasoning_tokens\n", + " \"\"\"\n", + " is_reasoning = \"non-reasoning\" not in model\n", + " \n", + " params = {\n", + " \"model\": model,\n", + " \"messages\": [\n", + " {\"role\": \"system\", \"content\": system_prompt},\n", + " {\"role\": \"user\", \"content\": sentence}\n", + " ],\n", + " \"max_tokens\": max_tokens + 4096 if is_reasoning else max_tokens,\n", + " \"temperature\": 1.0 if is_reasoning else 0,\n", + " }\n", + " \n", + " resp = client.chat.completions.create(**params)\n", + " text = (resp.choices[0].message.content or \"\").strip()\n", + " \n", + " # Extract reasoning tokens from completion_tokens_details\n", + " reasoning_tokens = 0\n", + " if hasattr(resp.usage, 'completion_tokens_details') and resp.usage.completion_tokens_details:\n", + " reasoning_tokens = getattr(resp.usage.completion_tokens_details, 'reasoning_tokens', 0) or 0\n", + " \n", + " # completion_tokens is visible response, reasoning_tokens is separate\n", + " # Both are billed at output rate\n", + " completion_tokens = resp.usage.completion_tokens\n", + " total_output = completion_tokens + reasoning_tokens\n", + " \n", + " return {\n", + " \"text\": text,\n", + " \"input_tokens\": resp.usage.prompt_tokens,\n", + " \"output_tokens\": total_output, # Total billed at output rate\n", + " \"reasoning_tokens\": reasoning_tokens,\n", + " \"response_tokens\": completion_tokens, # Visible response only\n", + " \"raw_response\": resp.model_dump() if hasattr(resp, 'model_dump') else str(resp),\n", + " }" + ] + }, + { + "cell_type": "markdown", + "id": "35a1a8ba-0a5d-4d1d-a10c-5404ca9a2919", + "metadata": {}, + "source": [ + "# Parsing & Cost" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae5e0d1f-be15-4cd1-84d7-3a05743e7c0a", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Parsing & Cost\n", + "# -----------------------------\n", + "\n", + "def parse_labels(text: str) -> list:\n", + " \"\"\"Parse JSON array, validate labels. Handles markdown code fences and extra text.\"\"\"\n", + " text = text.strip()\n", + "\n", + " # Strip thinking tags for models that use them (e.g., DeepSeek R1, QwQ)\n", + " if \"\" in text:\n", + " text = text.split(\"\")[-1].strip()\n", + " elif \"\" in text and \"\" not in text:\n", + " # Incomplete thinking block - try to find JSON after it\n", + " pass\n", + " \n", + " # Strip markdown code fences if present\n", + " if \"```\" in text:\n", + " lines = text.split(\"\\n\")\n", + " lines = [l for l in lines if not l.strip().startswith(\"```\")]\n", + " text = \"\\n\".join(lines).strip()\n", + " \n", + " # Find the LAST JSON array (thinking models put reasoning before answer)\n", + " # Search backwards for the final [...] pattern\n", + " start = text.rfind(\"[\")\n", + " if start == -1:\n", + " raise ValueError(\"No JSON array found\")\n", + " \n", + " # Find matching closing bracket\n", + " depth = 0\n", + " end = start\n", + " for i, char in enumerate(text[start:], start):\n", + " if char == \"[\":\n", + " depth += 1\n", + " elif char == \"]\":\n", + " depth -= 1\n", + " if depth == 0:\n", + " end = i + 1\n", + " break\n", + " \n", + " json_str = text[start:end]\n", + " data = json.loads(json_str)\n", + " \n", + " if not isinstance(data, list):\n", + " raise ValueError(\"Not a JSON array\")\n", + " invalid = [x for x in data if x not in ALLOWED_LABELS]\n", + " if invalid:\n", + " raise ValueError(f\"Invalid labels: {invalid}\")\n", + " return [label for label in LABEL_ORDER if label in set(data)]\n", + "\n", + "\n", + "def compute_cost(model_info: dict, input_tokens: int, output_tokens: int) -> float:\n", + " \"\"\"\n", + " Compute cost in USD.\n", + " Note: output_tokens includes reasoning/thinking tokens, all billed at output rate.\n", + " \"\"\"\n", + " return (input_tokens / 1e6) * model_info[\"price_in\"] + (output_tokens / 1e6) * model_info[\"price_out\"]\n", + "\n", + "\n", + "def get_mode_string(vendor: str, model: str = None, reasoning_effort: str = None, \n", + " thinking_budget: int = None, thinking_level: str = None) -> str:\n", + " \"\"\"Get mode string for filenames.\"\"\"\n", + " if vendor == \"openai\":\n", + " return f\"reasoning-{reasoning_effort}\" if reasoning_effort else \"chat\"\n", + " elif vendor == \"azure\":\n", + " return \"chat\"\n", + " elif vendor == \"anthropic\":\n", + " return f\"thinking-{thinking_budget}\" if thinking_budget else \"chat\"\n", + " elif vendor == \"google\":\n", + " if thinking_level == \"minimal\":\n", + " return \"chat\" # Only Flash supports minimal\n", + " elif thinking_level == \"low\":\n", + " return \"thinking-low\" # Still uses some thinking\n", + " elif thinking_level:\n", + " return f\"thinking-{thinking_level}\"\n", + " return \"thinking\" # default for Gemini 3 \n", + " elif vendor == \"fireworks\":\n", + " if reasoning_effort:\n", + " return f\"reasoning-{reasoning_effort}\"\n", + " # Auto-detected thinking models\n", + " if model and (\"thinking\" in model or \"r1\" in model or model == \"kimi-k2p5\"):\n", + " return \"thinking\"\n", + " elif vendor == \"xai\":\n", + " return \"chat\" if model and \"non-reasoning\" in model else \"reasoning\"\n", + " return \"chat\"" + ] + }, + { + "cell_type": "markdown", + "id": "9b028d13-8bd2-47e7-9f5d-f8589136d6d2", + "metadata": {}, + "source": [ + "# Classify Sentence and Retry on Error" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3cf0a31f-7403-4861-b8ab-69e6fdd65434", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Single Sentence Classification with Retry\n", + "# -----------------------------\n", + "def classify_sentence(client, sentence: str, vendor: str, model: str, model_info: dict,\n", + " system_prompt: str, max_tokens: int = 64, max_retries: int = 5,\n", + " reasoning_effort: str = None, thinking_budget: int = None,\n", + " thinking_level: str = None) -> dict:\n", + " \"\"\"\n", + " Classify a single sentence with retry logic.\n", + " \n", + " Retry wait times: 3s, 3s, 6s, 12s, 30s\n", + " Rate limit errors: 60s\n", + " \"\"\"\n", + " # Import vendor-specific exceptions only when needed\n", + " if vendor == \"openai\":\n", + " rate_limit_errors = (RateLimitError,)\n", + " api_errors = (APIError,)\n", + " auth_errors = (AuthenticationError,)\n", + " elif vendor == \"azure\":\n", + " rate_limit_errors = (RateLimitError,)\n", + " api_errors = (APIError,)\n", + " auth_errors = (AuthenticationError,)\n", + " elif vendor == \"anthropic\":\n", + " import anthropic\n", + " rate_limit_errors = (anthropic.RateLimitError,)\n", + " api_errors = (anthropic.APIError,)\n", + " auth_errors = (anthropic.AuthenticationError,)\n", + " elif vendor == \"fireworks\":\n", + " import openai\n", + " rate_limit_errors = (openai.RateLimitError,)\n", + " api_errors = (openai.APIError,)\n", + " auth_errors = (openai.AuthenticationError,)\n", + " elif vendor == \"xai\":\n", + " import openai\n", + " rate_limit_errors = (openai.RateLimitError,)\n", + " api_errors = (openai.APIError,)\n", + " auth_errors = (openai.AuthenticationError,)\n", + " elif vendor == \"google\":\n", + " from google.api_core import exceptions as google_exceptions\n", + " rate_limit_errors = (google_exceptions.ResourceExhausted,)\n", + " api_errors = (google_exceptions.GoogleAPIError,)\n", + " auth_errors = (google_exceptions.Unauthenticated, google_exceptions.PermissionDenied)\n", + " else:\n", + " rate_limit_errors = ()\n", + " api_errors = ()\n", + " auth_errors = ()\n", + " \n", + " last_error = None\n", + " error_details = None\n", + " wait_times = {1: 3, 2: 3, 3: 6, 4: 12, 5: 30}\n", + " \n", + " for attempt in range(1, max_retries + 1):\n", + " try:\n", + " t0 = time.perf_counter()\n", + " \n", + " if vendor == \"openai\":\n", + " result = call_openai(client, sentence, model, system_prompt, \n", + " max_tokens, reasoning_effort)\n", + " elif vendor == \"azure\": \n", + " result = call_azure(client, sentence, model, system_prompt,\n", + " max_tokens)\n", + " elif vendor == \"anthropic\":\n", + " result = call_anthropic(client, sentence, model, system_prompt,\n", + " max_tokens, thinking_budget)\n", + " elif vendor == \"fireworks\":\n", + " result = call_fireworks(client, sentence, model, system_prompt,\n", + " max_tokens, reasoning_effort)\n", + " elif vendor == \"xai\":\n", + " result = call_xai(client, sentence, model, system_prompt,\n", + " max_tokens)\n", + " elif vendor == \"google\":\n", + " result = call_google(client, sentence, model, system_prompt,\n", + " max_tokens, thinking_level)\n", + " else:\n", + " raise ValueError(f\"Unknown vendor: {vendor}\")\n", + " \n", + " latency = time.perf_counter() - t0\n", + " labels = parse_labels(result[\"text\"])\n", + " \n", + " return {\n", + " \"labels\": labels,\n", + " \"response_text\": result[\"text\"],\n", + " \"input_tokens\": result[\"input_tokens\"],\n", + " \"output_tokens\": result[\"output_tokens\"],\n", + " \"internal_tokens\": result.get(\"reasoning_tokens\") or result.get(\"thinking_tokens\", 0),\n", + " \"response_tokens\": result.get(\"response_tokens\", result[\"output_tokens\"]),\n", + " \"cost_usd\": compute_cost(model_info, result[\"input_tokens\"], result[\"output_tokens\"]),\n", + " \"latency_sec\": latency,\n", + " \"attempts\": attempt,\n", + " \"error\": None,\n", + " \"raw_response\": result.get(\"raw_response\"),\n", + " }\n", + " \n", + " except auth_errors as e:\n", + " # Don't retry auth errors - they won't resolve\n", + " error_details = {\"type\": \"AuthenticationError\", \"attempt\": attempt, \"message\": str(e)}\n", + " print(f\" Authentication error (not retrying): {e}\")\n", + " return {\n", + " \"labels\": None,\n", + " \"response_text\": None,\n", + " \"input_tokens\": 0,\n", + " \"output_tokens\": 0,\n", + " \"internal_tokens\": 0,\n", + " \"response_tokens\": 0,\n", + " \"cost_usd\": 0.0,\n", + " \"latency_sec\": None,\n", + " \"attempts\": attempt,\n", + " \"error\": error_details,\n", + " \"raw_response\": None,\n", + " }\n", + " \n", + " except rate_limit_errors as e:\n", + " last_error = f\"RateLimitError: {e}\"\n", + " error_details = {\"type\": \"RateLimitError\", \"attempt\": attempt}\n", + " print(f\" Rate limit hit, waiting 60s...\")\n", + " time.sleep(60)\n", + " \n", + " except api_errors as e:\n", + " last_error = f\"APIError: {e}\"\n", + " error_details = {\"type\": \"APIError\", \"attempt\": attempt, \"message\": str(e)}\n", + " wait = wait_times.get(attempt, 3)\n", + " print(f\" API error (attempt {attempt}), waiting {wait}s...\")\n", + " time.sleep(wait)\n", + " \n", + " except (json.JSONDecodeError, ValueError) as e:\n", + " last_error = f\"ParseError: {e}\"\n", + " response_preview = result.get(\"text\", \"\")[:100] if 'result' in dir() and result else \"\"\n", + " error_details = {\"type\": \"ParseError\", \"attempt\": attempt, \"message\": str(e), \"response_preview\": response_preview}\n", + " wait = wait_times.get(attempt, 3)\n", + " print(f\" Parse error (attempt {attempt}): got '{response_preview}', waiting {wait}s...\")\n", + " time.sleep(wait)\n", + " \n", + " except Exception as e:\n", + " last_error = f\"{type(e).__name__}: {e}\"\n", + " error_details = {\"type\": type(e).__name__, \"attempt\": attempt, \"message\": str(e)}\n", + " print(f\" Error (attempt {attempt}): {last_error}\")\n", + " if attempt < max_retries:\n", + " time.sleep(wait_times.get(attempt, 3))\n", + " \n", + " return {\n", + " \"labels\": None,\n", + " \"response_text\": None,\n", + " \"input_tokens\": 0,\n", + " \"output_tokens\": 0,\n", + " \"internal_tokens\": 0,\n", + " \"response_tokens\": 0,\n", + " \"cost_usd\": 0.0,\n", + " \"latency_sec\": None,\n", + " \"attempts\": max_retries,\n", + " \"error\": error_details,\n", + " \"raw_response\": None,\n", + " }" + ] + }, + { + "cell_type": "markdown", + "id": "b7d9948c-1278-42bb-84f9-311f5a9bcaac", + "metadata": {}, + "source": [ + "# DataFrame Processing with Checkpointing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5799882a-e057-4277-9583-94fbc1872481", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# DataFrame Processing with Checkpointing\n", + "# -----------------------------\n", + "def classify_dataframe(\n", + " df: pd.DataFrame,\n", + " sentence_col: str,\n", + " vendor: str,\n", + " model: str,\n", + " system_prompt: str = SYSTEM_PROMPT,\n", + " max_tokens: int = 64,\n", + " output_dir: str = \"./output\", ## default outout folder if none is passed\n", + " checkpoint_dir: str = \"./checkpoints\", ## default checkpoint folder if none is passed\n", + " save_interval: int = 100, ##### How often you want to save\n", + " run: int = 1,\n", + " reasoning_effort: str = None,\n", + " thinking_budget: int = None,\n", + " thinking_level: str = None,\n", + " max_consecutive_errors: int = 10,\n", + ") -> pd.DataFrame:\n", + " \"\"\"\n", + " Classify all sentences with checkpointing.\n", + " \n", + " Directory structure:\n", + " - checkpoint_dir/{vendor}_{model}_{mode}_run{run}.pkl (interim saves)\n", + " - output_dir/Run{run}/{vendor}_{model}_{mode}_run{run}.pkl (final output)\n", + " - output_dir/Run{run}/{vendor}_{model}_{mode}_run{run}.log.json (run log)\n", + " \"\"\"\n", + " # Validate\n", + " if sentence_col not in df.columns:\n", + " raise ValueError(f\"Column '{sentence_col}' not found\")\n", + " if vendor not in VENDORS:\n", + " raise ValueError(f\"Unknown vendor: {vendor}\")\n", + " if model not in VENDORS[vendor][\"models\"]:\n", + " raise ValueError(f\"Unknown model '{model}' for vendor '{vendor}'\")\n", + " \n", + " model_info = VENDORS[vendor][\"models\"][model]\n", + " mode_str = get_mode_string(vendor, model, reasoning_effort, thinking_budget, thinking_level)\n", + " # Derive dataset tag from output_dir (e.g., \"./output/holdout\" → \"holdout\")\n", + " dataset_tag = os.path.basename(os.path.normpath(output_dir)).lower()\n", + " base_name = f\"{dataset_tag}_{vendor}_{model}_{mode_str}_run{run}\"\n", + " \n", + " # Create directories\n", + " os.makedirs(checkpoint_dir, exist_ok=True)\n", + " run_dir = f\"{output_dir}/Run{run}\"\n", + " os.makedirs(run_dir, exist_ok=True)\n", + " \n", + " checkpoint_path = f\"{checkpoint_dir}/{base_name}.pkl\"\n", + " \n", + " print(f\"Checkpoint dir: {checkpoint_dir}\")\n", + " print(f\"Output dir: {run_dir}\")\n", + " print(f\"Checkpoint file: {checkpoint_path}\")\n", + " \n", + " # Load existing checkpoint\n", + " results = []\n", + " processed_ids = set()\n", + " \n", + " if os.path.exists(checkpoint_path):\n", + " try:\n", + " checkpoint_df = pd.read_pickle(checkpoint_path)\n", + " successful = checkpoint_df[checkpoint_df['error'].isna()]\n", + " results = successful.to_dict('records')\n", + " processed_ids = set(successful['id'].values) if 'id' in successful.columns else set()\n", + " failed_count = len(checkpoint_df) - len(successful)\n", + " print(f\"Resumed: {len(results)} processed, {failed_count} errors to retry\")\n", + " except Exception as e:\n", + " print(f\"Could not load checkpoint: {e}\")\n", + " \n", + " # Initialize client\n", + " client = init_client(vendor)\n", + " \n", + " # Track errors\n", + " consecutive_errors = 0\n", + " start_time = datetime.datetime.now()\n", + " \n", + " # Process\n", + " for idx, row in tqdm(df.iterrows(), total=len(df), desc=f\"Run {run}: {model} ({mode_str})\"):\n", + " row_id = row.get('id', idx)\n", + " \n", + " if row_id in processed_ids:\n", + " continue\n", + " \n", + " sentence = str(row[sentence_col])\n", + " \n", + " result = classify_sentence(\n", + " client, sentence, vendor, model, model_info, system_prompt, max_tokens,\n", + " reasoning_effort=reasoning_effort, thinking_budget=thinking_budget,\n", + " thinking_level=thinking_level\n", + " )\n", + " \n", + " ordered_result = {\n", + " 'id': row_id,\n", + " 'sentence': sentence,\n", + " **result\n", + " }\n", + " results.append(ordered_result)\n", + " \n", + " if result['error']:\n", + " consecutive_errors += 1\n", + " if consecutive_errors >= max_consecutive_errors:\n", + " print(f\"\\nStopping: {max_consecutive_errors} consecutive errors\")\n", + " break\n", + " else:\n", + " consecutive_errors = 0\n", + " processed_ids.add(row_id)\n", + " \n", + " # Checkpoint save\n", + " if len(results) % save_interval == 0:\n", + " pd.DataFrame(results).to_pickle(checkpoint_path)\n", + " print(f\"\\nCheckpoint saved: {len(results)} processed\")\n", + " \n", + " # Final results\n", + " results_df = pd.DataFrame(results)\n", + " end_time = datetime.datetime.now()\n", + " \n", + " # Save to checkpoint (for resume if needed)\n", + " results_df.to_pickle(checkpoint_path)\n", + " \n", + " # Save to run directory\n", + " results_df.to_pickle(f\"{run_dir}/{base_name}.pkl\")\n", + " results_df.to_csv(f\"{run_dir}/{base_name}.csv\", index=False)\n", + " \n", + " # Create log file\n", + " error_count = int(results_df['error'].notna().sum())\n", + " success_count = len(results_df) - error_count\n", + " \n", + " log = {\n", + " \"run\": run,\n", + " \"vendor\": vendor,\n", + " \"model\": model,\n", + " \"supplier\": model_info[\"supplier\"],\n", + " \"mode\": mode_str,\n", + " \"reasoning_effort\": reasoning_effort,\n", + " \"thinking_budget\": thinking_budget,\n", + " \"start_time\": str(start_time),\n", + " \"end_time\": str(end_time),\n", + " \"duration_sec\": (end_time - start_time).total_seconds(),\n", + " \"total_sentences\": len(df),\n", + " \"processed_sentences\": len(results_df),\n", + " \"successful\": success_count,\n", + " \"errors\": error_count,\n", + " \"tokens\": {\n", + " \"input\": int(results_df['input_tokens'].sum()),\n", + " \"output\": int(results_df['output_tokens'].sum()),\n", + " \"internal\": int(results_df['internal_tokens'].sum()),\n", + " \"response\": int(results_df['response_tokens'].sum()),\n", + " },\n", + " \"cost_usd\": float(results_df['cost_usd'].sum()),\n", + " \"pricing\": {\n", + " \"input_per_1m\": model_info[\"price_in\"],\n", + " \"output_per_1m\": model_info[\"price_out\"],\n", + " },\n", + " \"files\": {\n", + " \"checkpoint\": checkpoint_path,\n", + " \"output_pkl\": f\"{run_dir}/{base_name}.pkl\",\n", + " \"output_csv\": f\"{run_dir}/{base_name}.csv\",\n", + " }\n", + " }\n", + " \n", + " log_path = f\"{run_dir}/{base_name}.log.json\"\n", + " with open(log_path, 'w') as f:\n", + " json.dump(log, f, indent=2)\n", + " \n", + " print(f\"\\nRun {run} complete:\")\n", + " print(f\" Sentences: {success_count}/{len(results_df)}\")\n", + " print(f\" Tokens - Input: {log['tokens']['input']:,}, Output: {log['tokens']['output']:,} (Internal: {log['tokens']['internal']:,})\")\n", + " print(f\" Cost: ${log['cost_usd']:.4f}\")\n", + " print(f\" Log: {log_path}\")\n", + " \n", + " return results_df" + ] + }, + { + "cell_type": "markdown", + "id": "2e6d17c0-135a-42c6-8815-db44f8445894", + "metadata": {}, + "source": [ + "# Multi-Run Processing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fba35e3e-e3ab-4f23-9b46-749ccecbe66d", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Multi-Run Processing\n", + "# -----------------------------\n", + "def run_classification(\n", + " df: pd.DataFrame,\n", + " sentence_col: str,\n", + " vendor: str,\n", + " model: str,\n", + " runs: list = [1, 2, 3], # default if nothing passed\n", + " output_dir: str = \"./output\", # default if nothing passed\n", + " checkpoint_dir: str = \"./checkpoints\", # default if nothing passed\n", + " system_prompt: str = SYSTEM_PROMPT,\n", + " max_tokens: int = 64,\n", + " save_interval: int = 100,\n", + " reasoning_effort: str = None,\n", + " thinking_budget: int = None,\n", + " thinking_level: str = None,\n", + ") -> dict:\n", + " \"\"\"\n", + " Run classification multiple times.\n", + " \n", + " Directory structure:\n", + " output_dir/\n", + " Run1/\n", + " {vendor}_{model}_{mode}_run1.pkl\n", + " {vendor}_{model}_{mode}_run1.csv\n", + " {vendor}_{model}_{mode}_run1.log.json\n", + " Run2/\n", + " ...\n", + " checkpoint_dir/\n", + " {vendor}_{model}_{mode}_run1.pkl\n", + " ...\n", + " \n", + " Returns dict of {run: results_df}\n", + " \"\"\"\n", + " model_info = VENDORS[vendor][\"models\"][model]\n", + " mode_str = get_mode_string(vendor, model, reasoning_effort, thinking_budget, thinking_level)\n", + " job_start_time = time.perf_counter()\n", + " \n", + " print(f\"\\n{'='*60}\")\n", + " print(f\"CLASSIFICATION JOB\")\n", + " print(f\"{'='*60}\")\n", + " print(f\"Vendor: {vendor}\")\n", + " print(f\"Model: {model}\")\n", + " print(f\"Supplier: {model_info['supplier']}\")\n", + " print(f\"Mode: {mode_str}\")\n", + " print(f\"Pricing: ${model_info['price_in']}/M in, ${model_info['price_out']}/M out\")\n", + " print(f\"Runs: {runs}\")\n", + " print(f\"Sentences: {len(df)}\")\n", + " print(f\"Output dir: {output_dir}\")\n", + " print(f\"Checkpoint dir: {checkpoint_dir}\")\n", + " print(f\"{'='*60}\\n\")\n", + " \n", + " all_results = {}\n", + " \n", + " for run in runs:\n", + " print(f\"\\n{'='*40}\")\n", + " print(f\"RUN {run}\")\n", + " print(f\"{'='*40}\")\n", + " \n", + " results_df = classify_dataframe(\n", + " df, sentence_col, vendor, model, system_prompt, max_tokens,\n", + " output_dir=output_dir,\n", + " checkpoint_dir=checkpoint_dir,\n", + " save_interval=save_interval,\n", + " run=run,\n", + " reasoning_effort=reasoning_effort,\n", + " thinking_budget=thinking_budget,\n", + " thinking_level=thinking_level,\n", + " )\n", + " \n", + " all_results[run] = results_df\n", + " \n", + " all_results[run] = results_df\n", + " \n", + " # Print summary\n", + " job_runtime = time.perf_counter() - job_start_time\n", + " print(f\"\\n{'='*60}\")\n", + " print(\"SUMMARY\")\n", + " print(f\"{'='*60}\")\n", + " total_cost = sum(r['cost_usd'].sum() for r in all_results.values())\n", + " print(f\"Total runs: {len(runs)}\")\n", + " print(f\"Total cost: ${total_cost:.4f}\")\n", + " print(f\"Total runtime: {job_runtime:.1f}s ({job_runtime/60:.1f}m)\")\n", + " print(f\"Output location: {output_dir}/Run*/\")\n", + " \n", + " return all_results" + ] + }, + { + "cell_type": "markdown", + "id": "11c8f621-6826-4e71-b38c-32b1bd8de480", + "metadata": {}, + "source": [ + "# Load Data\n", + "\n", + ">**If you are running this in your local computer:** Subfolders will be automaticallty created inside the folder that this notebook is in. All files will be saved in those local folders/subfolders\n", + "\n", + "> **If you are on Google CoLab:**: FIRST, you will need to connect your google drive and navigate to the folder that this noteobok is in. Then, the code will create subfodlers inside the folder you navigated to on your google drive. All files will be saved in those local folders/subfolders on your google drive." + ] + }, + { + "cell_type": "markdown", + "id": "a90bfe58-576c-4b7a-88ff-b37ee816a391", + "metadata": {}, + "source": [ + "##### **IMPORTANT for Google CoLab**\n", + "\n", + "If you want to save to your google drive, you have to connect it first and then navigate to this current folder:\n", + "\n", + "- Import google drive and connect\n", + "- Change dir to folder that this notebook is in\n", + "\n", + "***Copy this code in a new code cell, modify folder as needed (if you don't want \"Project\") and run:***\n", + "```\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "\n", + "import os\n", + "\n", + "# Construct the full path to the desired folder within Google Drive\n", + "drive_path = '/content/drive/My Drive/Project'\n", + "\n", + "# Create the directory if it doesn't exist (optional, but good practice)\n", + "if not os.path.exists(drive_path):\n", + " os.makedirs(drive_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "336f43ce-f1b2-4d2f-8130-737ef8eea75d", + "metadata": {}, + "outputs": [], + "source": [ + "# Example Sentences\n", + "# sentences = ['The annual growth in our provision for uncollectible accounts was primarily attributable to transitioning multiple external payors from recognizing revenue on a cash basis to an accrual basis, thereby aligning with customer-related payment practices.', 'We handle certain sales-type leases internally, particularly those associated with facilities serving U.S. government hospital clients.', 'We expect that this acquisition will position us to broaden our offerings in the private label accessories segment, catering to customers seeking value-oriented solutions.', 'Additionally, the Tax Cuts and Jobs Act was enacted into law by the President in December 2017.', 'Valuations derived from the Black-Scholes model can vary significantly depending on the assumptions made regarding volatility and the duration of the underlying instruments.', 'Our audience is engaged across multiple channels, such as digital properties, print publications, as well as broadcast and streaming media.', 'The forecast incorporates management’s most informed assumptions regarding anticipated economic and market trends for the duration of the projection, encompassing anticipated changes in sales growth, cost structures, operating margin expectations, and future cash outlays.', 'As of December 31, 2007, a single credit remains unassigned to a particular reserve, which, absent any favorable developments, is at risk of defaulting imminently and may lead to the filing of a claim.', \"Our model incorporates assumptions regarding anticipated stock price fluctuations, utilizing historical volatility metrics, the applicable risk-free rate derived from the treasury yield curve, the projected duration of equity awards informed by past exercise trends and termination actions post-vesting, along with anticipated dividend equivalents to be paid during the award's estimated term, given that our stock appreciation rights participate in dividends.\", 'The Company conducts comprehensive physical counts of inventory across all store and warehouse locations at least once per year, and correspondingly updates the reported merchandise inventory to reflect the results of these assessments.', 'Substantial growth was realized across categories such as basketball and lifestyle footwear, as well as branded apparel and specialized cleats for wrestling, volleyball, and soccer.', 'Actuarial gains realized in 2007 will diminish the unrecognized loss allocated across the projected average remaining service years of active plan members, with such periods differing by plan from 6 to 23 years.', 'Platform utilization expanded across enterprise segments, reflecting enhanced operational efficiency and streamlined integration of internal systems, during fiscal 2020 for the technology company.', 'The advances carried maturities at assorted dates up to 2014, with interest rates ranging from 0.3% to 3.4%.', 'Furthermore, significant investments were made to upgrade essential operational systems, establish a comprehensive disaster recovery strategy, and strengthen compliance protocols designed to better serve our customers.', 'Subscription offerings expanded during fiscal 2023, reflecting successful margin optimization and process improvements within our technology platform, as we continued to innovate and refine our approach.', 'Provisions for loss contingencies are determined by management’s assessment regarding the probability of adverse results and the estimated magnitude of potential losses.', 'Before Discover Bank acquired SLC on December 31, 2010, SLC maintained a contractual relationship with Citi that facilitated the origination and ongoing management of private student loans for individuals.', 'Overview of BorgWarner Inc.’s Financial Condition and Operating Results Management’s Discussion and Analysis', 'We attribute the rise in RELIC watch sales chiefly to strategic modifications in our product offerings and the incremental expansion of our customer base achieved toward the end of Fiscal 2003.', 'PECO Electric Operating Statistics and Revenue Summary The following outlines PECO’s electric sales metrics and associated revenue information: (a) Full service represents energy provided to customers receiving electricity under standard tariffed rates.', 'As part of the 2009 Supervisory Capital Assessment Program, the Federal Reserve Board enhanced its review of the capital sufficiency of selected major bank holding companies by utilizing an alternative measure of Tier 1 capital referred to as Tier 1 common equity.', 'During the years ended 2008, 2007, and 2006, none of our customers individually represented 10% or greater of our overall revenue.', 'Other than our operating leases, which mainly pertain to restaurant locations, we do not engage in any off-balance sheet transactions.', 'Policyholders have the option to discontinue coverage as a result of their departure from medical practice due to retirement, disability, or passing.', 'Red Robin Gourmet Burgers®, Inc., a Delaware entity along with its subsidiaries (referred to herein as “Red Robin,” the “Company,” or by similar terms), is engaged predominantly in the development, operation, and franchising of casual-dining restaurants, totaling 514 outlets across North America as of the close of the fiscal year on December 28, 2014.', 'As of December 31, 2008, 51% of the outstanding home equity lines of credit were collateralized by properties located in New York State, while 21% and 26% were backed by properties in Pennsylvania and the Mid-Atlantic region, respectively.', \"The decline in property catastrophe offerings was attributable to non-renewed contracts and decreased participation levels, influenced by prevailing market dynamics and an increased reliance on retrocessional arrangements, whereas non-catastrophe property business during 2014 comprised an incoming unearned premium portfolio transfer of $50.2 million from Gulf Reinsurance Limited ('Gulf Re'), which the Company acquired in 2015.\", 'IPL’s operating performance reflected a $149 million decline in earnings allocable to common shareholders during 2008 and a $118 million improvement in 2007, which was largely attributable to a post-tax gain of $123 million stemming from the divestiture of IPL’s electric transmission assets in 2007.', 'Incorporating remote work practices forms an integral aspect of our overall business continuity strategy.', 'Upon the sale of a gift card, we record a corresponding liability on our balance sheet.', 'Management’s Discussion and Analysis Procedures for the Company and its Subsidiaries.', 'Our Carter’s, Just One Year, and Child of Mine brands are made available to third parties through licensing agreements.', 'SPP has initiated approvals for the development of transmission infrastructure designed to transport renewable energy from the wind-producing regions of western Oklahoma, the Texas Panhandle, and western Kansas to major demand centers, as part of its strategy to expand transmission capacity in these areas.', 'Consequently, we record revenue at the point when ownership is conveyed to our customer.', 'The credits provided will fluctuate from one quarter to the next, and ultimately, the cumulative benefits passed on to customers in the form of reduced electric rates will align with the amounts applied toward federal income tax obligations.', 'Outcomes may vary, and the divergence from these projections could be significant depending on alternative assumptions or circumstances.', 'Accruals associated with chargebacks and rebates involve a significant degree of estimation within our sales processes.', 'In the quarter concluding on September 30, 2013, we finalized a partnership arrangement with LPC MM Monrovia, LLC, an independent third party, specifically to facilitate the acquisition of a real estate asset in Monrovia, California, and to initiate the development of construction documentation for enhancements to the site.', 'Periodic legal proceedings or regulatory inquiries arise against the Company as a normal aspect of conducting its insurance operations.', 'In the interim between lease termination and final real estate disposition, we provided significant loan-based support to the hospital operator, enabling them to meet operational cash needs while awaiting reimbursement for patient services from Medicare and other payors.', 'Pursuant to the CRS agreement, close to one-fourth of the contract’s total value becomes payable by the customer and is subject to collection solely after launch and delivery objectives have been achieved for each of the eight CRS missions.', 'In 2008, the Company completed toll processing of 126,000 ounces of PGMs, an increase compared to the 112,000 ounces processed under toll arrangements in the prior year.', 'Operations within our ready-mixed concrete, precast concrete, and ancillary concrete businesses experience fluctuations attributable to seasonal patterns.', 'Historically, our highest working capital requirements arise in the latter half of the year, as accounts receivable and inventory expand due to elevated activity during the holiday sales period, and inventory builds ahead of anticipated factory shutdowns for Chinese New Year observances.', 'Cost of sales primarily includes expenses related to products, such as major aircraft and engine components, along with direct labor, overhead, and costs associated with maintaining aircraft.', 'Our proportional interest in the mines under our management affords us an annual rated pellet production capacity of 22.9 million tons, which equates to roughly 28% of the aggregate pellet capacity available throughout North America.', 'For the fiscal year ended December 31, 2009, a total of 338 da Vinci Surgical Systems were sold, representing a slight increase from the 335 units sold in the prior year ended December 31, 2008.', 'During the 2009 fiscal year, the Company incurred costs attributable to flooding totaling $7.6 million, while recognizing $16.7 million in recoveries from insurance claims.', 'The rise was tempered in part by the enhanced redeployment incentives offered through our equipment lease initiative to new customers.', 'The Company provides defined benefit pension arrangements, primarily benefiting salaried and management employees, and oversees the associated post-retirement medical plan accounting.', 'The gross profit margin declined to 30.7% in 2013 as compared to 31.4% in the previous year, largely attributable to a rise in sales from private label and international segments, which traditionally yield lower margins.', 'A substantial portion of our multi-family lending activity is directed toward longstanding property owners whose apartment buildings operate under rent control, resulting in rental rates that are lower than prevailing market levels.', 'Growth in distillery product revenues was driven by elevated unit volumes and enhanced pricing for food grade alcohol used in both beverage and industrial sectors, alongside strengthened pricing for fuel grade alcohol.', 'In October 2005, we completed delivery of the system to the customer.', 'On July 29, 2014, the Company completed an issuance of Senior Notes totaling $300 million, maturing on February 1, 2025, with a fixed interest rate of 5.375% and sold at par value, referred to herein as the 2014 Notes.', 'Expenses associated with reactivating these subscribers were recognized, and these customers were categorized as gross new DISH TV subscriber additions for the year ended December 31, 2017, with related costs captured under “Subscriber acquisition costs” within our Consolidated Statements of Operations and Comprehensive Income (Loss) and/or as “Purchases of property and equipment” in our Consolidated Statements of Cash Flows.', \"Refer to 'Item 8' under the section 'Loans Receivable.'\", 'Valuation of swaps, interest rate swaptions, and option contracts is determined through commonly accepted industry models that estimate the present value of anticipated derivative cash flows, reflecting both prevailing and anticipated market conditions.', 'On February 5, 2010, a total of 68 workover rigs were staffed and were either in service or subject to ongoing marketing efforts.', 'As a result of our commitment to making education more affordable for our students, we project that the typical debt burden for graduates of the Art Institutes has declined by nearly 15% since 2010.', 'The market valuation of distressed inventory is determined utilizing prior sales patterns for specific product categories, prevailing market dynamics, overall economic factors, and the worth of outstanding in-house orders associated with prospective sales of such inventory.', 'The Company factors such discounts and rebates into its determination of transaction price, recording them as deductions from gross sales.', 'Within the Caribbean region, we operate vertically integrated utilities as well as generation facilities, each governed by multi-year agreements established in partnership with governmental entities.', 'Nonetheless, despite Synovus Bank meeting all required quantitative capital standards, regulatory provisions allow for reclassification of an institution to a less favorable capital category, dependent on supervisory considerations beyond capital metrics.', 'For most of the securities measured through dealers and pricing services, we source several independent valuations, all of which are not legally binding on either our company or our counterparties.', '59 Overseas Shipholding Group, Inc. serves as the registry for this section.', 'Our strategy is to consistently keep U.S. wholesaler inventories for our products at or below a one-month average, supporting reliable supply for our customers.', 'Aside from the $202.9 million gain recognized in 2011 from the sales of Big Sandy and Langley, the year-over-year decline also reflected roughly 25% reduced realized prices for sales volumes, a 23% increase in the rate of production depletion, and elevated other operating costs, which were partially counterbalanced by a 33% rise in production output, a 30% expansion in gathered volumes, and increased transmission income.', '• Establishing further criteria for evaluating applications for disability-related benefit claims.', 'For fiscal year 2016, revenue totaled $28.2 million, representing a decline of $11.3 million, or 28.5%, compared to the prior year’s revenue of $39.5 million.', 'During 2010, the Company divested its Orthopaedic Implant production site located in Caen, France, resulting in a $24.3 million gain that was recognized within selling, general, and administrative expenses.', 'Illustrative instances of indemnification obligations under the Separation and Distribution Agreement and associated ancillary agreements include Navient’s duty to compensate the Company and the Bank for any liabilities, costs, or expenses they may face resulting from actions or potential actions concerning pre-Spin-Off SLM and its subsidiaries’ servicing, operational, and collections activities related to Private Education Loans and FFELP Loans that were held as assets by either the Bank or Navient at the time of the Spin-Off, contingent upon written notice being furnished to Navient by April 30, 2017, the third anniversary of the Spin-Off.', 'Growth in platform usage was observed, driven by improved internal workflows and resource allocation, during fiscal 2022 for the telecommunications company.', 'Acquisitions pursuant to the Plan may occur through open market transactions or privately arranged agreements, conducted periodically and in accordance with relevant legal requirements, inclusive of Rule 10b-18 under the amended Securities Exchange Act of 1934.', 'We ended our bridge revolving credit facility, originally established on June 13, 2007, with five U.S. and international banking institutions in February 2008.', \"Our management team regularly assesses the company's cost structure to confirm that operational expenditures are managed effectively without compromising the quality of service provided to customers.\", 'The company’s reach broadened in fiscal 2021, reflecting enhanced productivity and optimization of internal processes in expanded territories.', 'A portion of these rising costs was absorbed by the company, resulting in an average sales price uptick to customers of just $0.32 per kilogram.', 'The company saw advances in payment collection efficiency, reflecting cost reduction initiatives and automation of reconciliation tasks.', 'Advancements in technology have the potential to mitigate these factors or expand the availability of mineral resources.', 'Each of our two plans designed for small office and home office use provides a single business fax line at no extra cost, with supplementary fax lines available for a monthly fee.', 'In September 2011, the European Commission approved VIBATIV® for use in adult patients with hospital-acquired pneumonia, including cases linked to ventilators, where MRSA is known or suspected and alternative therapies are deemed inappropriate, thereby expanding therapeutic options for these individuals.', 'These trends are subject to modification due to factors such as the establishment of additional schools, launch of innovative programs, rising adult student enrollment, or potential acquisitions.', 'Funding requests, which generally arise at intervals of four to eight weeks, trigger the execution of necessary inspections and formal evaluations.', 'Assessing our yearly tax obligations and analyzing our tax strategies necessitates considerable judgment and expertise.', 'On March 24, 2006, the Grindle plaintiffs retracted their request to intervene, doing so without affecting their rights.', 'At December 31, 2017, the balance of commercial and industrial loans, which includes owner-occupied commercial properties, grew by $0.3 billion, reaching a total of $1.6 billion, reflecting support for our business customers’ financing needs.', 'After the merger, ARRIS retained complete indirect ownership of both ARRIS Group and Pace as wholly-owned subsidiaries.', '•On August 17, 2012, we provided a first mortgage loan totaling $46.0 million, secured by a Hilton hotel consisting of 315 rooms located in Rockville, Maryland.', 'Our Management’s Discussion and Analysis of Financial Condition and Results of Operations highlights our position as a foremost provider of independent, technology-driven portfolio management solutions, investment guidance, and retirement income services, serving participants predominantly within employer-sponsored defined contribution plans, including 401(k) arrangements.', 'The structure of our per-minute charges is designed to encompass the entirety of services provided.', 'On June 30, 2009, nearly 83% of our holdings, excluding cash balances, possessed maturities shorter than one year.', 'Distribution of Portfolio by Geography: The following table presents the Company’s operating square footage segmented by region at December 31, 2011 (in thousands). Industry Credit Risk Profile: The subsequent data illustrates the composition of our tenant portfolio by industry as of December 31, 2011.', '2 Includes additional products such as consumer revolving credit, installment loans, and consumer lease finance options tailored to customer needs.', 'Currently anticipated capital expenditures for 2006 total approximately $568.9 million, encompassing $24.0 million allocated to startup activities, resolution of the thruster defect previously detailed, and modifications mandated by customers for the GSF Development Driller I, $237.3 million dedicated to construction of a new semisubmersible unit, $45.0 million earmarked for repairs to our rig fleet due to hurricane impacts, $124.0 million assigned for significant fleet enhancements, $107.8 million directed toward additional equipment acquisitions and replacements, $17.3 million related to capitalized interest, and $13.5 million (net of intersegment eliminations) for oil and gas operations.', 'The Company plans to maintain its practice of issuing quarterly dividends, contingent upon Board approval, sufficient capital resources, and an assessment that such distributions remain advantageous for its shareholders.', 'As of December 31, 2008, the Company had no borrowings outstanding under its credit facility, with a total available capacity of roughly $48.4 million, reflecting the deduction for a $7.0 million letter of credit issued to BCBSF/HOI.', 'The progression of our product offerings relies on the availability and consistent performance of software sourced from external vendors.', 'The Company establishes BESP for its product deliverables through a weighted average pricing methodology, initiated by an examination of historical data on standalone sales transactions.']\n", + "# df = pd.DataFrame({'sentences': sentences})\n", + "# df=df.head(3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20f2dab2-7738-435f-bdff-f9a9e3a2eb2c", + "metadata": {}, + "outputs": [], + "source": [ + "# Load Data: Preliminary Labeling of all available data\n", + "df = pd.read_csv(\"source_text/10K_unlabeled.csv\")\n", + "df = df.sample(n=10).reset_index(drop=True) # different sample each run\n", + "df" + ] + }, + { + "cell_type": "markdown", + "id": "dbc7642b-d0cd-4b73-836d-893c54531155", + "metadata": {}, + "source": [ + "# Set Output Paths\n", + "- Preliminary (\"OUTPUT_DIR\") vs. Holdout (\"HOLDOUT_DIR\") vs. Train (\"TRAIN_DIR\") sets\n", + "- ***Make sure not to accidentally overwrite your labeled data!***\n", + "- When you query a model (below in Run Models), you need to specifiy the output path." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a37fc60e-8234-4deb-b101-de383d378f09", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# OUTPUT CONFIGURATION - SET for Base, Holdout, Train\n", + "# =============================================================================\n", + "CHECKPOINT_DIR = \"./checkpoints\"\n", + "OUTPUT_DIR = \"./output\"\n", + "PRELIM_DIR = f\"{OUTPUT_DIR}/preliminary\"\n", + "HOLDOUT_DIR = f\"{OUTPUT_DIR}/holdout\"\n", + "TRAIN_DIR = f\"{OUTPUT_DIR}/train\"" + ] + }, + { + "cell_type": "markdown", + "id": "59c343be-dbbe-4c28-be41-44c518ed2526", + "metadata": {}, + "source": [ + "# Run Models\n", + "- Set desired parameters\n", + "- Set runs (currently set to 1 run per model for examples): runs=[1] vs runs=[1, 2, 3]\n", + "- Uncomment the models you what you want to run! \n", + "- Run different models? Make sure to define them above in the vendor/model configuration first.\n", + " - Not all models will run with this code.\n", + " - If you get errors, you need to debug and develop code further\n", + " - (I'd use Claude Opus 4.5 Thinking for that, but other genAI models may work as well / even better).\n", + "\n", + "> ***IMPORTANT*** If you want to rerun a model, make sure to delete its files from the checkpoints folder first. Otherwise, it will skip all examples (e.g., sentences) that the previous run already labeled (which could be all) and you don't get updated results.\n", + "\n", + "> **MORE IMPORTANT** Running this code will cost you API credits (and requires you to ahve accounts with the providers). You will need to supply your own API keys. Beware that you may be subject to rate limits (how many queries you can send per minute) and which models you can use (OpenAI, for example, requires you to verify your identidy with an ID to access many models). Regardless, every time you execute this code, you will drain your API credits = real money! Thus, make wise decisions about what and how much to label." + ] + }, + { + "cell_type": "markdown", + "id": "e4503bd1-0d03-4f09-8897-fc06be07d887", + "metadata": {}, + "source": [ + "## Labeling Approach:\n", + "\n", + "**What we ultimately need:**\n", + "- A holdout set (1000 examples)\n", + "- A training set (15k examples)\n", + "- The ability to test and train all classes / labels\n", + " - Need to have at least some balance in holdout set so that every class (label) is represented (e.g., at least 100 times)\n", + " - Need enough examples per class (label) in train set that fine-tuned model can learn them\n", + "> **Challenge**\n", + "> - How do we know that we have enough examples (i.e., sentences) per class (label)?\n", + "\n", + "> **Idea**\n", + "> - One possible approach is to label all data only once with a reasonably fast and inexpensive genAI model to get an idea about class (label) distribution. Depending how good the model is (which we cannot easily know yet unless we have a couple dozen examples that we manually constructed as preliminary evaluation set), we have at least a *directional idea* of which sentences may belong to which classes / have which labels.\n", + "> - We can then use this *directional idea* to ***construct a holdout set*** with at least N=100 (or more?) examples per class (label). This will be better than random sampling (unless classes (labels) are balanced in full data set, which is unlikely). It will still not be robust, but *at least give us some idea*.\n", + "> - We need to make sure to remove all examples that we put in the holdout set from the rest of our data (i.e., no holdout set example should also be in the training set to ***prevent leakage***). " + ] + }, + { + "cell_type": "markdown", + "id": "a2c07cc9-3961-4ff4-82db-d75006dfa14c", + "metadata": {}, + "source": [ + "### WARNING: MAKE SURE TO **PASS** THE CORRECT OUTPUT DIR!\n", + "What are you labeling? \n", + "- Prelimiary full data set (once to investigate class balance) --> PRELIM_DIR\n", + "- Holdout set --> HOLDOUT_DIR\n", + "- Train set --> TRAIN_DIR" + ] + }, + { + "cell_type": "markdown", + "id": "27552dca-0f9d-4feb-b2f0-0fe63df69d51", + "metadata": {}, + "source": [ + "### Google CoLab Runtime Timeouts\n", + "If you have a slow API and/or you are lableing many sentences (texts), Google CoLab may time-out or shut down your runtime. This will abort the labeling process. You can resume anytime, but you need to restart and run your notebook again. For this purpose, I've added checkpointing so that results are saved every N=100 sentences (texts) and the code will look for an intermis checkpoint and pickup from there (it will also find and retry errors where nothing valid was returned). *This may be less of a problem with CoLap PRO, which you can sign-up for free as a student.*\n", + "\n", + "> If you can, it may be well worth installing python and jupyter notebooks on your local computer. You won't face the timeout issues then and your code will query the APIs as long as your computer is connected to the internet (and running)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e892c00f-d577-4c86-b913-85dc75f51ff2", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# UNC Main Campus Computing on MS Azure\n", + "# Models: gpt-4.1, gpt-4o\n", + "# Modes: chat only!!! (temperature=0 set with openAI models) \n", + "# Notes: Experimental by Dr. D. To get odd number of lables per sentence across two models, I queried the first model 4 times (4 + 3 = 7)\n", + "# =============================================================================\n", + "\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"azure\",\n", + "# model=\"gpt-4.1\",\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"azure\",\n", + "# model=\"gpt-4o\",\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55eefc77-28e4-4fbe-8e21-80c26a6d041a", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# OPENAI\n", + "# Models: gpt-5.2, gpt-5, gpt-5-mini, gpt-5-nano, gpt-4.1, gpt-4o\n", + "# Modes: chat (temperature=0) or reasoning (reasoning_effort)\n", + "# reasoning_effort: \"low\", \"medium\", \"high\"\n", + "# =============================================================================\n", + "\n", + "# # OpenAI chat mode (temperature=0)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"openai\",\n", + "# model=\"gpt-4.1\", # gpt-5.2, gpt-5, gpt-5-mini, gpt-5-nano, gpt-4.1, gpt-4o --> make sure these are defined with pricing in vendor/model configuration\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# OpenAI reasoning mode\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"openai\",\n", + "# model=\"gpt-5.2\", # gpt-5.2, gpt-5, gpt-5-mini, gpt-5-nano, gpt-4.1, gpt-4o --> make sure these are defined with pricing in vendor/model configuration\n", + "# reasoning_effort=\"high\", # \"low\", \"medium\", \"high\"\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "935a4056-5e2d-47da-add1-f38a2735844d", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# ANTHROPIC\n", + "# Models: claude-opus-4-5-20251101, claude-sonnet-4-5-20250929, \n", + "# claude-haiku-4-5-20251001, claude-sonnet-4-20250514\n", + "# Modes: chat (temperature=0) or thinking (thinking_budget)\n", + "# thinking_budget: min 1024, billed as output tokens\n", + "# =============================================================================\n", + "\n", + "# # Anthropic chat mode (temperature=0)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"anthropic\",\n", + "# model=\"claude-sonnet-4-5-20250929\", # opus-4-5, sonnet-4-5, haiku-4-5, sonnet-4\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# # Anthropic thinking mode\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"anthropic\",\n", + "# model=\"claude-sonnet-4-5-20250929\", # opus-4-5, sonnet-4-5, haiku-4-5, sonnet-4\n", + "# thinking_budget=2048, # min 1024, billed as output tokens\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5174d2b0-01b8-4971-8ead-9fb23bd416f3", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# FIREWORKS\n", + "# Models: deepseek-v3p2, deepseek-r1-0528, qwen3-vl-235b-a22b-instruct,\n", + "# qwen3-vl-235b-a22b-thinking, kimi-k2p5\n", + "# Modes: auto-detected from model name\n", + "# - Chat: deepseek-v3p2, qwen3-vl-235b-a22b-instruct\n", + "# - Thinking: deepseek-r1-0528, qwen3-vl-235b-a22b-thinking, kimi-k2p5\n", + "# =============================================================================\n", + "\n", + "# # DeepSeek V3.2 (chat)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"fireworks\",\n", + "# model=\"deepseek-v3p2\",\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# # DeepSeek R1 (thinking)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"fireworks\",\n", + "# model=\"deepseek-r1-0528\",\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# # Qwen3 VL Instruct (chat)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"fireworks\",\n", + "# model=\"qwen3-vl-235b-a22b-instruct\",\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# # Qwen3 VL Thinking (thinking)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"fireworks\",\n", + "# model=\"qwen3-vl-235b-a22b-thinking\",\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# # Kimi K2.5 (thinking)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"fireworks\",\n", + "# model=\"kimi-k2p5\",\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4c2689d-f405-4061-8844-a83b2a53c80d", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# XAI (GROK)\n", + "# Models: grok-4, grok-4-1-fast-reasoning, grok-4-1-fast-non-reasoning\n", + "# Modes: auto-detected from model name\n", + "# - Reasoning: grok-4, grok-4-1-fast-reasoning\n", + "# - Chat: grok-4-1-fast-non-reasoning\n", + "# =============================================================================\n", + "\n", + "# # Grok 4 (reasoning)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"xai\",\n", + "# model=\"grok-4\", # grok-4, grok-4-1-fast-reasoning, grok-4-1-fast-non-reasoning\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# # Grok 4.1 Fast Reasoning\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"xai\",\n", + "# model=\"grok-4-1-fast-reasoning\",\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# # Grok 4.1 Fast Non-Reasoning (chat)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"xai\",\n", + "# model=\"grok-4-1-fast-non-reasoning\",\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5746b7c9-c3d0-4b0f-92a9-5cf087509421", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# GOOGLE GEMINI\n", + "# Models: gemini-3-pro-preview, gemini-3-flash-preview, \n", + "# gemini-2.5-pro, gemini-2.5-flash\n", + "# Modes: controlled via thinking_level\n", + "# - Pro: \"low\", \"high\" (default if not specified)\n", + "# - Flash: \"minimal\", \"low\", \"medium\", \"high\" (default if not specified)\n", + "# Note: Even \"low\" uses some thinking tokens. Only Flash \"minimal\" truly minimizes.\n", + "# =============================================================================\n", + "\n", + "# # Gemini 3 Pro - high thinking (default)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"google\",\n", + "# model=\"gemini-3-pro-preview\", # gemini-3-pro-preview, gemini-3-flash-preview, gemini-2.5-pro, gemini-2.5-flash\n", + "# thinking_level=\"high\", # Pro: \"low\", \"high\"\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# # Gemini 3 Pro - low thinking\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"google\",\n", + "# model=\"gemini-3-pro-preview\",\n", + "# thinking_level=\"low\", # Pro: \"low\", \"high\"\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# # Gemini 3 Flash - high thinking (default)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"google\",\n", + "# model=\"gemini-3-flash-preview\",\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )\n", + "\n", + "# # Gemini 3 Flash - minimal thinking (closest to chat mode)\n", + "# results = run_classification(\n", + "# df,\n", + "# sentence_col=\"sentences\",\n", + "# vendor=\"google\",\n", + "# model=\"gemini-3-flash-preview\",\n", + "# thinking_level=\"minimal\", # Flash: \"minimal\", \"low\", \"medium\", \"high\"\n", + "# runs=[1],\n", + "# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + "# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + "# )" + ] + }, + { + "cell_type": "markdown", + "id": "b2e948ca-dafa-497d-8449-18af62945c33", + "metadata": {}, + "source": [ + "# Preliminary Labels and Holdout Set\n", + "Once we labeled out data ***one time*** with a model, we can get a first idea of class (label) distribution and construct our holdout set accordingly.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74af47ed-ce00-4b45-beed-a0316a067589", + "metadata": {}, + "outputs": [], + "source": [ + "# I will use GPT 4.1 for this purpose that UNC hosted on Azure. \n", + "# You probably will not have access to this API and model.\n", + "# Pick a fast and not too expensive model that you feel confident can do an okay job (maybe get at least 70% correct)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0128128-bcc9-4098-8cfd-de0d716384a5", + "metadata": {}, + "outputs": [], + "source": [ + "results = run_classification(\n", + " df,\n", + " sentence_col=\"sentences\", # provide the column name in which the text is you want to classify (sentence vs sentences?)\n", + " vendor=\"azure\",\n", + " model=\"gpt-4.1\",\n", + " runs=[1],\n", + " output_dir=PRELIM_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + " checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ee8af782-e8a6-43ce-ae24-8db0fe997a09", + "metadata": {}, + "source": [ + "## STEP 1: Load labeled data from a single genAI run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14621c20-9da6-4217-b6a7-d6412e224531", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Configuration\n", + "# -----------------------------\n", + "\n", + "# CHANGE THIS: Path to your labeled data (one run, one model)\n", + "LABELED_FILE = \"./output/preliminary/Run1/preliminary_azure_gpt-4.1_chat_run1.pkl\"\n", + "\n", + "# Functional areas\n", + "CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b2d53f0-22f2-4e16-8a58-8a6cca24adc1", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import ast\n", + "\n", + "\n", + "# -----------------------------\n", + "# Load and Parse\n", + "# -----------------------------\n", + "\n", + "print(\"=\"*60)\n", + "print(\"LOADING LABELED DATA\")\n", + "print(\"=\"*60)\n", + "\n", + "df_raw = pd.read_pickle(LABELED_FILE)\n", + "print(f\"Source: {LABELED_FILE}\")\n", + "print(f\"Total sentences: {len(df_raw)}\")\n", + "\n", + "# Parse labels into binary columns\n", + "def parse_labels_safe(labels):\n", + " \"\"\"Parse labels from various formats.\"\"\"\n", + " if labels is None:\n", + " return []\n", + " if isinstance(labels, list):\n", + " return labels\n", + " if isinstance(labels, str):\n", + " try:\n", + " parsed = ast.literal_eval(labels)\n", + " return parsed if isinstance(parsed, list) else []\n", + " except:\n", + " return []\n", + " return []\n", + "\n", + "# Create binary columns for each class\n", + "df_labeled = df_raw[['id', 'sentence']].copy()\n", + "\n", + "parsed_labels = df_raw['labels'].apply(parse_labels_safe)\n", + "\n", + "for cls in CLASSES:\n", + " df_labeled[cls] = parsed_labels.apply(lambda x: 1 if cls in x else 0)\n", + "\n", + "# None = all classes are 0\n", + "df_labeled['None'] = (df_labeled[CLASSES].sum(axis=1) == 0).astype(int)\n", + "\n", + "# Remove rows where labeling failed (error column is not None/NaN)\n", + "if 'error' in df_raw.columns:\n", + " error_mask = df_raw['error'].notna()\n", + " n_errors = error_mask.sum()\n", + " if n_errors > 0:\n", + " df_labeled = df_labeled[~error_mask].reset_index(drop=True)\n", + " print(f\"Removed {n_errors} rows with labeling errors\")\n", + "\n", + "print(f\"Successfully labeled: {len(df_labeled)} sentences\")\n", + "\n", + "# Preview\n", + "print(f\"\\nPreview:\")\n", + "label_cols = CLASSES + ['None']\n", + "print(df_labeled[['sentence'] + label_cols].head(5))" + ] + }, + { + "cell_type": "markdown", + "id": "b6b4c4d7-2ed1-42a3-8d6b-09c7371bc9ba", + "metadata": {}, + "source": [ + "## STEP 2: Descriptives" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91d6ef4d-6498-4d36-837a-91b44d329cd0", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "label_cols = CLASSES + ['None']\n", + "\n", + "print(\"=\"*60)\n", + "print(\"DESCRIPTIVE STATISTICS\")\n", + "print(\"=\"*60)\n", + "print(f\"Total sentences: {len(df_labeled)}\")\n", + "\n", + "# Class distribution\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"CLASS DISTRIBUTION\")\n", + "print(f\"{'='*60}\")\n", + "print(f\"{'Class':<15} {'Count':>8} {'Share':>10}\")\n", + "print(f\"{'-'*35}\")\n", + "\n", + "for cls in label_cols:\n", + " count = df_labeled[cls].sum()\n", + " share = count / len(df_labeled) * 100\n", + " print(f\"{cls:<15} {count:>8} {share:>9.1f}%\")\n", + "\n", + "# Multi-label distribution\n", + "df_labeled['num_labels'] = df_labeled[CLASSES].sum(axis=1)\n", + "\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"MULTI-LABEL DISTRIBUTION\")\n", + "print(f\"{'='*60}\")\n", + "print(f\"{'# Labels':<15} {'Count':>8} {'Share':>10}\")\n", + "print(f\"{'-'*35}\")\n", + "\n", + "for n in range(7):\n", + " count = (df_labeled['num_labels'] == n).sum()\n", + " if count > 0:\n", + " share = count / len(df_labeled) * 100\n", + " label_text = f\"{n} label{'s' if n != 1 else ''}\"\n", + " print(f\"{label_text:<15} {count:>8} {share:>9.1f}%\")\n", + "\n", + "print(f\"{'-'*35}\")\n", + "print(f\"{'Mean labels':<15} {df_labeled['num_labels'].mean():>8.2f}\")\n", + "print(f\"{'Median labels':<15} {df_labeled['num_labels'].median():>8.0f}\")\n", + "\n", + "# Co-occurrence matrix\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"CO-OCCURRENCE MATRIX\")\n", + "print(f\"{'='*60}\")\n", + "\n", + "cooccurrence = pd.DataFrame(index=CLASSES, columns=CLASSES, dtype=int)\n", + "for cls1 in CLASSES:\n", + " for cls2 in CLASSES:\n", + " cooccurrence.loc[cls1, cls2] = ((df_labeled[cls1] == 1) & (df_labeled[cls2] == 1)).sum()\n", + "\n", + "print(cooccurrence.to_string())\n", + "\n", + "# Top label combinations\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"TOP 10 LABEL COMBINATIONS\")\n", + "print(f\"{'='*60}\")\n", + "\n", + "def get_label_combo(row):\n", + " labels = [cls for cls in CLASSES if row[cls] == 1]\n", + " return str(labels) if labels else \"[]\"\n", + "\n", + "combo_counts = df_labeled.apply(get_label_combo, axis=1).value_counts().head(10)\n", + "print(f\"{'Combination':<50} {'Count':>8} {'Share':>10}\")\n", + "print(f\"{'-'*70}\")\n", + "for combo, count in combo_counts.items():\n", + " share = count / len(df_labeled) * 100\n", + " print(f\"{combo:<50} {count:>8} {share:>9.1f}%\")\n", + "\n", + "# Clean up temp column\n", + "df_labeled.drop(columns=['num_labels'], inplace=True)" + ] + }, + { + "cell_type": "markdown", + "id": "82b12506-7b3a-4f83-b3c9-da733c89306e", + "metadata": {}, + "source": [ + "## STEP 3: Create Holdout and Train Sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86bf25c6-10b5-4d2c-8043-46c62c6598bd", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# STEP 3: CREATE HOLDOUT AND TRAIN SETS\n", + "# =============================================================================\n", + "\n", + "import os\n", + "\n", + "# -----------------------------\n", + "# Configuration\n", + "# -----------------------------\n", + "\n", + "HOLDOUT_SIZE = 1000\n", + "MIN_PER_CLASS = 100\n", + "MAX_PER_CLASS = 333\n", + "MAX_NONE = 100 # Maximum None sentences in holdout\n", + "OUT_DIR = \"./Holdout_Train\"\n", + "\n", + "os.makedirs(OUT_DIR, exist_ok=True)\n", + "\n", + "np.random.seed(42)\n", + "\n", + "# -----------------------------\n", + "# Helper\n", + "# -----------------------------\n", + "\n", + "def get_labels(row):\n", + " \"\"\"Get list of functional labels for a row.\"\"\"\n", + " return [cls for cls in CLASSES if row[cls] == 1]\n", + "\n", + "# Add temp columns\n", + "df_labeled['labels_list'] = df_labeled.apply(get_labels, axis=1)\n", + "df_labeled['num_labels'] = df_labeled[CLASSES].sum(axis=1)\n", + "\n", + "# -----------------------------\n", + "# Stratified Sampling (6 functional classes only — None is derived)\n", + "# -----------------------------\n", + "\n", + "selected_indices = set()\n", + "class_counts = {cls: 0 for cls in CLASSES}\n", + "\n", + "print(\"=\"*60)\n", + "print(\"STRATIFIED HOLDOUT SAMPLING\")\n", + "print(\"=\"*60)\n", + "print(f\"Target holdout size: {HOLDOUT_SIZE}\")\n", + "print(f\"Min per class: {MIN_PER_CLASS} (6 functional classes)\")\n", + "print(f\"Max per class: {MAX_PER_CLASS}\")\n", + "print(f\"None: not constrained (derived, not independently predicted)\")\n", + "\n", + "# Phase 1: Ensure minimums, rarest classes first\n", + "class_freq = {cls: int(df_labeled[cls].sum()) for cls in CLASSES}\n", + "classes_by_rarity = sorted(CLASSES, key=lambda c: class_freq[c])\n", + "\n", + "print(f\"\\n--- Phase 1: Ensure minimums (rarest first) ---\")\n", + "print(f\" Class frequencies: {', '.join(f'{c}: {class_freq[c]}' for c in classes_by_rarity)}\")\n", + "\n", + "for rare_class in classes_by_rarity:\n", + " candidates = df_labeled[\n", + " (df_labeled[rare_class] == 1) & \n", + " (~df_labeled.index.isin(selected_indices))\n", + " ].sort_values('num_labels', ascending=False)\n", + " \n", + " for idx, row in candidates.iterrows():\n", + " if class_counts[rare_class] >= MIN_PER_CLASS:\n", + " break\n", + " labels = row['labels_list']\n", + " if len(labels) > 0 and any(class_counts[lbl] >= MAX_PER_CLASS for lbl in labels):\n", + " continue\n", + " selected_indices.add(idx)\n", + " for lbl in labels:\n", + " class_counts[lbl] += 1\n", + " \n", + " print(f\" {rare_class}: {class_counts[rare_class]} sentences\")\n", + "\n", + "print(f\"\\nAfter Phase 1: {len(selected_indices)} sentences selected\")\n", + "\n", + "# Phase 2: Multi-label sentences\n", + "print(f\"\\n--- Phase 2: Add multi-label sentences ---\")\n", + "\n", + "current_multilabel = df_labeled.loc[list(selected_indices), 'num_labels'].gt(1).sum()\n", + "target_multilabel = int(HOLDOUT_SIZE * 0.18)\n", + "\n", + "multilabel_candidates = df_labeled[\n", + " (df_labeled['num_labels'] >= 2) & \n", + " (~df_labeled.index.isin(selected_indices))\n", + "].sample(frac=1, random_state=42)\n", + "\n", + "for idx, row in multilabel_candidates.iterrows():\n", + " if current_multilabel >= target_multilabel:\n", + " break\n", + " if len(selected_indices) >= HOLDOUT_SIZE:\n", + " break\n", + " labels = row['labels_list']\n", + " if len(labels) > 0 and any(class_counts[lbl] >= MAX_PER_CLASS for lbl in labels):\n", + " continue\n", + " selected_indices.add(idx)\n", + " for lbl in labels:\n", + " class_counts[lbl] += 1\n", + " current_multilabel += 1\n", + "\n", + "print(f\" Multi-label sentences: {current_multilabel}\")\n", + "print(f\" Total selected: {len(selected_indices)}\")\n", + "\n", + "# Phase 3: Fill remaining (cap None sentences)\n", + "none_count = sum(1 for idx in selected_indices if df_labeled.loc[idx, CLASSES].sum() == 0)\n", + "\n", + "print(f\"\\n--- Phase 3: Fill to {HOLDOUT_SIZE} (max None: {MAX_NONE}) ---\")\n", + "print(f\" None so far: {none_count}\")\n", + "\n", + "if len(selected_indices) < HOLDOUT_SIZE:\n", + " fill_candidates = df_labeled[\n", + " ~df_labeled.index.isin(selected_indices)\n", + " ].sample(frac=1, random_state=42)\n", + " \n", + " for idx, row in fill_candidates.iterrows():\n", + " if len(selected_indices) >= HOLDOUT_SIZE:\n", + " break\n", + " labels = row['labels_list']\n", + " \n", + " # Skip if any functional class would exceed max\n", + " if len(labels) > 0 and any(class_counts[lbl] >= MAX_PER_CLASS for lbl in labels):\n", + " continue\n", + " \n", + " # Skip if this is a None sentence and we've hit the cap\n", + " if len(labels) == 0 and none_count >= MAX_NONE:\n", + " continue\n", + " \n", + " selected_indices.add(idx)\n", + " for lbl in labels:\n", + " class_counts[lbl] += 1\n", + " if len(labels) == 0:\n", + " none_count += 1\n", + "\n", + "print(f\" Final count: {len(selected_indices)}\")\n", + "\n", + "# -----------------------------\n", + "# Create Clean DataFrames (LABELS BLANKED OUT)\n", + "# -----------------------------\n", + "\n", + "# Holdout and train get sentences + empty label columns\n", + "df_holdout = df_labeled.loc[list(selected_indices), ['sentence']].copy()\n", + "df_train = df_labeled.loc[~df_labeled.index.isin(selected_indices), ['sentence']].copy()\n", + "\n", + "# Add empty label columns (students fill these in)\n", + "for cls in CLASSES + ['None']:\n", + " df_holdout[cls] = \"\"\n", + " df_train[cls] = \"\"\n", + "\n", + "# Shuffle both\n", + "df_holdout = df_holdout.sample(frac=1, random_state=42).reset_index(drop=True)\n", + "df_train = df_train.sample(frac=1, random_state=42).reset_index(drop=True)\n", + "\n", + "# Clean up temp columns from df_labeled\n", + "df_labeled.drop(columns=['labels_list', 'num_labels'], inplace=True, errors='ignore')\n", + "\n", + "# -----------------------------\n", + "# Verification (using original labels for reporting only)\n", + "# -----------------------------\n", + "\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"VERIFICATION\")\n", + "print(f\"{'='*60}\")\n", + "\n", + "print(f\"\\nDataset sizes:\")\n", + "print(f\" Holdout: {len(df_holdout)}\")\n", + "print(f\" Train: {len(df_train)}\")\n", + "print(f\" Total: {len(df_holdout) + len(df_train)}\")\n", + "\n", + "# Verify no overlap\n", + "overlap = set(df_holdout['sentence']) & set(df_train['sentence'])\n", + "print(f\"Overlap check: {len(overlap)} sentences (should be 0)\")\n", + "\n", + "# Report holdout class distribution (from original labels, for our reference)\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"HOLDOUT CLASS DISTRIBUTION (from LLM labels, for reference)\")\n", + "print(f\"{'='*60}\")\n", + "print(f\"{'Class':<15} {'Count':>8} {'Share':>10} {'Min OK':>10} {'Max OK':>10}\")\n", + "print(f\"{'-'*55}\")\n", + "\n", + "holdout_sentences = set(df_holdout['sentence'])\n", + "df_holdout_check = df_labeled[df_labeled['sentence'].isin(holdout_sentences)]\n", + "\n", + "for cls in CLASSES:\n", + " count = df_holdout_check[cls].sum()\n", + " share = count / len(df_holdout_check) * 100\n", + " min_ok = \"✓\" if count >= MIN_PER_CLASS else \"✗\"\n", + " max_ok = \"✓\" if count <= MAX_PER_CLASS else \"✗\"\n", + " print(f\"{cls:<15} {count:>8} {share:>9.1f}% {min_ok:>10} {max_ok:>10}\")\n", + "\n", + "# None (derived, for reference)\n", + "none_count = (df_holdout_check[CLASSES].sum(axis=1) == 0).sum()\n", + "none_share = none_count / len(df_holdout_check) * 100\n", + "none_ok = \"✓\" if none_count <= MAX_NONE else \"✗\"\n", + "print(f\"{'None (derived)':<15} {none_count:>8} {none_share:>9.1f}% {'':>10} {none_ok:>10} (max {MAX_NONE})\")\n", + "\n", + "# Multi-label distribution\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"HOLDOUT MULTI-LABEL DISTRIBUTION (from LLM labels, for reference)\")\n", + "print(f\"{'='*60}\")\n", + "print(f\"{'# Labels':<15} {'Count':>8} {'Share':>10}\")\n", + "print(f\"{'-'*35}\")\n", + "\n", + "num_labels = df_holdout_check[CLASSES].sum(axis=1)\n", + "for n in range(7):\n", + " count = (num_labels == n).sum()\n", + " if count > 0:\n", + " share = count / len(df_holdout_check) * 100\n", + " print(f\"{n} label{'s' if n != 1 else '':<14} {count:>8} {share:>9.1f}%\")\n", + "\n", + "# Preview\n", + "print(f\"\\nHoldout preview (labels are blank for human experts to fill):\")\n", + "print(df_holdout.head(5))\n", + "print(f\"\\nTrain preview (labels are blank for genAI to fill):\")\n", + "print(df_train.head(5))\n", + "\n", + "# -----------------------------\n", + "# Save\n", + "# -----------------------------\n", + "\n", + "df_holdout.to_csv(f\"{OUT_DIR}/holdout.csv\", index=False)\n", + "df_holdout.to_pickle(f\"{OUT_DIR}/holdout.pkl\")\n", + "df_train.to_csv(f\"{OUT_DIR}/train.csv\", index=False)\n", + "df_train.to_pickle(f\"{OUT_DIR}/train.pkl\")\n", + "\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"SAVED\")\n", + "print(f\"{'='*60}\")\n", + "print(f\" Holdout: {OUT_DIR}/holdout.csv (.pkl)\")\n", + "print(f\" Train: {OUT_DIR}/train.csv (.pkl)\")\n", + "print(f\"\\nColumns: {list(df_holdout.columns)}\")\n", + "print(f\"Label columns are EMPTY - human experts must fill these in\")" + ] + }, + { + "cell_type": "markdown", + "id": "189837d8-bf61-4924-981a-3d5f83ec9e78", + "metadata": {}, + "source": [ + "# **Human Labeling of Holdout for Ground Truth**\n", + "\n", + "- Now we have a more or less balanced holdout set\n", + "- We need to get this labeled by human experts, that is, domain experts / subject matter experts\n", + "- You will need at least three humans to evaluate each sentence independently,\n", + " *or* work as an expert team and discuss each sentence to assign it the appropriate class (or labels)\n", + "- Make sure to save the file with the labels AND ***be sure that the text (e.g., sentences) are not broken / contain formatting or ASCI or HTML errors***\n", + " - You might want to label in excel, but then merge the labels from the XLSX or CSV file to the the holdout.pkl file (which better preserved the actual texts)\n", + "- If you have independent human experts label the holdout, then you need to have a majority vote per label (this suggests you need an odd number of human expert labelers). Good practice is also to check inter-rater agreement on the labels." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fac02954-21d8-419b-9215-733639f8fdb6", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3abaaf59-24e0-49dc-aef9-47f88dbc4469", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e3aae9b-6565-413d-aa57-78113809acf5", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "546b0d8b-c049-4ae3-8f37-f471036a99bc", + "metadata": {}, + "source": [ + "# **genAI Labeling of Holdout**\n", + "> Now that you have a ground truth for your holdout set, you need to determine which genAI model performs best on it\n", + "\n", + "#### Approach: \n", + "1. **Label the holdout 3 times with a genAI mode** (use code from above, but instead of loading the \"10K_unlabeled.csv\" file for preliminary labeling, you want to load your holdout.pkl file (the one without your human labels).\n", + "2. **Check** the three runs for **label agreement** (krippendorff's alpha, for example)\n", + "3. Get the **majority vote** (for class or for labels - this example will be for ***labels*** that, unlike classes, are ***NOT mutually exclusive***\n", + "4. **Compare to human expert labels on holdout** set to measure genAI labeling **performance**\n", + "5. **Repeat** for other models (of other vendors)\n", + "6. **Find model** that does **best**, then use it to label the train set (3 times - same genAI API code from above but on different file: makre sure to adjust the paths where you store the data so you don't overwrite your holdout labels), get majority votes on train, then use those to fine-tune a pretrained and open-source LLM (like RoBERTa Large)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c81a6d50-b86b-4384-824c-3f7b1262d63e", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# OUTPUT CONFIGURATION - SET for Base, Holdout, Train\n", + "# =============================================================================\n", + "CHECKPOINT_DIR = \"./checkpoints\"\n", + "OUTPUT_DIR = \"./output\"\n", + "HOLDOUT_DIR = f\"{OUTPUT_DIR}/holdout\"\n", + "TRAIN_DIR = f\"{OUTPUT_DIR}/train\"\n", + "\n", + "# =============================================================================\n", + "# Load Holdout (not the human labeled, but the blank)\n", + "# =============================================================================\n", + "\n", + "# # Load Data: Labeling of (balanced) holdout data\n", + "df = pd.read_pickle(\"Holdout_Train/holdout.pkl\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bea59bb-1faf-4469-819a-e5cbcf9d5425", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# Query genAI model via API\n", + "# =============================================================================\n", + "\n", + "# Here an example for gpt-4o via UNC Azure API\n", + "results = run_classification(\n", + " df,\n", + " sentence_col=\"sentence\", # --> Check if your file has the column as named here where the text is (sentence vs sentences vs text vs tweet vs ... )\n", + " vendor=\"azure\",\n", + " model=\"gpt-4o\",\n", + " runs=[1,2,3], # Doing 3 runs for conistency / replicability: will later take majority vote per lable across runs\n", + " output_dir=HOLDOUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + " checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "207a4e45-dc15-4df4-855c-d151c1a9f649", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# Query another genAI model via API\n", + "# =============================================================================\n", + "\n", + "# Here an example for gpt-4o via UNC Azure API\n", + "results = run_classification(\n", + " df,\n", + " sentence_col=\"sentence\", # --> Check if your file has the column as named here where the text is (sentence vs sentences vs text vs tweet vs ... )\n", + " vendor=\"azure\",\n", + " model=\"gpt-4.1\",\n", + " runs=[1,2,3], # Doing 3 runs for conistency / replicability: will later take majority vote per lable across runs\n", + " output_dir=HOLDOUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + " checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "158e92e8-1f1a-4b58-bb8e-a000f4fe9b59", + "metadata": {}, + "source": [ + "# genAI Label Agreement\n", + "\n", + "- How consistent is genAI in its labels?\n", + "- Test the extent to which genAI's labels agree across runs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d0c9f7c-d74d-4cc5-a16c-55e20dc0293e", + "metadata": {}, + "outputs": [], + "source": [ + "# pip install -q -U krippendorff # already done at the very beginning. Keeping this as a reminder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "793ba990-1e0d-4376-b2cc-887bbe633ee8", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# INTER-RATER AGREEMENT ACROSS RUNS (Krippendorff's Alpha)\n", + "# =============================================================================\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import ast\n", + "import glob\n", + "import re\n", + "import krippendorff \n", + "\n", + "# -----------------------------\n", + "# Configuration\n", + "# -----------------------------\n", + "\n", + "HOLDOUT_DIR = \"./output/holdout\" # Set your holdout output directory\n", + "RUNS = [1, 2, 3] # Which runs to compare\n", + "CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "063a4981-1524-4bb0-a91c-c937f345fc27", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Helper Functions\n", + "# -----------------------------\n", + "\n", + "def parse_labels_safe(labels):\n", + " \"\"\"Parse labels from various formats.\"\"\"\n", + " if labels is None:\n", + " return []\n", + " if isinstance(labels, list):\n", + " return labels\n", + " if isinstance(labels, str):\n", + " try:\n", + " parsed = ast.literal_eval(labels)\n", + " return parsed if isinstance(parsed, list) else []\n", + " except:\n", + " return []\n", + " return []\n", + "\n", + "def parse_filename(filename):\n", + " \"\"\"\n", + " Parse filename like 'holdout_azure_gpt-4.1_chat_run1.pkl' \n", + " into (vendor, model, mode, run).\n", + " Handles optional dataset prefix (holdout_, train_, prelim_, output_).\n", + " \"\"\"\n", + " # Remove extension\n", + " name = filename.replace('.pkl', '').replace('.csv', '')\n", + " \n", + " # Strip known dataset prefixes\n", + " for prefix in ['holdout_', 'train_', 'prelim_', 'output_']:\n", + " if name.startswith(prefix):\n", + " name = name[len(prefix):]\n", + " break\n", + " \n", + " # Extract run number from end\n", + " match = re.search(r'_run(\\d+)$', name)\n", + " if not match:\n", + " return None\n", + " run = int(match.group(1))\n", + " name = name[:match.start()]\n", + " \n", + " # Extract mode (last part before run)\n", + " parts = name.rsplit('_', 1)\n", + " if len(parts) != 2:\n", + " return None\n", + " mode = parts[1]\n", + " remainder = parts[0]\n", + " \n", + " # Extract vendor (first part) and model (rest)\n", + " first_underscore = remainder.index('_')\n", + " vendor = remainder[:first_underscore]\n", + " model = remainder[first_underscore + 1:]\n", + " \n", + " return vendor, model, mode, run\n", + "\n", + "def discover_models(holdout_dir, runs):\n", + " \"\"\"\n", + " Discover all vendor_model_mode combinations that have ALL specified runs.\n", + " Returns dict: {(vendor, model, mode): {run: filepath}}\n", + " \"\"\"\n", + " models = {}\n", + " \n", + " for run in runs:\n", + " run_dir = f\"{holdout_dir}/Run{run}\"\n", + " pkl_files = glob.glob(f\"{run_dir}/*.pkl\")\n", + " \n", + " for filepath in pkl_files:\n", + " filename = filepath.split(\"/\")[-1]\n", + " parsed = parse_filename(filename)\n", + " if parsed is None:\n", + " continue\n", + " \n", + " vendor, model, mode, file_run = parsed\n", + " if file_run != run:\n", + " continue\n", + " \n", + " key = (vendor, model, mode)\n", + " if key not in models:\n", + " models[key] = {}\n", + " models[key][run] = filepath\n", + " \n", + " # Keep only models that have ALL specified runs\n", + " complete = {\n", + " key: paths for key, paths in models.items()\n", + " if all(r in paths for r in runs)\n", + " }\n", + " \n", + " return complete\n", + "\n", + "def load_and_binarize(filepath, classes):\n", + " \"\"\"Load pkl and convert labels to binary columns.\"\"\"\n", + " df = pd.read_pickle(filepath)\n", + " \n", + " parsed = df['labels'].apply(parse_labels_safe)\n", + " \n", + " binary = pd.DataFrame(index=df.index)\n", + " binary['id'] = df['id'] if 'id' in df.columns else df.index\n", + " binary['sentence'] = df['sentence']\n", + " \n", + " for cls in classes:\n", + " binary[cls] = parsed.apply(lambda x: 1 if cls in x else 0)\n", + " \n", + " # Remove error rows\n", + " if 'error' in df.columns:\n", + " binary = binary[df['error'].isna()].reset_index(drop=True)\n", + " \n", + " return binary\n", + "\n", + "def compute_agreement(model_runs, classes, runs):\n", + " \"\"\"\n", + " Compute Krippendorff's alpha per class and overall.\n", + " Only evaluates the 6 functional classes — None is derived, not labeled.\n", + " \n", + " Args:\n", + " model_runs: dict {run: filepath}\n", + " classes: list of class names (6 functional classes)\n", + " runs: list of run numbers\n", + " \n", + " Returns:\n", + " dict with per-class and overall alpha\n", + " \"\"\"\n", + " # Load all runs\n", + " run_dfs = {}\n", + " for run in runs:\n", + " run_dfs[run] = load_and_binarize(model_runs[run], classes)\n", + " \n", + " # Verify all runs have same sentences\n", + " n_sentences = len(run_dfs[runs[0]])\n", + " for run in runs:\n", + " assert len(run_dfs[run]) == n_sentences, \\\n", + " f\"Run {run} has {len(run_dfs[run])} sentences, expected {n_sentences}\"\n", + " \n", + " results = {}\n", + " \n", + " # Per-class alpha (6 functional classes only)\n", + " for cls in classes:\n", + " # Reliability matrix: rows = raters (runs), columns = units (sentences)\n", + " reliability_matrix = np.array([\n", + " run_dfs[run][cls].values for run in runs\n", + " ])\n", + " \n", + " # Krippendorff's alpha (nominal level for binary data)\n", + " try:\n", + " alpha = krippendorff.alpha(\n", + " reliability_data=reliability_matrix,\n", + " level_of_measurement='nominal',\n", + " )\n", + " except:\n", + " alpha = np.nan\n", + " \n", + " results[cls] = alpha\n", + " \n", + " # Overall alpha (flatten 6 functional classes into one reliability matrix)\n", + " overall_matrix = np.hstack([\n", + " np.array([run_dfs[run][cls].values for run in runs])\n", + " for cls in classes\n", + " ])\n", + " \n", + " try:\n", + " results['Overall'] = krippendorff.alpha(\n", + " reliability_data=overall_matrix,\n", + " level_of_measurement='nominal',\n", + " )\n", + " except:\n", + " results['Overall'] = np.nan\n", + " \n", + " # Pairwise agreement percentage per class (6 functional classes only)\n", + " pairwise = {}\n", + " for cls in classes:\n", + " agreements = []\n", + " for i, r1 in enumerate(runs):\n", + " for r2 in runs[i+1:]:\n", + " agree = (run_dfs[r1][cls].values == run_dfs[r2][cls].values).mean()\n", + " agreements.append(agree)\n", + " pairwise[cls] = np.mean(agreements)\n", + " pairwise['Overall'] = np.mean([pairwise[cls] for cls in classes])\n", + " \n", + " return results, pairwise, n_sentences\n", + "\n", + "\n", + "# -----------------------------\n", + "# Main Analysis\n", + "# -----------------------------\n", + "\n", + "print(\"=\"*70)\n", + "print(\"INTER-RATER AGREEMENT ANALYSIS (Krippendorff's Alpha)\")\n", + "print(\"=\"*70)\n", + "print(f\"Holdout dir: {HOLDOUT_DIR}\")\n", + "print(f\"Runs: {RUNS}\")\n", + "print(f\"Classes evaluated: {CLASSES} (None excluded — derived, not labeled)\")\n", + "\n", + "# Discover models\n", + "models = discover_models(HOLDOUT_DIR, RUNS)\n", + "\n", + "print(f\"\\nFound {len(models)} model(s) with all {len(RUNS)} runs:\")\n", + "for (vendor, model, mode), paths in models.items():\n", + " print(f\" {vendor} / {model} / {mode}\")\n", + " for run, path in sorted(paths.items()):\n", + " print(f\" Run {run}: {path}\")\n", + "\n", + "# Compute agreement for each model\n", + "all_results = {}\n", + "\n", + "for (vendor, model, mode), paths in models.items():\n", + " model_key = f\"{vendor}_{model}_{mode}\"\n", + " \n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"MODEL: {vendor} / {model} / {mode}\")\n", + " print(f\"{'='*70}\")\n", + " \n", + " alphas, pairwise, n_sentences = compute_agreement(paths, CLASSES, RUNS)\n", + " all_results[model_key] = {'alphas': alphas, 'pairwise': pairwise}\n", + " \n", + " print(f\"Sentences: {n_sentences}\")\n", + " print(f\"Runs compared: {RUNS}\")\n", + " \n", + " # Per-class table (6 functional classes + Overall)\n", + " print(f\"\\n{'Class':<15} {'k-Alpha':>10} {'Agreement Share':>16}\")\n", + " print(\"-\"*45)\n", + " \n", + " for cls in CLASSES + ['Overall']:\n", + " alpha = alphas[cls]\n", + " agree = pairwise[cls]\n", + " \n", + " if cls == 'Overall':\n", + " print(\"-\"*45)\n", + " \n", + " print(f\"{cls:<15} {alpha:>10.4f} {agree:>15.1%}\")\n", + " \n", + " # Interpretation guide\n", + " print(f\"\\nKrippendorff's Alpha Interpretation:\")\n", + " print(f\" α ≥ 0.80 → Reliable agreement\")\n", + " print(f\" α ≥ 0.667 → Tentative agreement (acceptable for some purposes)\")\n", + " print(f\" α < 0.667 → Unreliable / insufficient agreement\")\n", + " print(f\"\\n Note: None class excluded — it is derived (all classes = 0),\")\n", + " print(f\" not independently labeled by the model.\")\n", + "\n", + "# -----------------------------\n", + "# Summary across models\n", + "# -----------------------------\n", + "\n", + "if len(all_results) > 1:\n", + " print(f\"\\n{'='*70}\")\n", + " print(\"SUMMARY ACROSS MODELS\")\n", + " print(f\"{'='*70}\")\n", + " \n", + " summary_rows = []\n", + " for model_key, data in all_results.items():\n", + " row = {'Model': model_key}\n", + " row['Overall_Alpha'] = data['alphas']['Overall']\n", + " row['Overall_Agreement'] = data['pairwise']['Overall']\n", + " for cls in CLASSES:\n", + " row[f'{cls}_Alpha'] = data['alphas'][cls]\n", + " summary_rows.append(row)\n", + " \n", + " df_summary = pd.DataFrame(summary_rows)\n", + " df_summary = df_summary.sort_values('Overall_Alpha', ascending=False)\n", + " \n", + " print(f\"\\n{'Model':<35} {'Overall α':>12} {'Agreement':>12}\")\n", + " print(\"-\"*60)\n", + " for _, row in df_summary.iterrows():\n", + " print(f\"{row['Model']:<35} {row['Overall_Alpha']:>12.4f} {row['Overall_Agreement']:>11.1%}\")\n", + " \n", + " # Save summary\n", + " df_summary.to_csv(f\"{HOLDOUT_DIR}/agreement_summary.csv\", index=False)\n", + " print(f\"\\nSaved summary to {HOLDOUT_DIR}/agreement_summary.csv\")" + ] + }, + { + "cell_type": "markdown", + "id": "b9b669eb-b3ea-4137-8d58-70455261198c", + "metadata": {}, + "source": [ + "# Majority Votes across Runs\n", + "\n", + "- Across runs\n", + "- Does not handle ties yet (assumes odd number of total runs)\n", + "- Assumes multi-label problem (one sentence can have multiple labels)\n", + "- Assumes that if majority is [] (none), then all other classes must be negative (0 or FALSE) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3526fb48-f557-4025-b228-f4f75915abc2", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# MAJORITY VOTE PER MODEL (across its own runs)\n", + "# =============================================================================\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import glob\n", + "import ast\n", + "import re\n", + "import os\n", + "\n", + "# -----------------------------\n", + "# Configuration\n", + "# -----------------------------\n", + "\n", + "HOLDOUT_DIR = \"./output/holdout\" # Where Run1/, Run2/, Run3/ etc. are\n", + "RUNS = [1, 2, 3] # Which runs to aggregate\n", + "CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95aa527c-daa4-4063-a5e1-fc438f2f76fb", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Helper Functions\n", + "# -----------------------------\n", + "\n", + "def parse_labels_safe(labels):\n", + " \"\"\"Parse labels from various formats.\"\"\"\n", + " if labels is None:\n", + " return []\n", + " if isinstance(labels, list):\n", + " return labels\n", + " if isinstance(labels, str):\n", + " try:\n", + " parsed = ast.literal_eval(labels)\n", + " return parsed if isinstance(parsed, list) else []\n", + " except:\n", + " return []\n", + " return []\n", + "\n", + "def parse_filename(filename):\n", + " \"\"\"\n", + " Parse filename like 'holdout_azure_gpt-4.1_chat_run1.pkl'\n", + " into (vendor, model, mode, run).\n", + " Handles optional dataset prefix (holdout_, train_, prelim_, output_).\n", + " \"\"\"\n", + " name = filename.replace('.pkl', '').replace('.csv', '')\n", + " \n", + " # Strip known dataset prefixes\n", + " for prefix in ['holdout_', 'train_', 'prelim_', 'output_']:\n", + " if name.startswith(prefix):\n", + " name = name[len(prefix):]\n", + " break\n", + " \n", + " match = re.search(r'_run(\\d+)$', name)\n", + " if not match:\n", + " return None\n", + " run = int(match.group(1))\n", + " name = name[:match.start()]\n", + " parts = name.rsplit('_', 1)\n", + " if len(parts) != 2:\n", + " return None\n", + " mode = parts[1]\n", + " remainder = parts[0]\n", + " first_underscore = remainder.index('_')\n", + " vendor = remainder[:first_underscore]\n", + " model = remainder[first_underscore + 1:]\n", + " return vendor, model, mode, run\n", + "\n", + "def discover_models(holdout_dir, runs):\n", + " \"\"\"Find all vendor_model_mode combos that have ALL specified runs.\"\"\"\n", + " models = {}\n", + " for run in runs:\n", + " run_dir = f\"{holdout_dir}/Run{run}\"\n", + " if not os.path.exists(run_dir):\n", + " continue\n", + " for filepath in glob.glob(f\"{run_dir}/*.pkl\"):\n", + " filename = filepath.split(\"/\")[-1]\n", + " parsed = parse_filename(filename)\n", + " if parsed is None:\n", + " continue\n", + " vendor, model, mode, file_run = parsed\n", + " if file_run != run:\n", + " continue\n", + " key = (vendor, model, mode)\n", + " if key not in models:\n", + " models[key] = {}\n", + " models[key][run] = filepath\n", + " \n", + " # Keep only models with ALL specified runs\n", + " return {k: v for k, v in models.items() if all(r in v for r in runs)}\n", + "\n", + "def compute_majority_vote(model_runs, classes, runs):\n", + " \"\"\"\n", + " Compute majority vote for a single model across its runs.\n", + " \n", + " Votes are counted for the 6 functional classes only.\n", + " None is derived: if no functional class gets majority, None = 1.\n", + " None_votes tracks how many runs returned an empty label list (for reference).\n", + " \n", + " Returns DataFrame with sentence + binary labels + vote counts.\n", + " \"\"\"\n", + " # Load all runs\n", + " run_dfs = {}\n", + " for run in runs:\n", + " run_dfs[run] = pd.read_pickle(model_runs[run])\n", + " \n", + " # Collect votes per sentence\n", + " votes = {}\n", + " for run in runs:\n", + " df = run_dfs[run]\n", + " for _, row in df.iterrows():\n", + " sent_id = row['id']\n", + " sentence = row['sentence']\n", + " labels = parse_labels_safe(row['labels'])\n", + " \n", + " # Skip error rows\n", + " if 'error' in df.columns and pd.notna(row.get('error')):\n", + " continue\n", + " \n", + " if sent_id not in votes:\n", + " votes[sent_id] = {\n", + " 'sentence': sentence,\n", + " **{cls: 0 for cls in classes},\n", + " 'none_runs': 0,\n", + " 'total_runs': 0,\n", + " }\n", + " \n", + " votes[sent_id]['total_runs'] += 1\n", + " \n", + " if len(labels) == 0:\n", + " votes[sent_id]['none_runs'] += 1\n", + " else:\n", + " for label in labels:\n", + " if label in classes:\n", + " votes[sent_id][label] += 1\n", + " \n", + " # Calculate majority\n", + " results = []\n", + " for sent_id, data in votes.items():\n", + " total = data['total_runs']\n", + " threshold = total / 2 # > 50%\n", + " \n", + " row = {'id': sent_id, 'sentence': data['sentence']}\n", + " \n", + " # Majority vote on 6 functional classes\n", + " for cls in classes:\n", + " row[cls] = 1 if data[cls] > threshold else 0\n", + " \n", + " # None is derived: all functional classes = 0\n", + " row['None'] = 1 if all(row[cls] == 0 for cls in classes) else 0\n", + " \n", + " # Vote counts (for reference / pseudo-probabilities)\n", + " for cls in classes:\n", + " row[f'{cls}_votes'] = data[cls]\n", + " row['None_votes'] = data['none_runs']\n", + " row['total_runs'] = total\n", + " \n", + " results.append(row)\n", + " \n", + " df_result = pd.DataFrame(results)\n", + " label_cols = classes + ['None']\n", + " vote_cols = [f'{cls}_votes' for cls in classes] + ['None_votes', 'total_runs']\n", + " df_result = df_result[['id', 'sentence'] + label_cols + vote_cols]\n", + " df_result = df_result.sort_values('id').reset_index(drop=True)\n", + " \n", + " return df_result\n", + "\n", + "# -----------------------------\n", + "# Main\n", + "# -----------------------------\n", + "\n", + "print(\"=\"*70)\n", + "print(\"MAJORITY VOTE PER MODEL\")\n", + "print(\"=\"*70)\n", + "print(f\"Holdout dir: {HOLDOUT_DIR}\")\n", + "print(f\"Runs: {RUNS}\")\n", + "\n", + "# Discover models\n", + "models = discover_models(HOLDOUT_DIR, RUNS)\n", + "\n", + "print(f\"\\nFound {len(models)} model(s) with all {len(RUNS)} runs:\")\n", + "for (vendor, model, mode), paths in models.items():\n", + " print(f\" {vendor} / {model} / {mode}\")\n", + "\n", + "# Process each model\n", + "all_majority = {}\n", + "\n", + "for (vendor, model, mode), paths in models.items():\n", + " model_key = f\"{vendor}_{model}_{mode}\"\n", + " \n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"MODEL: {vendor} / {model} / {mode}\")\n", + " print(f\"{'='*70}\")\n", + " \n", + " # Compute majority vote\n", + " df_mv = compute_majority_vote(paths, CLASSES, RUNS)\n", + " all_majority[model_key] = df_mv\n", + " \n", + " # Report\n", + " print(f\"Total sentences: {len(df_mv)}\")\n", + " print(f\"Runs per sentence: {df_mv['total_runs'].iloc[0]}\")\n", + " \n", + " print(f\"\\n{'Class':<15} {'Count':>8} {'Share':>10}\")\n", + " print(\"-\"*35)\n", + " for cls in CLASSES:\n", + " count = df_mv[cls].sum()\n", + " share = count / len(df_mv) * 100\n", + " print(f\"{cls:<15} {count:>8} {share:>9.1f}%\")\n", + " \n", + " # None (derived)\n", + " none_count = df_mv['None'].sum()\n", + " none_share = none_count / len(df_mv) * 100\n", + " print(\"-\"*35)\n", + " print(f\"{'None (derived)':<15} {none_count:>8} {none_share:>9.1f}%\")\n", + " \n", + " # Multi-label distribution\n", + " num_labels = df_mv[CLASSES].sum(axis=1)\n", + " print(f\"\\n{'# Labels':<15} {'Count':>8} {'Share':>10}\")\n", + " print(\"-\"*35)\n", + " for n in range(7):\n", + " count = (num_labels == n).sum()\n", + " if count > 0:\n", + " share = count / len(df_mv) * 100\n", + " label_text = f\"{n} label{'s' if n != 1 else ''}\"\n", + " print(f\"{label_text:<15} {count:>8} {share:>9.1f}%\")\n", + " \n", + " # Save (with dataset prefix for consistency)\n", + " dataset_tag = os.path.basename(os.path.normpath(HOLDOUT_DIR)).lower()\n", + " save_path = f\"{HOLDOUT_DIR}/{dataset_tag}_{model_key}_majority_vote\"\n", + " df_mv.to_csv(f\"{save_path}.csv\", index=False)\n", + " df_mv.to_pickle(f\"{save_path}.pkl\")\n", + " print(f\"\\nSaved: {save_path}.csv (.pkl)\")\n", + "\n", + "# -----------------------------\n", + "# Summary across models\n", + "# -----------------------------\n", + "\n", + "if len(all_majority) > 1:\n", + " print(f\"\\n{'='*70}\")\n", + " print(\"COMPARISON ACROSS MODELS\")\n", + " print(f\"{'='*70}\")\n", + " \n", + " # Header\n", + " model_keys = list(all_majority.keys())\n", + " header = f\"{'Class':<15}\"\n", + " for key in model_keys:\n", + " short = key.split('_', 1)[1] # Remove vendor prefix for display\n", + " header += f\" {short:>20}\"\n", + " print(header)\n", + " print(\"-\" * (15 + 21 * len(model_keys)))\n", + " \n", + " for cls in CLASSES:\n", + " row = f\"{cls:<15}\"\n", + " for key in model_keys:\n", + " df_mv = all_majority[key]\n", + " count = df_mv[cls].sum()\n", + " share = count / len(df_mv) * 100\n", + " row += f\" {count:>8} ({share:>5.1f}%)\"\n", + " print(row)\n", + " \n", + " # None row\n", + " row = f\"{'None (derived)':<15}\"\n", + " for key in model_keys:\n", + " df_mv = all_majority[key]\n", + " count = df_mv['None'].sum()\n", + " share = count / len(df_mv) * 100\n", + " row += f\" {count:>8} ({share:>5.1f}%)\"\n", + " print(\"-\" * (15 + 21 * len(model_keys)))\n", + " print(row)\n", + "\n", + "print(f\"\\n{'='*70}\")\n", + "print(\"DONE\")\n", + "print(f\"{'='*70}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4222ebe5-8ee4-444d-9d43-26a155112487", + "metadata": {}, + "source": [ + "# genAI Model Performance\n", + "\n", + "- Check genAI majority labels against ground truth from human experts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "820abbf4-8938-4c3f-bfa0-a458a4841f5a", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# EVALUATE GenAI MODELS vs HUMAN LABELS ON HOLDOUT\n", + "# =============================================================================\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import glob\n", + "import json\n", + "from sklearn.metrics import (\n", + " f1_score,\n", + " roc_auc_score,\n", + " matthews_corrcoef,\n", + " precision_score,\n", + " recall_score,\n", + " confusion_matrix,\n", + " hamming_loss,\n", + " jaccard_score,\n", + ")\n", + "\n", + "# -----------------------------\n", + "# Configuration\n", + "# -----------------------------\n", + "\n", + "HOLDOUT_DIR = \"./output/holdout\" # Where model majority vote files are\n", + "HUMAN_LABELS_PATH = \"./Holdout_Train/holdout_human.pkl\" # Ground truth\n", + "CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4ffde335-fa5c-4944-a0ff-1b25fbd054d9", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Load Human Labels (Ground Truth)\n", + "# -----------------------------\n", + "\n", + "print(\"=\"*70)\n", + "print(\"EVALUATE GenAI MODELS vs HUMAN LABELS\")\n", + "print(\"=\"*70)\n", + "\n", + "df_human = pd.read_pickle(HUMAN_LABELS_PATH)\n", + "print(f\"Human-labeled holdout: {len(df_human)} sentences\")\n", + "print(f\"Source: {HUMAN_LABELS_PATH}\")\n", + "\n", + "print(f\"\\nHuman label distribution:\")\n", + "for cls in CLASSES:\n", + " count = df_human[cls].sum()\n", + " share = count / len(df_human) * 100\n", + " print(f\" {cls}: {int(count)} ({share:.1f}%)\")\n", + "none_count = (df_human[CLASSES].sum(axis=1) == 0).sum()\n", + "print(f\" None (derived): {none_count} ({none_count / len(df_human) * 100:.1f}%)\")\n", + "\n", + "# -----------------------------\n", + "# Discover Model Majority Vote Files\n", + "# -----------------------------\n", + "\n", + "mv_files = glob.glob(f\"{HOLDOUT_DIR}/*_majority_vote.pkl\")\n", + "print(f\"\\nFound {len(mv_files)} model majority vote file(s):\")\n", + "for f in mv_files:\n", + " print(f\" {f.split('/')[-1]}\")\n", + "\n", + "# -----------------------------\n", + "# Evaluation Functions\n", + "# -----------------------------\n", + "\n", + "def evaluate_multilabel(y_true, y_pred, class_names):\n", + " \"\"\"\n", + " Comprehensive multi-label evaluation on the 6 functional classes.\n", + " \n", + " Each label is an independent binary decision.\n", + " None is derived (all classes = 0), not independently predicted,\n", + " so it is excluded from macro averaging but reported separately.\n", + " No AUC: requires continuous probabilities, not available for binary GenAI predictions.\n", + " \"\"\"\n", + " # Per-class metrics (6 functional classes)\n", + " per_class = {}\n", + " for i, cls in enumerate(class_names):\n", + " y_t = y_true[:, i]\n", + " y_p = y_pred[:, i]\n", + " \n", + " per_class[cls] = {\n", + " 'f1': f1_score(y_t, y_p, zero_division=0),\n", + " 'precision': precision_score(y_t, y_p, zero_division=0),\n", + " 'recall': recall_score(y_t, y_p, zero_division=0),\n", + " 'mcc': matthews_corrcoef(y_t, y_p),\n", + " 'support_true': int(y_t.sum()),\n", + " 'support_pred': int(y_p.sum()),\n", + " }\n", + " \n", + " # Macro averages (6 functional classes only)\n", + " macro = {\n", + " 'f1_macro': np.mean([per_class[c]['f1'] for c in class_names]),\n", + " 'precision_macro': np.mean([per_class[c]['precision'] for c in class_names]),\n", + " 'recall_macro': np.mean([per_class[c]['recall'] for c in class_names]),\n", + " 'mcc_macro': np.mean([per_class[c]['mcc'] for c in class_names]),\n", + " }\n", + " \n", + " # Micro averages (pooled across 6 classes)\n", + " micro = {\n", + " 'f1_micro': f1_score(y_true, y_pred, average='micro', zero_division=0),\n", + " 'precision_micro': precision_score(y_true, y_pred, average='micro', zero_division=0),\n", + " 'recall_micro': recall_score(y_true, y_pred, average='micro', zero_division=0),\n", + " }\n", + " \n", + " # Sample-based metrics (per sentence, 6 classes)\n", + " # Note: sentences where both true and predicted are all-zero (None) score 0/0.\n", + " # Using zero_division=1 so correct \"None\" predictions score 1.0, not 0.0.\n", + " sample = {\n", + " 'f1_samples': f1_score(y_true, y_pred, average='samples', zero_division=1),\n", + " 'jaccard_samples': jaccard_score(y_true, y_pred, average='samples', zero_division=1),\n", + " }\n", + " \n", + " # Overall metrics\n", + " overall = {\n", + " 'exact_match_ratio': (y_pred == y_true).all(axis=1).mean(),\n", + " 'hamming_loss': hamming_loss(y_true, y_pred),\n", + " }\n", + " \n", + " # None class (derived: all classes = 0, reported separately)\n", + " none_true = (y_true.sum(axis=1) == 0).astype(int)\n", + " none_pred = (y_pred.sum(axis=1) == 0).astype(int)\n", + " \n", + " none = {\n", + " 'f1': f1_score(none_true, none_pred, zero_division=0),\n", + " 'precision': precision_score(none_true, none_pred, zero_division=0),\n", + " 'recall': recall_score(none_true, none_pred, zero_division=0),\n", + " 'mcc': matthews_corrcoef(none_true, none_pred),\n", + " 'support_true': int(none_true.sum()),\n", + " 'support_pred': int(none_pred.sum()),\n", + " }\n", + " \n", + " return {\n", + " 'per_class': per_class,\n", + " 'macro': macro,\n", + " 'micro': micro,\n", + " 'sample': sample,\n", + " 'overall': overall,\n", + " 'none': none,\n", + " }\n", + "\n", + "\n", + "def print_evaluation_report(model_name, metrics, class_names):\n", + " \"\"\"Print formatted evaluation report for one model.\"\"\"\n", + " \n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"MODEL: {model_name}\")\n", + " print(f\"{'='*70}\")\n", + " \n", + " # Per-class table\n", + " print(f\"\\n{'Class':<12} {'F1':>7} {'MCC':>7} {'Prec':>7} {'Rec':>7} {'True':>6} {'Pred':>6}\")\n", + " print(\"-\"*58)\n", + " \n", + " for cls in class_names:\n", + " m = metrics['per_class'][cls]\n", + " print(f\"{cls:<12} {m['f1']:>7.4f} {m['mcc']:>7.4f} \"\n", + " f\"{m['precision']:>7.4f} {m['recall']:>7.4f} {m['support_true']:>6} {m['support_pred']:>6}\")\n", + " \n", + " # None (derived, separate from macro)\n", + " print(\"-\"*58)\n", + " n = metrics['none']\n", + " print(f\"{'None':<12} {n['f1']:>7.4f} {n['mcc']:>7.4f} \"\n", + " f\"{n['precision']:>7.4f} {n['recall']:>7.4f} {n['support_true']:>6} {n['support_pred']:>6}\")\n", + " print(f\"{'':>12} (derived: all classes = 0, excluded from Macro)\")\n", + " \n", + " # Macro averages (6 functional classes)\n", + " print(\"-\"*58)\n", + " m = metrics['macro']\n", + " print(f\"{'Macro Avg':<12} {m['f1_macro']:>7.4f} {m['mcc_macro']:>7.4f} \"\n", + " f\"{m['precision_macro']:>7.4f} {m['recall_macro']:>7.4f}\")\n", + " print(f\"{'(6 classes)':>12}\")\n", + " \n", + " # Summary\n", + " print(\"\\n--- Summary Metrics ---\")\n", + " print(f\" Macro F1: {metrics['macro']['f1_macro']:.4f} ← Primary (class-balanced, 6 classes)\")\n", + " print(f\" Macro MCC: {metrics['macro']['mcc_macro']:.4f} ← Robust to imbalance\")\n", + " print(f\" Micro F1: {metrics['micro']['f1_micro']:.4f} ← Pooled across 6 classes\")\n", + " print(f\" Sample F1: {metrics['sample']['f1_samples']:.4f} ← Per-sentence average\")\n", + " print(f\" Sample Jaccard: {metrics['sample']['jaccard_samples']:.4f} ← Per-sentence |∩|/|∪|\")\n", + " print(f\" Exact Match: {metrics['overall']['exact_match_ratio']:.4f} ← All 6 labels correct\")\n", + " print(f\" Hamming Loss: {metrics['overall']['hamming_loss']:.4f} ← Fraction wrong (lower=better)\")\n", + "\n", + " # Interpretation\n", + " print(\"\\n--- Interpretation ---\")\n", + " gap = metrics['micro']['f1_micro'] - metrics['macro']['f1_macro']\n", + " print(f\" Macro F1 ({metrics['macro']['f1_macro']:.4f}) treats 6 functional classes equally;\")\n", + " print(f\" Micro F1 ({metrics['micro']['f1_micro']:.4f}) pools all labels, weighted by frequency.\")\n", + " if gap > 0.05:\n", + " print(f\" Gap of {gap:.4f} suggests rare classes (IT, HR) drag down Macro F1.\")\n", + " \n", + " print(f\" Sample F1 ({metrics['sample']['f1_samples']:.4f}) and Jaccard ({metrics['sample']['jaccard_samples']:.4f}):\")\n", + " print(f\" Per-sentence evaluation. Correct 'None' predictions (both empty) score 1.0.\")\n", + " print(f\" Exact Match ({metrics['overall']['exact_match_ratio']:.4f}): {metrics['overall']['exact_match_ratio']*100:.1f}% of sentences had ALL labels correct.\")\n", + " print(f\" Binary per sentence: one wrong label fails the entire sentence.\")\n", + " print(f\" Hamming Loss ({metrics['overall']['hamming_loss']:.4f}): {metrics['overall']['hamming_loss']*100:.1f}% of individual label decisions are wrong.\")\n", + " print(f\" MCC ({metrics['macro']['mcc_macro']:.4f}): Balanced metric even with class imbalance. >0.70 is good.\")\n", + " print(f\" None (derived): F1 = {metrics['none']['f1']:.4f}, MCC = {metrics['none']['mcc']:.4f}\")\n", + " print(f\" Not independently predicted — excluded from Macro to avoid inflating scores.\")\n", + "\n", + " # For paper\n", + " print(\"\\n--- For Paper/Report ---\")\n", + " print(f\" Primary: Macro F1 = {metrics['macro']['f1_macro']:.4f}\")\n", + " print(f\" Secondary: Macro MCC = {metrics['macro']['mcc_macro']:.4f}\")\n", + " print(f\" Overall: Exact Match = {metrics['overall']['exact_match_ratio']:.4f}, Hamming Loss = {metrics['overall']['hamming_loss']:.4f}\")\n", + "\n", + "\n", + "# -----------------------------\n", + "# Evaluate Each Model\n", + "# -----------------------------\n", + "\n", + "all_metrics = {}\n", + "\n", + "for mv_file in sorted(mv_files):\n", + " filename = mv_file.split('/')[-1]\n", + " #model_name = filename.replace('_majority_vote.pkl', '')\n", + " model_name = filename.replace('_majority_vote.pkl', '')\n", + " # Strip dataset prefix for clean display\n", + " for prefix in ['holdout_', 'train_', 'prelim_', 'output_']:\n", + " if model_name.startswith(prefix):\n", + " model_name = model_name[len(prefix):]\n", + " break\n", + " \n", + " # Load model predictions\n", + " df_model = pd.read_pickle(mv_file)\n", + " \n", + " # Merge with human labels on sentence (6 functional classes only)\n", + " df_merged = df_human[['sentence'] + CLASSES].merge(\n", + " df_model[['sentence'] + CLASSES],\n", + " on='sentence',\n", + " how='inner',\n", + " suffixes=('_true', '_pred'),\n", + " )\n", + " \n", + " # Check match rate\n", + " match_rate = len(df_merged) / len(df_human) * 100\n", + " print(f\"\\n{'─'*70}\")\n", + " print(f\"Model: {model_name}\")\n", + " print(f\"Matched sentences: {len(df_merged)}/{len(df_human)} ({match_rate:.1f}%)\")\n", + " \n", + " if len(df_merged) == 0:\n", + " print(\" WARNING: No matching sentences! Skipping.\")\n", + " continue\n", + " \n", + " # Extract true labels (human) — 6 functional classes\n", + " y_true = df_merged[[f'{cls}_true' for cls in CLASSES]].values.astype(int)\n", + " \n", + " # Extract predicted labels (model majority vote) — 6 functional classes\n", + " y_pred = df_merged[[f'{cls}_pred' for cls in CLASSES]].values.astype(int)\n", + " \n", + " # Evaluate\n", + " metrics = evaluate_multilabel(y_true, y_pred, CLASSES)\n", + " all_metrics[model_name] = metrics\n", + " \n", + " # Print report\n", + " print_evaluation_report(model_name, metrics, CLASSES)\n", + " \n", + " # Confusion matrices\n", + " print(f\"\\n--- Confusion Matrices ---\")\n", + " print(f\" {'Class':<12} {'TP':>6} {'TN':>6} {'FP':>6} {'FN':>6}\")\n", + " print(f\" {'-'*38}\")\n", + " \n", + " for i, cls in enumerate(CLASSES):\n", + " cm = confusion_matrix(y_true[:, i], y_pred[:, i])\n", + " tn, fp, fn, tp = cm.ravel()\n", + " print(f\" {cls:<12} {tp:>6} {tn:>6} {fp:>6} {fn:>6}\")\n", + " \n", + " # None confusion matrix (derived)\n", + " none_true = (y_true.sum(axis=1) == 0).astype(int)\n", + " none_pred = (y_pred.sum(axis=1) == 0).astype(int)\n", + " none_cm = confusion_matrix(none_true, none_pred)\n", + " none_tn, none_fp, none_fn, none_tp = none_cm.ravel()\n", + " print(f\" {'-'*38}\")\n", + " print(f\" {'None':<12} {none_tp:>6} {none_tn:>6} {none_fp:>6} {none_fn:>6}\")\n", + "\n", + "# -----------------------------\n", + "# Comparison Summary\n", + "# -----------------------------\n", + "\n", + "if len(all_metrics) > 1:\n", + " print(f\"\\n{'='*70}\")\n", + " print(\"MODEL COMPARISON SUMMARY\")\n", + " print(f\"{'='*70}\")\n", + " \n", + " # Summary table\n", + " print(f\"\\n{'Model':<40} {'F1 Mac':>8} {'MCC Mac':>8} {'Exact':>8} {'Hamming':>8}\")\n", + " print(\"-\"*68)\n", + " \n", + " summary_rows = []\n", + " for model_name, metrics in sorted(all_metrics.items(), \n", + " key=lambda x: x[1]['macro']['f1_macro'], \n", + " reverse=True):\n", + " m = metrics['macro']\n", + " o = metrics['overall']\n", + " print(f\"{model_name:<40} {m['f1_macro']:>8.4f} {m['mcc_macro']:>8.4f} \"\n", + " f\"{o['exact_match_ratio']:>8.4f} {o['hamming_loss']:>8.4f}\")\n", + " \n", + " summary_rows.append({\n", + " 'model': model_name,\n", + " 'f1_macro': m['f1_macro'],\n", + " 'mcc_macro': m['mcc_macro'],\n", + " 'f1_micro': metrics['micro']['f1_micro'],\n", + " 'f1_samples': metrics['sample']['f1_samples'],\n", + " 'jaccard_samples': metrics['sample']['jaccard_samples'],\n", + " 'exact_match': o['exact_match_ratio'],\n", + " 'hamming_loss': o['hamming_loss'],\n", + " 'none_f1': metrics['none']['f1'],\n", + " 'none_mcc': metrics['none']['mcc'],\n", + " })\n", + " \n", + " # Per-class F1 comparison\n", + " print(f\"\\n--- Per-Class F1 Comparison ---\")\n", + " header = f\"{'Class':<12}\"\n", + " for model_name in sorted(all_metrics.keys()):\n", + " short = model_name.replace('azure_', '').replace('fireworks_', 'fw_')\n", + " header += f\" {short:>20}\"\n", + " print(header)\n", + " print(\"-\" * (12 + 21 * len(all_metrics)))\n", + " \n", + " for cls in CLASSES + ['None', 'Macro Avg']:\n", + " row = f\"{cls:<12}\"\n", + " if cls == 'None':\n", + " row = f\"{'-'*12}\\n{cls:<12}\"\n", + " for model_name in sorted(all_metrics.keys()):\n", + " if cls == 'Macro Avg':\n", + " val = all_metrics[model_name]['macro']['f1_macro']\n", + " elif cls == 'None':\n", + " val = all_metrics[model_name]['none']['f1']\n", + " else:\n", + " val = all_metrics[model_name]['per_class'][cls]['f1']\n", + " row += f\" {val:>20.4f}\"\n", + " print(row)\n", + " \n", + " # Save comparison\n", + " df_summary = pd.DataFrame(summary_rows).sort_values('f1_macro', ascending=False)\n", + " df_summary.to_csv(f\"{HOLDOUT_DIR}/model_vs_human_comparison.csv\", index=False)\n", + " print(f\"\\nSaved comparison to {HOLDOUT_DIR}/model_vs_human_comparison.csv\")\n", + "\n", + "# -----------------------------\n", + "# Save all metrics\n", + "# -----------------------------\n", + "\n", + "def convert_for_json(obj):\n", + " \"\"\"Convert numpy types for JSON serialization.\"\"\"\n", + " if isinstance(obj, (np.integer,)):\n", + " return int(obj)\n", + " if isinstance(obj, (np.floating,)):\n", + " return float(obj)\n", + " if isinstance(obj, np.ndarray):\n", + " return obj.tolist()\n", + " if isinstance(obj, dict):\n", + " return {k: convert_for_json(v) for k, v in obj.items()}\n", + " return obj\n", + "\n", + "all_metrics_json = convert_for_json(all_metrics)\n", + "with open(f\"{HOLDOUT_DIR}/model_vs_human_metrics.json\", 'w') as f:\n", + " json.dump(all_metrics_json, f, indent=2)\n", + "print(f\"Saved all metrics to {HOLDOUT_DIR}/model_vs_human_metrics.json\")\n", + "\n", + "print(f\"\\n{'='*70}\")\n", + "print(\"DONE\")\n", + "print(f\"{'='*70}\")" + ] + }, + { + "cell_type": "markdown", + "id": "0af919a8-c95a-4598-847e-142fedee656c", + "metadata": {}, + "source": [ + "# Label Training Data with genAI\n", + "\n", + "- Determine which genAI model (or models) does on holdout\n", + "- Assume that this will also be the best model for labeling the train data\n", + "- Label train data at least 3 times (i.e. 3 runs: [1,2,3] with best genAI model (or models)\n", + "- Build final train set with majority votes across runs (and models if you want to pool genAI models for possibly better performance)\n", + "> I am giving an example here using GPT-4.1. ***This wil most likely not be the best model! Try others on the holdout!***" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45c897fb-16e2-4ebe-abaf-9ab6be1dffe2", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# OUTPUT CONFIGURATION - SET for Train\n", + "# =============================================================================\n", + "CHECKPOINT_DIR = \"./checkpoints\"\n", + "OUTPUT_DIR = \"./output\"\n", + "TRAIN_DIR = f\"{OUTPUT_DIR}/train\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72367ec1-6d23-44e3-81f5-826f53518666", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# Load Train dataset\n", + "# =============================================================================\n", + "\n", + "# # Load Data: Labeling of (balanced) holdout data\n", + "df = pd.read_pickle(\"Holdout_Train/train.pkl\")\n", + "df = df.head(101)\n", + "\n", + "# =============================================================================\n", + "# Query genAI model via API\n", + "# =============================================================================\n", + "\n", + "# Here an example for gpt-4.1 via UNC Azure API. \n", + "# Most likely, this will not be the best model! Try other vendors and other models.\n", + "results = run_classification(\n", + " df,\n", + " sentence_col=\"sentence\", # --> Check if your file has the column as named here where the text is (sentence vs sentences vs text vs tweet vs ... )\n", + " vendor=\"azure\",\n", + " model=\"gpt-4.1\",\n", + " runs=[1,2,3], # Doing 3 runs for conistency / replicability: will later take majority vote per lable across runs\n", + " output_dir=TRAIN_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n", + " checkpoint_dir=CHECKPOINT_DIR, # overrides default\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "16420fc3-a55a-4e90-a306-38f52ff8a3f6", + "metadata": {}, + "source": [ + "# Majority Votes\n", + "\n", + "- Across Runs\n", + "- If you are ***pooling models***, then you need to also do this ***Across Models***: THIS CODE DOES NOT DO THAT!\n", + "- Does not handle ties yet (assumes odd number of total runs across models)\n", + "- Assumes multi-label problem (one sentence can have multiple labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45172423-8146-478a-bfcf-7060e52f60c0", + "metadata": {}, + "outputs": [], + "source": [ + "# =============================================================================\n", + "# MAJORITY VOTE PER MODEL (across its own runs)\n", + "# =============================================================================\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import glob\n", + "import ast\n", + "import re\n", + "import os\n", + "\n", + "# -----------------------------\n", + "# Configuration\n", + "# -----------------------------\n", + "\n", + "TRAIN_DIR = \"./output/train\" # Now pointing to train (all code is the same, but we are sing a differnt folder for the files\n", + "RUNS = [1, 2, 3]\n", + "CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07170d67-f36c-45c6-a41c-2843cc5c52e0", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Helper Functions\n", + "# -----------------------------\n", + "\n", + "def parse_labels_safe(labels):\n", + " \"\"\"Parse labels from various formats.\"\"\"\n", + " if labels is None:\n", + " return []\n", + " if isinstance(labels, list):\n", + " return labels\n", + " if isinstance(labels, str):\n", + " try:\n", + " parsed = ast.literal_eval(labels)\n", + " return parsed if isinstance(parsed, list) else []\n", + " except:\n", + " return []\n", + " return []\n", + "\n", + "def parse_filename(filename):\n", + " \"\"\"\n", + " Parse filename like 'train_azure_gpt-4.1_chat_run1.pkl'\n", + " into (vendor, model, mode, run).\n", + " Handles optional dataset prefix (holdout_, train_, prelim_, output_).\n", + " \"\"\"\n", + " name = filename.replace('.pkl', '').replace('.csv', '')\n", + " \n", + " # Strip known dataset prefixes\n", + " for prefix in ['holdout_', 'train_', 'prelim_', 'output_']:\n", + " if name.startswith(prefix):\n", + " name = name[len(prefix):]\n", + " break\n", + " \n", + " match = re.search(r'_run(\\d+)$', name)\n", + " if not match:\n", + " return None\n", + " run = int(match.group(1))\n", + " name = name[:match.start()]\n", + " parts = name.rsplit('_', 1)\n", + " if len(parts) != 2:\n", + " return None\n", + " mode = parts[1]\n", + " remainder = parts[0]\n", + " first_underscore = remainder.index('_')\n", + " vendor = remainder[:first_underscore]\n", + " model = remainder[first_underscore + 1:]\n", + " return vendor, model, mode, run\n", + "\n", + "def discover_models(train_dir, runs):\n", + " \"\"\"Find all vendor_model_mode combos that have ALL specified runs.\"\"\"\n", + " models = {}\n", + " for run in runs:\n", + " run_dir = f\"{train_dir}/Run{run}\"\n", + " if not os.path.exists(run_dir):\n", + " continue\n", + " for filepath in glob.glob(f\"{run_dir}/*.pkl\"):\n", + " filename = filepath.split(\"/\")[-1]\n", + " parsed = parse_filename(filename)\n", + " if parsed is None:\n", + " continue\n", + " vendor, model, mode, file_run = parsed\n", + " if file_run != run:\n", + " continue\n", + " key = (vendor, model, mode)\n", + " if key not in models:\n", + " models[key] = {}\n", + " models[key][run] = filepath\n", + " \n", + " # Keep only models with ALL specified runs\n", + " return {k: v for k, v in models.items() if all(r in v for r in runs)}\n", + "\n", + "def compute_majority_vote(model_runs, classes, runs):\n", + " \"\"\"\n", + " Compute majority vote for a single model across its runs.\n", + " \n", + " Votes are counted for the 6 functional classes only.\n", + " None is derived: if no functional class gets majority, None = 1.\n", + " None_votes tracks how many runs returned an empty label list (for reference).\n", + " \n", + " Returns DataFrame with sentence + binary labels + vote counts.\n", + " \"\"\"\n", + " # Load all runs\n", + " run_dfs = {}\n", + " for run in runs:\n", + " run_dfs[run] = pd.read_pickle(model_runs[run])\n", + " \n", + " # Collect votes per sentence\n", + " votes = {}\n", + " for run in runs:\n", + " df = run_dfs[run]\n", + " for _, row in df.iterrows():\n", + " sent_id = row['id']\n", + " sentence = row['sentence']\n", + " labels = parse_labels_safe(row['labels'])\n", + " \n", + " # Skip error rows\n", + " if 'error' in df.columns and pd.notna(row.get('error')):\n", + " continue\n", + " \n", + " if sent_id not in votes:\n", + " votes[sent_id] = {\n", + " 'sentence': sentence,\n", + " **{cls: 0 for cls in classes},\n", + " 'none_runs': 0,\n", + " 'total_runs': 0,\n", + " }\n", + " \n", + " votes[sent_id]['total_runs'] += 1\n", + " \n", + " if len(labels) == 0:\n", + " votes[sent_id]['none_runs'] += 1\n", + " else:\n", + " for label in labels:\n", + " if label in classes:\n", + " votes[sent_id][label] += 1\n", + " \n", + " # Calculate majority\n", + " results = []\n", + " for sent_id, data in votes.items():\n", + " total = data['total_runs']\n", + " threshold = total / 2 # > 50%\n", + " \n", + " row = {'id': sent_id, 'sentence': data['sentence']}\n", + " \n", + " # Majority vote on 6 functional classes\n", + " for cls in classes:\n", + " row[cls] = 1 if data[cls] > threshold else 0\n", + " \n", + " # None is derived: all functional classes = 0\n", + " row['None'] = 1 if all(row[cls] == 0 for cls in classes) else 0\n", + " \n", + " # Vote counts (for reference / pseudo-probabilities)\n", + " for cls in classes:\n", + " row[f'{cls}_votes'] = data[cls]\n", + " row['None_votes'] = data['none_runs']\n", + " row['total_runs'] = total\n", + " \n", + " results.append(row)\n", + " \n", + " df_result = pd.DataFrame(results)\n", + " label_cols = classes + ['None']\n", + " vote_cols = [f'{cls}_votes' for cls in classes] + ['None_votes', 'total_runs']\n", + " df_result = df_result[['id', 'sentence'] + label_cols + vote_cols]\n", + " df_result = df_result.sort_values('id').reset_index(drop=True)\n", + " \n", + " return df_result\n", + "\n", + "# -----------------------------\n", + "# Main\n", + "# -----------------------------\n", + "\n", + "print(\"=\"*70)\n", + "print(\"MAJORITY VOTE PER MODEL\")\n", + "print(\"=\"*70)\n", + "print(f\"Train dir: {TRAIN_DIR}\")\n", + "print(f\"Runs: {RUNS}\")\n", + "\n", + "# Discover models\n", + "models = discover_models(TRAIN_DIR, RUNS)\n", + "\n", + "print(f\"\\nFound {len(models)} model(s) with all {len(RUNS)} runs:\")\n", + "for (vendor, model, mode), paths in models.items():\n", + " print(f\" {vendor} / {model} / {mode}\")\n", + "\n", + "# Process each model\n", + "all_majority = {}\n", + "\n", + "for (vendor, model, mode), paths in models.items():\n", + " model_key = f\"{vendor}_{model}_{mode}\"\n", + " \n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"MODEL: {vendor} / {model} / {mode}\")\n", + " print(f\"{'='*70}\")\n", + " \n", + " # Compute majority vote\n", + " df_mv = compute_majority_vote(paths, CLASSES, RUNS)\n", + " all_majority[model_key] = df_mv\n", + " \n", + " # Report\n", + " print(f\"Total sentences: {len(df_mv)}\")\n", + " print(f\"Runs per sentence: {df_mv['total_runs'].iloc[0]}\")\n", + " \n", + " print(f\"\\n{'Class':<15} {'Count':>8} {'Share':>10}\")\n", + " print(\"-\"*35)\n", + " for cls in CLASSES:\n", + " count = df_mv[cls].sum()\n", + " share = count / len(df_mv) * 100\n", + " print(f\"{cls:<15} {count:>8} {share:>9.1f}%\")\n", + " \n", + " # None (derived)\n", + " none_count = df_mv['None'].sum()\n", + " none_share = none_count / len(df_mv) * 100\n", + " print(\"-\"*35)\n", + " print(f\"{'None (derived)':<15} {none_count:>8} {none_share:>9.1f}%\")\n", + " \n", + " # Multi-label distribution\n", + " num_labels = df_mv[CLASSES].sum(axis=1)\n", + " print(f\"\\n{'# Labels':<15} {'Count':>8} {'Share':>10}\")\n", + " print(\"-\"*35)\n", + " for n in range(7):\n", + " count = (num_labels == n).sum()\n", + " if count > 0:\n", + " share = count / len(df_mv) * 100\n", + " label_text = f\"{n} label{'s' if n != 1 else ''}\"\n", + " print(f\"{label_text:<15} {count:>8} {share:>9.1f}%\")\n", + " \n", + " # Save (with dataset prefix for consistency)\n", + " dataset_tag = os.path.basename(os.path.normpath(TRAIN_DIR)).lower()\n", + " save_path = f\"{TRAIN_DIR}/{dataset_tag}_{model_key}_majority_vote\"\n", + " df_mv.to_csv(f\"{save_path}.csv\", index=False)\n", + " df_mv.to_pickle(f\"{save_path}.pkl\")\n", + " print(f\"\\nSaved: {save_path}.csv (.pkl)\")\n", + "\n", + "# -----------------------------\n", + "# Summary across models\n", + "# -----------------------------\n", + "\n", + "if len(all_majority) > 1:\n", + " print(f\"\\n{'='*70}\")\n", + " print(\"COMPARISON ACROSS MODELS\")\n", + " print(f\"{'='*70}\")\n", + " \n", + " # Header\n", + " model_keys = list(all_majority.keys())\n", + " header = f\"{'Class':<15}\"\n", + " for key in model_keys:\n", + " short = key.split('_', 1)[1] # Remove vendor prefix for display\n", + " header += f\" {short:>20}\"\n", + " print(header)\n", + " print(\"-\" * (15 + 21 * len(model_keys)))\n", + " \n", + " for cls in CLASSES:\n", + " row = f\"{cls:<15}\"\n", + " for key in model_keys:\n", + " df_mv = all_majority[key]\n", + " count = df_mv[cls].sum()\n", + " share = count / len(df_mv) * 100\n", + " row += f\" {count:>8} ({share:>5.1f}%)\"\n", + " print(row)\n", + " \n", + " # None row\n", + " row = f\"{'None (derived)':<15}\"\n", + " for key in model_keys:\n", + " df_mv = all_majority[key]\n", + " count = df_mv['None'].sum()\n", + " share = count / len(df_mv) * 100\n", + " row += f\" {count:>8} ({share:>5.1f}%)\"\n", + " print(\"-\" * (15 + 21 * len(model_keys)))\n", + " print(row)\n", + "\n", + "print(f\"\\n{'='*70}\")\n", + "print(\"DONE\")\n", + "print(f\"{'='*70}\")" + ] + }, + { + "cell_type": "markdown", + "id": "43c77538-5661-4801-ad7f-aaa06c6eb828", + "metadata": {}, + "source": [ + "# Fine-Tune a pretrained LLM to create a Vertical AI model\n", + "- We now have training data for fine-tuning a pretrained LLM (here, RoBERTa Large) to become a classifier for business functions\n", + "- You can experiment with different pretrained LLMs on Huggingface that may be more appropriate for your fine-tuning purpose:\n", + "> https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers&sort=trending" + ] + }, + { + "cell_type": "markdown", + "id": "2ac54fc6-18a0-4577-941c-4aac33d2dd07", + "metadata": {}, + "source": [ + "### **IMPORTANT**: There is a hyperparameter that determines how long text (in tokens) can be at most for fine-tuning.\n", + "```\n", + "MAX_LENGTH = 128 # max length of sentences in tokens\n", + "```\n", + "> * I set it to 128. If your texts are longer, you need to change that because it will be truncated! \n", + "> * The larger the max length is, the slower the training process. \n", + "> * Why? See class 8 on Deep Learning. Hint: When you have more input tokens, more needs to be embedded and contextualized, which takes computer and RAM." + ] + }, + { + "cell_type": "markdown", + "id": "a912ff49-3530-42a4-bbc8-624b9b7b695f", + "metadata": {}, + "source": [ + "## Part 1: Imports and Configuration\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "42db0c31-14f2-4354-9e51-cc81a1ca85bc", + "metadata": {}, + "outputs": [], + "source": [ + "# Install if needed:\n", + "# pip install -q -U transformers datasets accelerate scikit-learn\n", + "\n", + "import os\n", + "import torch\n", + "import numpy as np\n", + "import pandas as pd\n", + "import json\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import (\n", + " f1_score,\n", + " roc_auc_score,\n", + " matthews_corrcoef,\n", + " precision_score,\n", + " recall_score,\n", + " classification_report,\n", + " confusion_matrix,\n", + " hamming_loss,\n", + " jaccard_score,\n", + ")\n", + "from transformers import (\n", + " RobertaTokenizer,\n", + " RobertaForSequenceClassification,\n", + " Trainer,\n", + " TrainingArguments,\n", + " EarlyStoppingCallback,\n", + ")\n", + "from datasets import Dataset\n", + "import warnings\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "# -----------------------------\n", + "# Configuration\n", + "# -----------------------------\n", + "\n", + "MODEL_NAME = \"roberta-large\"\n", + "CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]\n", + "NUM_LABELS = len(CLASSES)\n", + "\n", + "# Hyperparameters\n", + "LEARNING_RATE = 2e-5\n", + "WEIGHT_DECAY = 0.01\n", + "WARMUP_RATIO = 0.1\n", + "NUM_EPOCHS = 4\n", + "BATCH_SIZE = 32 # Reduce if OOM (out of memory)\n", + "GRADIENT_ACCUMULATION = 2 # Effective batch size = 32 * 2 = 54\n", + "MAX_LENGTH = 128 # max length of sentences in tokens\n", + "SEED = 42\n", + "\n", + "# Early stopping\n", + "EARLY_STOPPING_PATIENCE = 2\n", + "\n", + "# Threshold for binary predictions\n", + "THRESHOLD = 0.5\n", + "\n", + "# Mixed precision (True = faster, less memory; False = more stable)\n", + "USE_MIXED_PRECISION = True\n", + "\n", + "# -----------------------------\n", + "# Device Selection (CUDA > MPS > CPU)\n", + "# -----------------------------\n", + "\n", + "def get_device_and_precision():\n", + " \"\"\"Detect best available device and set precision flags.\"\"\"\n", + " if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + " device_name = torch.cuda.get_device_name(0)\n", + " fp16 = USE_MIXED_PRECISION\n", + " bf16 = False\n", + " use_mps = False\n", + " print(f\"Device: CUDA ({device_name})\")\n", + " print(f\" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n", + " \n", + " elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + " fp16 = False\n", + " bf16 = USE_MIXED_PRECISION\n", + " use_mps = True\n", + " print(f\"Device: Apple Silicon (MPS)\")\n", + " os.environ[\"PYTORCH_MPS_HIGH_WATERMARK_RATIO\"] = \"0.0\"\n", + " \n", + " else:\n", + " device = torch.device(\"cpu\")\n", + " fp16 = False\n", + " bf16 = False\n", + " use_mps = False\n", + " print(f\"Device: CPU (this will be slow!)\")\n", + " \n", + " precision = \"mixed (fp16)\" if fp16 else \"mixed (bf16)\" if bf16 else \"full (fp32)\"\n", + " print(f\" Precision: {precision}\")\n", + " \n", + " return device, fp16, bf16, use_mps\n", + "\n", + "device, use_fp16, use_bf16, use_mps = get_device_and_precision()\n", + "\n", + "print(f\"\\nConfiguration:\")\n", + "print(f\" Model: {MODEL_NAME}\")\n", + "print(f\" Classes: {CLASSES}\")\n", + "print(f\" Epochs: {NUM_EPOCHS}\")\n", + "print(f\" Batch size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}\")\n", + "print(f\" Learning rate: {LEARNING_RATE}\")\n", + "print(f\" Max length: {MAX_LENGTH}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c3c3e7e6-8c47-4690-a490-b49dd3c08edb", + "metadata": {}, + "source": [ + "## Part 2: Prepare Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54863fc9-94f2-44ae-8ac8-c80b459b10da", + "metadata": {}, + "outputs": [], + "source": [ + "#!pip install iterative-stratification" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81180705-cf64-4f5e-acd4-34e20e4cca39", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Load Data\n", + "# -----------------------------\n", + "\n", + "print(\"=\"*60)\n", + "print(\"LOADING DATA\")\n", + "print(\"=\"*60)\n", + "\n", + "df_train_full = pd.read_pickle(\"./output/train/train_azure_gpt-4.1_chat_majority_vote.pkl\")\n", + "df_holdout = pd.read_pickle(\"./Holdout_Train/holdout_human.pkl\")\n", + "\n", + "print(f\"Train (full): {len(df_train_full)}\")\n", + "print(f\"Holdout: {len(df_holdout)}\")\n", + "\n", + "# -----------------------------\n", + "# Where to save model and results\n", + "# -----------------------------\n", + "model_save_path = \"./roberta_multilabel/best_model\"\n", + "\n", + "\n", + "# -----------------------------\n", + "# Train/Validation Split (90/10)\n", + "# -----------------------------\n", + "\n", + "# Option 1: Iterative stratification (best for multi-label)\n", + "# pip install iterative-stratification\n", + "try:\n", + " from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit\n", + " \n", + " msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)\n", + " train_idx, val_idx = next(msss.split(df_train_full, df_train_full[CLASSES].values))\n", + " \n", + " df_train = df_train_full.iloc[train_idx].copy()\n", + " df_val = df_train_full.iloc[val_idx].copy()\n", + " print(\"Using: Iterative multi-label stratification\")\n", + "\n", + "except ImportError:\n", + " # Option 2: Simple random split (no stratification)\n", + " df_train, df_val = train_test_split(\n", + " df_train_full, \n", + " test_size=0.2, \n", + " random_state=SEED,\n", + " )\n", + " print(\"Using: Random split (install 'iterative-stratification' for better stratification)\")\n", + "\n", + "print(f\"\\nSplit:\")\n", + "print(f\" Train: {len(df_train)}\")\n", + "print(f\" Validation: {len(df_val)}\")\n", + "print(f\" Holdout: {len(df_holdout)}\")\n", + "\n", + "# -----------------------------\n", + "# Extract Labels\n", + "# -----------------------------\n", + "\n", + "def get_labels_array(df):\n", + " \"\"\"Extract labels as numpy array.\"\"\"\n", + " return df[CLASSES].values.astype(np.float32)\n", + "\n", + "train_labels = get_labels_array(df_train)\n", + "val_labels = get_labels_array(df_val)\n", + "holdout_labels = get_labels_array(df_holdout)\n", + "\n", + "print(f\"\\nLabel shapes:\")\n", + "print(f\" Train: {train_labels.shape}\")\n", + "print(f\" Val: {val_labels.shape}\")\n", + "print(f\" Holdout: {holdout_labels.shape}\")\n", + "\n", + "# Class distribution\n", + "print(f\"\\nTrain class distribution:\")\n", + "for i, cls in enumerate(CLASSES):\n", + " count = train_labels[:, i].sum()\n", + " pct = count / len(train_labels) * 100\n", + " print(f\" {cls}: {int(count)} ({pct:.1f}%)\")\n", + "\n", + "# Verify validation has all classes\n", + "print(f\"\\nValidation class distribution:\")\n", + "for i, cls in enumerate(CLASSES):\n", + " count = val_labels[:, i].sum()\n", + " pct = count / len(val_labels) * 100\n", + " print(f\" {cls}: {int(count)} ({pct:.1f}%)\")\n", + "\n", + "# -----------------------------\n", + "# Tokenization\n", + "# -----------------------------\n", + "\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"TOKENIZATION\")\n", + "print(\"=\"*60)\n", + "\n", + "tokenizer = RobertaTokenizer.from_pretrained(MODEL_NAME)\n", + "\n", + "def tokenize_data(texts, labels):\n", + " \"\"\"Tokenize texts and create HuggingFace Dataset.\"\"\"\n", + " encodings = tokenizer(\n", + " texts.tolist(),\n", + " truncation=True,\n", + " padding='max_length',\n", + " max_length=MAX_LENGTH,\n", + " return_tensors=None,\n", + " )\n", + " \n", + " dataset = Dataset.from_dict({\n", + " 'input_ids': encodings['input_ids'],\n", + " 'attention_mask': encodings['attention_mask'],\n", + " 'labels': labels.tolist(),\n", + " })\n", + " \n", + " return dataset\n", + "\n", + "print(\"Tokenizing datasets...\")\n", + "train_dataset = tokenize_data(df_train['sentence'].values, train_labels)\n", + "val_dataset = tokenize_data(df_val['sentence'].values, val_labels)\n", + "holdout_dataset = tokenize_data(df_holdout['sentence'].values, holdout_labels)\n", + "\n", + "print(f\" Train: {len(train_dataset)}\")\n", + "print(f\" Validation: {len(val_dataset)}\")\n", + "print(f\" Holdout: {len(holdout_dataset)}\")\n", + "\n", + "# Sequence length stats\n", + "train_lengths = [len(ids) for ids in train_dataset['input_ids']]\n", + "print(f\" Sequence lengths: min={min(train_lengths)}, max={max(train_lengths)}, mean={np.mean(train_lengths):.0f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "e9fabf48-8b21-4303-b8d4-cc2d1c4a17cd", + "metadata": {}, + "source": [ + "## PART 3: FINE-TUNE MODEL" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0bdfa218-edd5-41af-bdab-5ccced4510f7", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Load Model\n", + "# -----------------------------\n", + "\n", + "print(\"=\"*60)\n", + "print(\"MODEL\")\n", + "print(\"=\"*60)\n", + "\n", + "model = RobertaForSequenceClassification.from_pretrained(\n", + " MODEL_NAME,\n", + " num_labels=NUM_LABELS,\n", + " problem_type=\"multi_label_classification\", # Enables BCE loss\n", + ")\n", + "\n", + "if not use_mps:\n", + " model.to(device)\n", + "\n", + "total_params = sum(p.numel() for p in model.parameters())\n", + "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "print(f\"Total parameters: {total_params:,}\")\n", + "print(f\"Trainable parameters: {trainable_params:,}\")\n", + "\n", + "# -----------------------------\n", + "# Metrics Function\n", + "# -----------------------------\n", + "\n", + "def compute_metrics(eval_pred):\n", + " \"\"\"Compute metrics for validation during training.\"\"\"\n", + " logits, labels = eval_pred\n", + " probs = torch.sigmoid(torch.tensor(logits)).numpy()\n", + " preds = (probs >= THRESHOLD).astype(int)\n", + " \n", + " f1_micro = f1_score(labels, preds, average='micro', zero_division=0)\n", + " f1_macro = f1_score(labels, preds, average='macro', zero_division=0)\n", + " f1_samples = f1_score(labels, preds, average='samples', zero_division=0)\n", + " \n", + " try:\n", + " auc_macro = roc_auc_score(labels, probs, average='macro')\n", + " except ValueError:\n", + " auc_macro = 0.0\n", + " \n", + " mcc_scores = []\n", + " for i in range(labels.shape[1]):\n", + " try:\n", + " mcc = matthews_corrcoef(labels[:, i], preds[:, i])\n", + " mcc_scores.append(mcc)\n", + " except:\n", + " mcc_scores.append(0.0)\n", + " mcc_macro = np.mean(mcc_scores)\n", + " \n", + " return {\n", + " 'f1_micro': f1_micro,\n", + " 'f1_macro': f1_macro,\n", + " 'f1_samples': f1_samples,\n", + " 'auc_macro': auc_macro,\n", + " 'mcc_macro': mcc_macro,\n", + " }\n", + "\n", + "# -----------------------------\n", + "# Training Arguments\n", + "# -----------------------------\n", + "\n", + "training_args = TrainingArguments(\n", + " output_dir=\"./roberta_multilabel\",\n", + " \n", + " # Training\n", + " num_train_epochs=NUM_EPOCHS,\n", + " per_device_train_batch_size=BATCH_SIZE,\n", + " per_device_eval_batch_size=BATCH_SIZE * 2,\n", + " gradient_accumulation_steps=GRADIENT_ACCUMULATION,\n", + " \n", + " # Optimizer\n", + " learning_rate=LEARNING_RATE,\n", + " weight_decay=WEIGHT_DECAY,\n", + " warmup_ratio=WARMUP_RATIO,\n", + " optim=\"adamw_torch\",\n", + " lr_scheduler_type=\"linear\",\n", + " max_grad_norm=1.0,\n", + " \n", + " # Evaluation & Saving\n", + " eval_strategy=\"epoch\",\n", + " save_strategy=\"epoch\",\n", + " load_best_model_at_end=True,\n", + " metric_for_best_model=\"f1_macro\",\n", + " greater_is_better=True,\n", + " save_total_limit=2,\n", + " \n", + " # Logging\n", + " logging_dir=\"./logs\",\n", + " logging_steps=50,\n", + " logging_first_step=True,\n", + " report_to=\"none\",\n", + " \n", + " # Precision\n", + " fp16=use_fp16,\n", + " bf16=use_bf16,\n", + " use_mps_device=use_mps,\n", + " \n", + " # Reproducibility\n", + " seed=SEED,\n", + " data_seed=SEED,\n", + " \n", + " # Performance\n", + " dataloader_num_workers=0 if use_mps else 2,\n", + " dataloader_pin_memory=not use_mps,\n", + " remove_unused_columns=False,\n", + ")\n", + "\n", + "# -----------------------------\n", + "# Trainer\n", + "# -----------------------------\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=val_dataset,\n", + " compute_metrics=compute_metrics,\n", + " callbacks=[\n", + " EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE),\n", + " ],\n", + ")\n", + "\n", + "# -----------------------------\n", + "# Train\n", + "# -----------------------------\n", + "\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"TRAINING\")\n", + "print(\"=\"*60)\n", + "\n", + "train_result = trainer.train()\n", + "\n", + "print(f\"\\n{'='*60}\")\n", + "print(\"TRAINING COMPLETE\")\n", + "print(\"=\"*60)\n", + "print(f\"Runtime: {train_result.metrics['train_runtime']:.0f} seconds\")\n", + "print(f\"Samples/second: {train_result.metrics['train_samples_per_second']:.1f}\")\n", + "print(f\"Final loss: {train_result.metrics['train_loss']:.4f}\")\n", + "\n", + "# Validation results\n", + "print(f\"\\nValidation (Best Model):\")\n", + "val_results = trainer.evaluate()\n", + "for key, value in val_results.items():\n", + " if isinstance(value, float):\n", + " print(f\" {key}: {value:.4f}\")\n", + "\n", + "# -----------------------------\n", + "# Save Model\n", + "# -----------------------------\n", + "\n", + "trainer.save_model(model_save_path)\n", + "tokenizer.save_pretrained(model_save_path)\n", + "print(f\"\\nModel saved to: {model_save_path}\")" + ] + }, + { + "cell_type": "markdown", + "id": "44acf40a-da7d-4aa3-912c-2bdf883fdc54", + "metadata": {}, + "source": [ + "## Part 4: Evaluate on Holdout (ground truth human expert labels)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df813061-90b2-4076-87ac-5c5be10ee048", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Set performance metric output path first! It is based in the model save path you defined earlier\n", + "# -----------------------------\n", + "\n", + "MODEL_DIR = model_save_path.replace(\"best_model\", \"performance\")\n", + "os.makedirs(MODEL_DIR, exist_ok=True) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fede292-3cfc-4a69-8dfd-dd21e224572d", + "metadata": {}, + "outputs": [], + "source": [ + "# -----------------------------\n", + "# Evaluation Functions\n", + "# -----------------------------\n", + "\n", + "def evaluate_multilabel(y_true, y_pred, y_prob, class_names):\n", + " \"\"\"\n", + " Comprehensive multi-label evaluation.\n", + " \n", + " Each label is treated as an independent binary decision.\n", + " Macro averaging gives equal weight to each class (recommended for imbalanced data).\n", + " \"\"\"\n", + " n_classes = len(class_names)\n", + " \n", + " # Per-class metrics\n", + " per_class = {}\n", + " for i, cls in enumerate(class_names):\n", + " y_t = y_true[:, i]\n", + " y_p = y_pred[:, i]\n", + " y_pr = y_prob[:, i]\n", + " \n", + " try:\n", + " auc = roc_auc_score(y_t, y_pr)\n", + " except:\n", + " auc = 0.0\n", + " \n", + " per_class[cls] = {\n", + " 'f1': f1_score(y_t, y_p, zero_division=0),\n", + " 'precision': precision_score(y_t, y_p, zero_division=0),\n", + " 'recall': recall_score(y_t, y_p, zero_division=0),\n", + " 'mcc': matthews_corrcoef(y_t, y_p),\n", + " 'auc': auc,\n", + " 'support': int(y_t.sum()),\n", + " }\n", + " \n", + " # Macro averages (equal weight to each class)\n", + " macro = {\n", + " 'f1_macro': np.mean([per_class[c]['f1'] for c in class_names]),\n", + " 'precision_macro': np.mean([per_class[c]['precision'] for c in class_names]),\n", + " 'recall_macro': np.mean([per_class[c]['recall'] for c in class_names]),\n", + " 'mcc_macro': np.mean([per_class[c]['mcc'] for c in class_names]),\n", + " 'auc_macro': np.mean([per_class[c]['auc'] for c in class_names]),\n", + " }\n", + " \n", + " # Micro averages (pooled across all classes)\n", + " micro = {\n", + " 'f1_micro': f1_score(y_true, y_pred, average='micro', zero_division=0),\n", + " 'precision_micro': precision_score(y_true, y_pred, average='micro', zero_division=0),\n", + " 'recall_micro': recall_score(y_true, y_pred, average='micro', zero_division=0),\n", + " }\n", + " \n", + " # Sample-based metrics\n", + " # zero_division=1: correct \"None\" predictions (both empty) score 1.0, not 0.0\n", + " sample = {\n", + " 'f1_samples': f1_score(y_true, y_pred, average='samples', zero_division=1),\n", + " 'jaccard_samples': jaccard_score(y_true, y_pred, average='samples', zero_division=1),\n", + " }\n", + " \n", + " # Overall metrics\n", + " overall = {\n", + " 'exact_match_ratio': (y_pred == y_true).all(axis=1).mean(),\n", + " 'hamming_loss': hamming_loss(y_true, y_pred),\n", + " }\n", + " \n", + " return {\n", + " 'per_class': per_class,\n", + " 'macro': macro,\n", + " 'micro': micro,\n", + " 'sample': sample,\n", + " 'overall': overall,\n", + " }\n", + "\n", + "\n", + "def print_evaluation_report(metrics, class_names):\n", + " \"\"\"Print formatted evaluation report.\"\"\"\n", + " \n", + " print(\"=\"*70)\n", + " print(\"HOLDOUT EVALUATION REPORT\")\n", + " print(\"=\"*70)\n", + " \n", + " # Per-class table\n", + " print(\"\\n--- Per-Class Metrics (Each Label = Independent Binary Decision) ---\")\n", + " print(f\"{'Class':<12} {'F1':>7} {'AUC':>7} {'MCC':>7} {'Prec':>7} {'Rec':>7} {'Support':>8}\")\n", + " print(\"-\"*60)\n", + " \n", + " for cls in class_names:\n", + " m = metrics['per_class'][cls]\n", + " print(f\"{cls:<12} {m['f1']:>7.4f} {m['auc']:>7.4f} {m['mcc']:>7.4f} \"\n", + " f\"{m['precision']:>7.4f} {m['recall']:>7.4f} {m['support']:>8}\")\n", + " \n", + " # Macro averages\n", + " print(\"-\"*60)\n", + " m = metrics['macro']\n", + " print(f\"{'Macro Avg':<12} {m['f1_macro']:>7.4f} {m['auc_macro']:>7.4f} {m['mcc_macro']:>7.4f} \"\n", + " f\"{m['precision_macro']:>7.4f} {m['recall_macro']:>7.4f}\")\n", + " \n", + " # Summary\n", + " print(\"\\n--- Summary Metrics ---\")\n", + " print(f\" Macro F1: {metrics['macro']['f1_macro']:.4f} ← Primary (class-balanced)\")\n", + " print(f\" Macro AUC: {metrics['macro']['auc_macro']:.4f} ← Threshold-independent\")\n", + " print(f\" Macro MCC: {metrics['macro']['mcc_macro']:.4f} ← Robust to imbalance\")\n", + " print(f\" Micro F1: {metrics['micro']['f1_micro']:.4f} ← Dominated by frequent classes\")\n", + " print(f\" Sample F1: {metrics['sample']['f1_samples']:.4f} ← Per-instance average\")\n", + " print(f\" Exact Match: {metrics['overall']['exact_match_ratio']:.4f} ← All labels correct\")\n", + " print(f\" Hamming Loss: {metrics['overall']['hamming_loss']:.4f} ← Fraction wrong (lower=better)\")\n", + " \n", + " # For paper\n", + " print(\"\\n--- For Report ---\")\n", + " print(f\" Primary: Macro F1 = {metrics['macro']['f1_macro']:.4f}\")\n", + " print(f\" Secondary: Macro AUC = {metrics['macro']['auc_macro']:.4f}, Macro MCC = {metrics['macro']['mcc_macro']:.4f}\")\n", + " print(f\" Overall: Exact Match = {metrics['overall']['exact_match_ratio']:.4f}\")\n", + "\n", + "\n", + "# -----------------------------\n", + "# Get Predictions\n", + "# -----------------------------\n", + "\n", + "print(\"=\"*60)\n", + "print(\"HOLDOUT PREDICTIONS\")\n", + "print(\"=\"*60)\n", + "\n", + "holdout_output = trainer.predict(holdout_dataset)\n", + "holdout_logits = holdout_output.predictions\n", + "holdout_probs = torch.sigmoid(torch.tensor(holdout_logits)).numpy()\n", + "holdout_preds = (holdout_probs >= THRESHOLD).astype(int)\n", + "\n", + "print(f\"Holdout samples: {len(holdout_labels)}\")\n", + "print(f\"Predictions shape: {holdout_preds.shape}\")\n", + "\n", + "# -----------------------------\n", + "# Evaluate\n", + "# -----------------------------\n", + "\n", + "metrics = evaluate_multilabel(\n", + " y_true=holdout_labels,\n", + " y_pred=holdout_preds,\n", + " y_prob=holdout_probs,\n", + " class_names=CLASSES,\n", + ")\n", + "\n", + "print_evaluation_report(metrics, CLASSES)\n", + "\n", + "# -----------------------------\n", + "# None Class (Derived)\n", + "# -----------------------------\n", + "\n", + "print(\"\\n--- 'None' Class (Derived: all classes = 0) ---\")\n", + "\n", + "holdout_none_true = (holdout_labels.sum(axis=1) == 0).astype(int)\n", + "holdout_none_pred = (holdout_preds.sum(axis=1) == 0).astype(int)\n", + "holdout_none_prob = 1 - holdout_probs.max(axis=1)\n", + "\n", + "none_f1 = f1_score(holdout_none_true, holdout_none_pred, zero_division=0)\n", + "none_mcc = matthews_corrcoef(holdout_none_true, holdout_none_pred)\n", + "try:\n", + " none_auc = roc_auc_score(holdout_none_true, holdout_none_prob)\n", + "except:\n", + " none_auc = 0.0\n", + "none_precision = precision_score(holdout_none_true, holdout_none_pred, zero_division=0)\n", + "none_recall = recall_score(holdout_none_true, holdout_none_pred, zero_division=0)\n", + "none_support = int(holdout_none_true.sum())\n", + "\n", + "print(f\"{'None':<12} {none_f1:>7.4f} {none_auc:>7.4f} {none_mcc:>7.4f} \"\n", + " f\"{none_precision:>7.4f} {none_recall:>7.4f} {none_support:>8}\")\n", + "\n", + "# -----------------------------\n", + "# Confusion Matrices\n", + "# -----------------------------\n", + "\n", + "print(\"\\n--- Confusion Matrices (per class) ---\")\n", + "print(f\"{'Class':<12} {'TP':>6} {'TN':>6} {'FP':>6} {'FN':>6}\")\n", + "print(\"-\"*40)\n", + "\n", + "for i, cls in enumerate(CLASSES):\n", + " y_true = holdout_labels[:, i]\n", + " y_pred = holdout_preds[:, i]\n", + " cm = confusion_matrix(y_true, y_pred)\n", + " tn, fp, fn, tp = cm.ravel()\n", + " print(f\"{cls:<12} {tp:>6} {tn:>6} {fp:>6} {fn:>6}\")\n", + "\n", + "none_cm = confusion_matrix(holdout_none_true, holdout_none_pred)\n", + "none_tn, none_fp, none_fn, none_tp = none_cm.ravel()\n", + "print(f\"{'None':<12} {none_tp:>6} {none_tn:>6} {none_fp:>6} {none_fn:>6}\")\n", + "\n", + "# -----------------------------\n", + "# Multi-label Analysis\n", + "# -----------------------------\n", + "\n", + "print(\"\\n--- Multi-label Analysis ---\")\n", + "\n", + "true_label_counts = holdout_labels.sum(axis=1)\n", + "pred_label_counts = holdout_preds.sum(axis=1)\n", + "\n", + "print(f\"Average labels per sentence:\")\n", + "print(f\" True: {true_label_counts.mean():.2f}\")\n", + "print(f\" Predicted: {pred_label_counts.mean():.2f}\")\n", + "\n", + "print(f\"\\nLabel count distribution:\")\n", + "print(f\"{'# Labels':<10} {'True':>8} {'Predicted':>10}\")\n", + "print(\"-\"*30)\n", + "for n in range(7):\n", + " true_count = (true_label_counts == n).sum()\n", + " pred_count = (pred_label_counts == n).sum()\n", + " print(f\"{n:<10} {true_count:>8} {pred_count:>10}\")\n", + "\n", + "# -----------------------------\n", + "# Save Results\n", + "# -----------------------------\n", + "\n", + "print(\"\\n\" + \"=\"*60)\n", + "print(\"SAVING RESULTS\")\n", + "print(\"=\"*60)\n", + "\n", + "# Save predictions\n", + "df_holdout_results = df_holdout.copy()\n", + "for i, cls in enumerate(CLASSES):\n", + " df_holdout_results[f'{cls}_prob'] = holdout_probs[:, i]\n", + " df_holdout_results[f'{cls}_pred'] = holdout_preds[:, i]\n", + "df_holdout_results['None_prob'] = holdout_none_prob\n", + "df_holdout_results['None_pred'] = holdout_none_pred\n", + "\n", + "df_holdout_results.to_csv(f\"{MODEL_DIR}/holdout_predictions.csv\", index=False)\n", + "df_holdout_results.to_pickle(f\"{MODEL_DIR}/holdout_predictions.pkl\")\n", + "print(f\" Predictions: {MODEL_DIR}/holdout_predictions.csv\")\n", + "\n", + "# Save metrics\n", + "metrics_summary = {\n", + " 'model': MODEL_NAME,\n", + " 'threshold': THRESHOLD,\n", + " 'holdout_size': len(df_holdout),\n", + " 'macro': metrics['macro'],\n", + " 'micro': metrics['micro'],\n", + " 'sample': metrics['sample'],\n", + " 'overall': metrics['overall'],\n", + " 'per_class': metrics['per_class'],\n", + " 'none_class': {\n", + " 'f1': none_f1, 'auc': none_auc, 'mcc': none_mcc,\n", + " 'precision': none_precision, 'recall': none_recall, 'support': none_support,\n", + " },\n", + " 'hyperparameters': {\n", + " 'learning_rate': LEARNING_RATE,\n", + " 'weight_decay': WEIGHT_DECAY,\n", + " 'warmup_ratio': WARMUP_RATIO,\n", + " 'num_epochs': NUM_EPOCHS,\n", + " 'batch_size': BATCH_SIZE,\n", + " 'gradient_accumulation': GRADIENT_ACCUMULATION,\n", + " 'max_length': MAX_LENGTH,\n", + " },\n", + "}\n", + "\n", + "with open(f\"{MODEL_DIR}/holdout_metrics.json\", 'w') as f:\n", + " json.dump(metrics_summary, f, indent=2, default=float)\n", + "print(f\" Metrics: {MODEL_DIR}/holdout_metrics.json\")\n", + "\n", + "# Save training history\n", + "with open(f\"{MODEL_DIR}/training_history.json\", 'w') as f:\n", + " json.dump(trainer.state.log_history, f, indent=2)\n", + "print(f\" History: {MODEL_DIR}/training_history.json\")\n", + "\n", + "print(\"\\n\" + \"=\"*60)\n", + "print(\"DONE\")\n", + "print(\"=\"*60)\n", + "print(f\"\\nFinal Results:\")\n", + "print(f\" Macro F1: {metrics['macro']['f1_macro']:.4f}\")\n", + "print(f\" Macro AUC: {metrics['macro']['auc_macro']:.4f}\")\n", + "print(f\" Macro MCC: {metrics['macro']['mcc_macro']:.4f}\")\n", + "print(f\" Exact Match: {metrics['overall']['exact_match_ratio']:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "91936e6d-3591-4a86-956c-9090d7e46bb6", + "metadata": {}, + "source": [ + "**Please cite this paper** if you use any part or all of this code in a project - be it commercial or academic: \n", + "\n", + "> Ringel, Daniel, *Creating Synthetic Experts with Generative Artificial Intelligence* (December 5, 2023). Kenan Institute of Private Enterprise Research Paper No. 4542949, Available at SSRN: https://ssrn.com/abstract=4542949 or http://dx.doi.org/10.2139/ssrn.4542949 \n" + ] + }, + { + "cell_type": "markdown", + "id": "5e5e8344-2674-4daf-b468-6afbf7aa15bf", + "metadata": {}, + "source": [ + "*This notebook was developed by Daniel M. Ringel in January 2026 with the help of the various vendor's API documentation (and examples) as well as genAI models from OpenAI and Anthropic.*" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/reference/Ringel_488-2026_Capstone_Constructs.pdf b/docs/reference/Ringel_488-2026_Capstone_Constructs.pdf new file mode 100644 index 0000000..af9e381 Binary files /dev/null and b/docs/reference/Ringel_488-2026_Capstone_Constructs.pdf differ diff --git a/docs/reference/Signoff_assn_instructions.md b/docs/reference/Signoff_assn_instructions.md new file mode 100644 index 0000000..4c9dae5 --- /dev/null +++ b/docs/reference/Signoff_assn_instructions.md @@ -0,0 +1,71 @@ +# CAPSTONE: Construct of Interest and Data Sign-off by Dr. D. + +**Due:** April 7 by 11:59pm | **Points:** 0 | **Submission:** File upload + +--- + +## Overview + +Your capstone project requires you to classify a construct of interest in data at scale. + +You must get approval for your construct of interest and data from your instructor. + +This is an ungraded assignment. However, it is the prerequisite to your capstone project. + +## Goal + +Pick a well-documented, theoretically founded construct of interest. Explain why a firm would want to classify this construct at scale. + +### Your Construct Must Be: + +- **Business-relevant** — addresses a real business decision +- **Theoretically grounded** — anchored in established literature +- **Well documented** — clearly defined in academic or industry sources +- **Observable in text** — detectable in your data source +- **Definable with clear rules** — specific enough for reliable labeling +- **Complex & nuanced** — more than just sentiment (not positive/negative) + +You must pick one of the seven provided constructs of interest from here: https://www.ringel.ai/UNC/2026/BUSI488/Class23/Ringel_488-2026_Capstone_Constructs.pdf + +## Define Your Construct Precisely + +Turn the concept into labels humans can apply consistently. + +### You Must Create: + +- **Label set** (classes/categories) +- **Clear definitions and decision rules** for each label +- **Borderline cases** — guidance for unclear examples +- **None/Other policy** — if applicable (multi-class yes, multi-label no) +- **2-3 example texts** per label (your own examples) +- **Decision:** multi-class (one label per item) vs multi-label (multiple labels can apply) + +### Consider Your Data Source + +Before finalizing, ask yourself: + +- Do these data, when classified, inform and improve a business decision? +- Can the construct of interest (all its labels/classes) be sufficiently found in these data? +- Are these data abundantly available and do they need to be analyzed frequently and/or at scale to justify building a vertical AI? + +## Important Guardrails + +- **Choose a data source** that fits your construct and is realistically useful to a firm +- **Use public data or properly de-identified data only** — no sensitive internal company data +- **Pilot test first:** Before committing, do a quick manual pilot on 100–200 texts in the developer platform playground or ChatGPT to confirm your construct appears in the source and that your labels are workable + +## Deliverable for This Assignment + +**Two-page maximum, double-spaced documentation** containing: + +1. **Definition** of your construct of interest and its labels/classes + +2. **Sources & Citations** that support your construct of interest (and its classes/labels), demonstrating that it is: + - Theoretically founded + - Well-established in literature + - Meaningful to decision makers + +3. **Data Description** explaining: + - What data you will identify it in + - How you will acquire these data + - Why identifying your construct at scale/frequently in these data is valuable (justifies the need for a vertical AI)