SEC-cyBERT/docs/reference/Ringel_2026_VerticalAI_Capstone_Pipeline_Example.ipynb
2026-03-31 16:27:47 -04:00

4434 lines
206 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "f2658472-1630-4aae-aaa4-01987b0eff1e",
"metadata": {},
"source": [
"# Training a Vertical AI for 10K Text Classification in Business Functions\n",
"Supports APIs of OpenAI (Responses API), Anthropic, xAI, Fireworks AI, Google, UNC Azure hosted\n",
"\n",
"February 4, 2026 \n",
"\n",
"*Version 1.0*\n",
"\n",
"Copyright 2026 Daniel M. Ringel \n",
"www.ringel.ai \n",
"\n",
"\n",
"**Please cite this paper** if you use any part or all of this code in a project - be it commercial or academic: \n",
"\n",
"> Ringel, Daniel, *Creating Synthetic Experts with Generative Artificial Intelligence* (December 5, 2023). Kenan Institute of Private Enterprise Research Paper No. 4542949, Available at SSRN: https://ssrn.com/abstract=4542949 or http://dx.doi.org/10.2139/ssrn.4542949 \n"
]
},
{
"cell_type": "markdown",
"id": "4f4233be-6c67-41db-bc89-7b47bde71d59",
"metadata": {},
"source": [
"Query various serverless genAI models to classify text by example of a multi-label classification problem\n",
"- Sentences from 10K filings of fortune 500 companies\n",
"- Construct of Interest: Business functions\n",
"- Valid labels: Marketing, Finance, Accounting, Operations, IT, HR"
]
},
{
"cell_type": "markdown",
"id": "eda002cb-c5ef-4a68-ba10-4d4ee65e836a",
"metadata": {},
"source": [
"> **IMPORTANT** Running this code will cost you API credits (and requires you to ahve accounts with the providers). You will need to supply your own API keys. Beware that you may be subject to rate limits (how many queries you can send per minute) and which models you can use (OpenAI, for example, requires you to verify your identidy with an ID to access many models). Regardless, every time you execute this code, you will drain your API credits = real money! Thus, make wise decisions about what and how much to label."
]
},
{
"cell_type": "markdown",
"id": "9228e155-f4b0-45f8-96ba-8cb99c0bb990",
"metadata": {},
"source": [
"## **Disclaimer**\n",
"\n",
"**USE AT YOUR OWN RISK**\n",
"\n",
"This code is provided \"as is\" without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and noninfringement.\n",
"\n",
"Under no circumstances shall Daniel Ringel be liable for any direct, indirect, incidental, special, exemplary, or consequential damages (including, but not limited to, procurement of substitute goods or services, loss of use, data, or profits, or business interruption) however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence or otherwise) arising in any way out of the use of this code, even if advised of the possibility of such damage.\n",
"\n",
"**Additional Notes:**\n",
"- API costs incurred from running this code are your sole responsibility\n",
"- Verify all API pricing before running at scale\n",
"- Test with small samples before processing large datasets\n",
"- The author makes no guarantees about the accuracy, reliability, or completeness of results\n",
"- This code is for educational and research purposes\n",
"\n",
"By using this code, you acknowledge that you have read this disclaimer and agree to its terms.\n"
]
},
{
"cell_type": "markdown",
"id": "4eed4986-cdb9-4285-80d4-7c2e415a5268",
"metadata": {},
"source": [
"# Installs and updates\n",
"- need to run only once if you are on your own computer\n",
"- on CoLab, you may need to run each time"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8714eefd-4af8-4116-a4fb-922bf1aa4667",
"metadata": {},
"outputs": [],
"source": [
"# Core\n",
"!pip install --upgrade pandas tqdm\n",
"# OpenAI (also used for xAI and Fireworks)\n",
"!pip install --upgrade openai\n",
"# Anthropic\n",
"!pip install --upgrade anthropic\n",
"# Google Gemini\n",
"!pip install --upgrade google-genai google-api-core\n",
"# For Label Agreement\n",
"!pip install -q -U krippendorff \n",
"# For Fine-tuning a pretrained LLM\n",
"!pip install -q -U transformers datasets accelerate scikit-learn\n",
"!pip install iterative-stratification"
]
},
{
"cell_type": "markdown",
"id": "9ef820f3-b6b2-4799-9af8-45e8e4ac5194",
"metadata": {},
"source": [
"# Set Environmental Variables (API Keys) on local Computer \n",
"### *(or on Colab in Secret Keys Tab)*\n",
"> **DO NOT SHARE API KEYS!!!** Delte them before sharing this notebook\n",
" \n",
"> On ***Colab***, use \"Secrets\" in Tab (key) on right. Make sure to use the exact same API Key names (e.g., \"OPENAI_API_KEY\") spelled and capitalized as shown!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c82d0c2c-8b19-49e0-80d1-1f7ebeb53f78",
"metadata": {},
"outputs": [],
"source": [
"# Put your API Keys here if you run this locally.\n",
"import os\n",
"os.environ[\"OPENAI_API_KEY\"] = \"\"\n",
"os.environ[\"ANTHROPIC_API_KEY\"] = \"\"\n",
"os.environ[\"FIREWORKS_API_KEY\"] = \"\"\n",
"os.environ[\"XAI_API_KEY\"] = \"\"\n",
"os.environ[\"GOOGLE_API_KEY\"] = \"\"\n",
"os.environ[\"AZURE_API_KEY\"] = \"\""
]
},
{
"cell_type": "markdown",
"id": "8803ce06-67df-49d4-851e-2617cbf157be",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0dbd424-0d89-4b2a-af13-aa98bc07f0a6",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"import time\n",
"import datetime\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"\n",
"# Vendor SDKs\n",
"import openai\n",
"from openai import OpenAI, AzureOpenAI, RateLimitError, APIError, AuthenticationError\n",
"import anthropic\n",
"from google import genai\n",
"from google.genai import types\n",
"from google.api_core import exceptions as google_exceptions"
]
},
{
"cell_type": "markdown",
"id": "e3be8e22-dd79-4417-bf34-235900d5487a",
"metadata": {},
"source": [
"# Vendor/Model Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3c180681-cd7c-4d27-a429-06a149d8933c",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Vendor/Model Configuration\n",
"# All prices per 1M tokens (USD) ----> UPDATE THE PRICES!!!!\n",
"# supplier = who made the model (for when using 3rd party APIs like Fireworks)\n",
"# -----------------------------\n",
"VENDORS = {\n",
" \"openai\": {\n",
" \"models\": {\n",
" # GPT-5 family\n",
" \"gpt-5.2\": {\"supplier\": \"openai\", \"price_in\": 1.75, \"price_out\": 14.00},\n",
" \"gpt-5\": {\"supplier\": \"openai\", \"price_in\": 1.25, \"price_out\": 10.00},\n",
" \"gpt-5-mini\": {\"supplier\": \"openai\", \"price_in\": 0.30, \"price_out\": 2.50},\n",
" \"gpt-5-nano\": {\"supplier\": \"openai\", \"price_in\": 0.10, \"price_out\": 0.40},\n",
" # GPT-4.1 family\n",
" \"gpt-4.1\": {\"supplier\": \"openai\", \"price_in\": 2.00, \"price_out\": 8.00},\n",
" # GPT-4o family\n",
" \"gpt-4o\": {\"supplier\": \"openai\", \"price_in\": 2.50, \"price_out\": 10.00},\n",
" }\n",
" },\n",
" \"azure\": {\n",
" \"models\": {\n",
" \"gpt-4.1\": {\"supplier\": \"openai\", \"price_in\": 2.00, \"price_out\": 8.00},\n",
" \"gpt-4o\": {\"supplier\": \"openai\", \"price_in\": 2.50, \"price_out\": 8.00},\n",
" },\n",
" \"endpoint\": \"https://azureaiapi.cloud.unc.edu\",\n",
" \"api_version\": \"2025-04-01-preview\",\n",
" },\n",
" \"anthropic\": {\n",
" \"models\": {\n",
" # Claude 4.5 family\n",
" \"claude-opus-4-5-20251101\": {\"supplier\": \"anthropic\", \"price_in\": 5.00, \"price_out\": 25.00},\n",
" \"claude-sonnet-4-5-20250929\": {\"supplier\": \"anthropic\", \"price_in\": 3.00, \"price_out\": 15.00},\n",
" \"claude-haiku-4-5-20251001\": {\"supplier\": \"anthropic\", \"price_in\": 1.00, \"price_out\": 5.00},\n",
" # Claude 4 family\n",
" \"claude-sonnet-4-20250514\": {\"supplier\": \"anthropic\", \"price_in\": 3.00, \"price_out\": 15.00},\n",
" }\n",
" },\n",
" \"google\": {\n",
" \"models\": {\n",
" # Gemini 3 - latest flagship - preview\n",
" \"gemini-3-pro-preview\": {\"supplier\": \"google\", \"price_in\": 2.00, \"price_out\": 12.00},\n",
" \"gemini-3-flash-preview\": {\"supplier\": \"google\", \"price_in\": 0.50, \"price_out\": 3.00},\n",
" # Gemini 2.5 - stable\n",
" \"gemini-2.5-pro\": {\"supplier\": \"google\", \"price_in\": 1.25, \"price_out\": 10.00},\n",
" \"gemini-2.5-flash\": {\"supplier\": \"google\", \"price_in\": 0.30, \"price_out\": 2.50},\n",
" }\n",
" },\n",
" \"xai\": {\n",
" \"models\": {\n",
" \"grok-4\": {\"supplier\": \"xai\", \"price_in\": 3.00, \"price_out\": 15.00},\n",
" \"grok-4-1-fast-reasoning\": {\"supplier\": \"xai\", \"price_in\": 0.20, \"price_out\": 0.50},\n",
" \"grok-4-1-fast-non-reasoning\": {\"supplier\": \"xai\", \"price_in\": 0.20, \"price_out\": 0.50}\n",
" }\n",
" },\n",
" \"fireworks\": {\n",
" \"models\": {\n",
" \"deepseek-v3p2\": {\"supplier\": \"deepseek\", \"price_in\": 0.56, \"price_out\": 1.68},\n",
" \"qwen3-vl-235b-a22b-instruct\": {\"supplier\": \"alibaba\", \"price_in\": 0.22, \"price_out\": 0.88},\n",
" \"deepseek-r1-0528\": {\"supplier\": \"deepseek\", \"price_in\": 1.35, \"price_out\": 5.40},\n",
" \"qwen3-vl-235b-a22b-thinking\": {\"supplier\": \"alibaba\", \"price_in\": 0.22, \"price_out\": 0.88},\n",
" \"kimi-k2p5\": {\"supplier\": \"moonshot\", \"price_in\": 1.20, \"price_out\": 1.20}\n",
" }\n",
" } \n",
"}"
]
},
{
"cell_type": "markdown",
"id": "3c799dc8-e11d-43f5-b48a-61ce699dcfec",
"metadata": {},
"source": [
"# System Prompt\n",
"- Be clear\n",
"- Be specific\n",
"- Try RTF (role task format)\n",
"- Succinct and exhaustive construct definitions\n",
"- Could give examples (few-shot), but may create noise or focus model too much on these cases\n",
"- Explain tie-breakers or dricky cases\n",
"- Define output format."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b46f6ca7-d791-48b2-838d-039fd696780d",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# System Prompt\n",
"# -----------------------------\n",
"SYSTEM_PROMPT = \"\"\"You are a business analyst classifying sentences from 10-K filings.\n",
"\n",
"Classify the sentence into one or more business functions: Marketing, Finance, Accounting, Operations, IT, HR (human resources)\n",
"\n",
"Definitions:\n",
"- Marketing: Customers, markets, demand, branding, advertising, promotion, pricing strategy, market research, segmentation, positioning, sales strategy, sales channels.\n",
"- Finance: Capital structure, funding, liquidity, treasury, investing, valuation, financial risk management (interest rates, FX, hedging), dividends, buybacks, M&A, financing activities.\n",
"- Accounting: Financial reporting, disclosures, GAAP/IFRS, accounting policies, accounting estimates, revenue recognition, impairments, reserves, depreciation, amortization, audits, internal controls over financial reporting (ICFR), tax accounting.\n",
"- Operations: Production, service delivery, supply chain, logistics, procurement, inventory management, manufacturing, capacity planning, facilities, quality control, safety, process efficiency, fulfillment, operational infrastructure, IT systems.\n",
"- IT: Information technology systems, software, hardware, cybersecurity, data management, cloud computing, digital infrastructure, technology platforms, system integration, IT support, technology investments, data analytics infrastructure.\n",
"- HR: Human resources, workforce, hiring, recruitment, talent acquisition, employee benefits, compensation, training, professional development, labor relations, employee retention, workplace safety, organizational culture.\n",
"\n",
"Rules:\n",
"1. Assign labels only when there is clear, direct evidence in the sentence.\n",
"2. Assign multiple labels if clearly relevant to more than one field.\n",
"3. Tie-breakers: Reporting/policies/controls/disclosures → Accounting; Funding/treasury/hedging/M&A → Finance.\n",
"4. Out of scope: General corporate governance, board matters, executive compensation, legal proceedings → return empty array.\n",
"5. Output: Return ONLY a JSON array with exact spelling: \"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\". If none apply, return [].\n",
"6. DO NOT provide a reason or explanation for your labels.\"\"\"\n",
" \n",
"ALLOWED_LABELS = {\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"}\n",
"LABEL_ORDER = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]"
]
},
{
"cell_type": "markdown",
"id": "09e63238-a066-4a80-95cd-1eacb59d19d5",
"metadata": {},
"source": [
"# Client Initialization & API Call Functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2069211c-abff-4c85-8a5e-6de3539431dd",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Client Initialization on a local PC (requires setting them at the beginning) or \n",
"# on Google CoLab (requires defining them in the secrets tab with key symbol)\n",
"# -----------------------------\n",
"def get_api_key(key_name: str) -> str:\n",
" \"\"\"Get API key from Colab secrets or environment variables.\"\"\"\n",
" try:\n",
" from google.colab import userdata\n",
" return userdata.get(key_name)\n",
" except (ImportError, ModuleNotFoundError):\n",
" # Not in Colab, use environment variables\n",
" return os.environ[key_name]\n",
"\n",
"\n",
"def init_client(vendor: str):\n",
" \"\"\"Initialize API client for vendor.\"\"\"\n",
" if vendor == \"openai\":\n",
" return OpenAI(api_key=get_api_key(\"OPENAI_API_KEY\"))\n",
" elif vendor == \"azure\":\n",
" return AzureOpenAI(\n",
" api_key=get_api_key(\"AZURE_API_KEY\"),\n",
" azure_endpoint=VENDORS[\"azure\"][\"endpoint\"],\n",
" api_version=VENDORS[\"azure\"][\"api_version\"],\n",
" ) \n",
" elif vendor == \"anthropic\":\n",
" return anthropic.Anthropic(api_key=get_api_key(\"ANTHROPIC_API_KEY\"))\n",
" elif vendor == \"google\":\n",
" return genai.Client(api_key=get_api_key(\"GOOGLE_API_KEY\"))\n",
" elif vendor == \"xai\":\n",
" return OpenAI(\n",
" api_key=get_api_key(\"XAI_API_KEY\"),\n",
" base_url=\"https://api.x.ai/v1\"\n",
" )\n",
" elif vendor == \"fireworks\":\n",
" return OpenAI(\n",
" api_key=get_api_key(\"FIREWORKS_API_KEY\"),\n",
" base_url=\"https://api.fireworks.ai/inference/v1\"\n",
" )\n",
" else:\n",
" raise ValueError(f\"Unknown vendor: {vendor}\")\n",
"\n",
"# -----------------------------\n",
"# API Call Functions\n",
"# -----------------------------\n",
"def call_openai(client, sentence: str, model: str, system_prompt: str,\n",
" max_tokens: int = 64, reasoning_effort: str = None) -> dict:\n",
" \"\"\"\n",
" Call OpenAI Responses API.\n",
" \n",
" Token billing:\n",
" - input_tokens: billed at input rate\n",
" - output_tokens: includes reasoning_tokens, all billed at output rate\n",
" - reasoning_tokens: subset of output_tokens (internal thinking)\n",
" \n",
" Note: For reasoning models, max_output_tokens must accommodate BOTH\n",
" reasoning tokens AND response tokens. We scale up accordingly.\n",
" \"\"\"\n",
" # For reasoning models, we need more output tokens to fit both reasoning + response\n",
" if reasoning_effort:\n",
" # Reasoning needs room: base response tokens + reasoning overhead\n",
" # low ~500, medium ~1000, high ~2000+ reasoning tokens typical\n",
" reasoning_overhead = {\"low\": 512, \"medium\": 1024, \"high\": 2048}.get(reasoning_effort, 512)\n",
" effective_max_tokens = max_tokens + reasoning_overhead\n",
" else:\n",
" effective_max_tokens = max_tokens\n",
" \n",
" params = {\n",
" \"model\": model,\n",
" \"instructions\": system_prompt,\n",
" \"input\": sentence,\n",
" \"max_output_tokens\": effective_max_tokens,\n",
" }\n",
" \n",
" if reasoning_effort:\n",
" params[\"reasoning\"] = {\"effort\": reasoning_effort}\n",
" else:\n",
" params[\"temperature\"] = 0\n",
" \n",
" resp = client.responses.create(**params)\n",
" \n",
" # Extract text - output_text may be empty for some response types\n",
" text = \"\"\n",
" if resp.output_text:\n",
" text = resp.output_text.strip()\n",
" else:\n",
" # Fallback: extract from output items\n",
" for item in resp.output:\n",
" if getattr(item, 'type', None) == 'message':\n",
" for block in getattr(item, 'content', []):\n",
" if getattr(block, 'type', None) == 'text':\n",
" text += getattr(block, 'text', '')\n",
" text = text.strip()\n",
" \n",
" # Extract token breakdown\n",
" input_tokens = resp.usage.input_tokens\n",
" output_tokens = resp.usage.output_tokens\n",
" \n",
" # Reasoning tokens are part of output_tokens (not additional)\n",
" reasoning_tokens = 0\n",
" if hasattr(resp.usage, 'output_tokens_details') and resp.usage.output_tokens_details:\n",
" reasoning_tokens = getattr(resp.usage.output_tokens_details, 'reasoning_tokens', 0) or 0\n",
" \n",
" return {\n",
" \"text\": text,\n",
" \"input_tokens\": input_tokens,\n",
" \"output_tokens\": output_tokens, # total output (includes reasoning)\n",
" \"reasoning_tokens\": reasoning_tokens, # internal thinking (subset of output)\n",
" \"response_tokens\": output_tokens - reasoning_tokens, # visible response\n",
" \"raw_response\": resp.model_dump() if hasattr(resp, 'model_dump') else str(resp),\n",
" }\n",
"\n",
"def call_azure(client, sentence: str, model: str, system_prompt: str,\n",
" max_tokens: int = 64) -> dict:\n",
" \"\"\"\n",
" Call Azure OpenAI API (chat completions).\n",
" \n",
" Azure uses the standard chat completions endpoint, not Responses API.\n",
" \"\"\"\n",
" params = {\n",
" \"model\": model,\n",
" \"messages\": [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": sentence}\n",
" ],\n",
" \"max_tokens\": max_tokens,\n",
" \"temperature\": 0, # Set temp to 0 to be deterministic\n",
" }\n",
" \n",
" resp = client.chat.completions.create(**params)\n",
" \n",
" text = (resp.choices[0].message.content or \"\").strip()\n",
" \n",
" return {\n",
" \"text\": text,\n",
" \"input_tokens\": resp.usage.prompt_tokens,\n",
" \"output_tokens\": resp.usage.completion_tokens,\n",
" \"reasoning_tokens\": 0,\n",
" \"response_tokens\": resp.usage.completion_tokens,\n",
" \"raw_response\": resp.model_dump() if hasattr(resp, 'model_dump') else str(resp),\n",
" }\n",
" \n",
"def call_anthropic(client, sentence: str, model: str, system_prompt: str,\n",
" max_tokens: int = 64, thinking_budget: int = None) -> dict:\n",
" \"\"\"\n",
" Call Anthropic Messages API.\n",
" \n",
" Token billing:\n",
" - input_tokens: billed at input rate\n",
" - output_tokens: includes thinking tokens, all billed at output rate\n",
" - NOTE: Anthropic doesn't provide separate thinking token count in usage.\n",
" The output_tokens is what's billed (includes full thinking, not summary).\n",
" \"\"\"\n",
" params = {\n",
" \"model\": model,\n",
" \"system\": system_prompt,\n",
" \"messages\": [{\"role\": \"user\", \"content\": sentence}],\n",
" }\n",
" \n",
" if thinking_budget:\n",
" thinking_budget = max(1024, thinking_budget)\n",
" params[\"thinking\"] = {\"type\": \"enabled\", \"budget_tokens\": thinking_budget}\n",
" params[\"max_tokens\"] = thinking_budget + max_tokens\n",
" else:\n",
" params[\"temperature\"] = 0\n",
" params[\"max_tokens\"] = max_tokens\n",
" \n",
" resp = client.messages.create(**params)\n",
" \n",
" # Extract text (skip thinking blocks - they're summarized in Claude 4)\n",
" text = \"\".join(b.text for b in resp.content if b.type == \"text\").strip()\n",
" \n",
" # output_tokens includes thinking tokens (billed amount)\n",
" # Anthropic doesn't provide breakdown like OpenAI does\n",
" output_tokens = resp.usage.output_tokens\n",
" \n",
" return {\n",
" \"text\": text,\n",
" \"input_tokens\": resp.usage.input_tokens,\n",
" \"output_tokens\": output_tokens, # total billed (includes thinking)\n",
" \"reasoning_tokens\": 0, # Anthropic doesn't provide breakdown\n",
" \"response_tokens\": output_tokens, # Can't separate, use total\n",
" \"raw_response\": resp.model_dump() if hasattr(resp, 'model_dump') else str(resp),\n",
" }\n",
"\n",
"def call_google(client, sentence: str, model: str, system_prompt: str,\n",
" max_tokens: int = 64, thinking_level: str = None) -> dict:\n",
" \"\"\"\n",
" Call Google Gemini API.\n",
" \n",
" Gemini 3 Pro: thinking_level \"low\", \"high\" (default)\n",
" Gemini 3 Flash: thinking_level \"minimal\", \"low\", \"medium\", \"high\" (default)\n",
" \n",
" Note: Even \"low\" still uses thinking tokens! Only Flash \"minimal\" truly minimizes.\n",
" \n",
" Token billing:\n",
" - output tokens include thinking tokens (no separate breakdown)\n",
" \"\"\"\n",
" from google.genai import types\n",
" \n",
" # Build config\n",
" config_params = {}\n",
" \n",
" # Determine token allocation based on thinking level: Even \"low\" uses thinking tokens (~50-200), so we need buffer\n",
" if thinking_level == \"minimal\":\n",
" # Only Flash supports minimal - truly minimal thinking\n",
" config_params[\"max_output_tokens\"] = max_tokens + 128\n",
" elif thinking_level == \"low\":\n",
" # Low still uses thinking tokens, need buffer\n",
" config_params[\"max_output_tokens\"] = max_tokens + 512\n",
" elif thinking_level == \"medium\":\n",
" config_params[\"max_output_tokens\"] = max_tokens + 2048\n",
" else:\n",
" # high or default (None) - full thinking\n",
" config_params[\"max_output_tokens\"] = max_tokens + 4096\n",
" \n",
" # Set thinking level if specified\n",
" if thinking_level:\n",
" config_params[\"thinking_config\"] = types.ThinkingConfig(\n",
" thinking_level=thinking_level\n",
" )\n",
" \n",
" # Gemini 3 recommends temperature=1.0 (default), don't override\n",
" \n",
" config = types.GenerateContentConfig(**config_params)\n",
" \n",
" # Combine system prompt and sentence\n",
" full_prompt = f\"{system_prompt}\\n{sentence}\"\n",
" \n",
" resp = client.models.generate_content(\n",
" model=model,\n",
" contents=full_prompt,\n",
" config=config,\n",
" )\n",
" \n",
" # Extract text - resp.text may be empty, need to check parts\n",
" text = \"\"\n",
" if resp.text:\n",
" text = resp.text.strip()\n",
" elif resp.candidates and resp.candidates[0].content and resp.candidates[0].content.parts:\n",
" # Extract text from parts, skipping thinking parts\n",
" for part in resp.candidates[0].content.parts:\n",
" # Skip thinking parts (they have 'thought' attribute set to True)\n",
" if hasattr(part, 'thought') and part.thought:\n",
" continue\n",
" if hasattr(part, 'text') and part.text:\n",
" text += part.text\n",
" text = text.strip()\n",
" \n",
" # Token usage\n",
" usage = resp.usage_metadata\n",
" input_tokens = usage.prompt_token_count\n",
" output_tokens = usage.candidates_token_count\n",
" # - candidates_token_count = actual response tokens\n",
" # - thoughts_token_count = thinking/reasoning tokens (billed as output)\n",
" thoughts_tokens = getattr(resp.usage_metadata, 'thoughts_token_count', 0) or 0\n",
" candidates_tokens = resp.usage_metadata.candidates_token_count or 0\n",
"\n",
" return {\n",
" \"text\": text,\n",
" \"input_tokens\": resp.usage_metadata.prompt_token_count,\n",
" \"output_tokens\": candidates_tokens + thoughts_tokens, # Total billed output\n",
" \"reasoning_tokens\": thoughts_tokens,\n",
" \"response_tokens\": candidates_tokens,\n",
" \"raw_response\": str(resp),\n",
" }\n",
" \n",
"def call_fireworks(client, sentence: str, model: str, system_prompt: str,\n",
" max_tokens: int = 64, reasoning_effort: str = None) -> dict:\n",
" \"\"\"\n",
" Call Fireworks API using OpenAI-compatible endpoint.\n",
" \n",
" Thinking models (*-thinking, r1, kimi) output reasoning in content,\n",
" so we allocate extra tokens and use recommended temperature.\n",
" \"\"\"\n",
" full_model_id = f\"accounts/fireworks/models/{model}\"\n",
" \n",
" # Thinking models need more tokens (thinking appears in content)\n",
" is_thinking_model = \"thinking\" in model or \"r1\" in model or model == \"kimi-k2p5\"\n",
" \n",
" if is_thinking_model:\n",
" effective_max_tokens = max_tokens + 4096\n",
" temperature = 0.6 # Recommended for thinking\n",
" else:\n",
" effective_max_tokens = max_tokens\n",
" temperature = 0 # Deterministic for chat\n",
" \n",
" params = {\n",
" \"model\": full_model_id,\n",
" \"messages\": [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": sentence}\n",
" ],\n",
" \"max_tokens\": effective_max_tokens,\n",
" \"temperature\": temperature,\n",
" }\n",
"\n",
" resp = client.chat.completions.create(**params)\n",
" \n",
" text = (resp.choices[0].message.content or \"\").strip()\n",
" \n",
" return {\n",
" \"text\": text,\n",
" \"input_tokens\": resp.usage.prompt_tokens,\n",
" \"output_tokens\": resp.usage.completion_tokens,\n",
" \"reasoning_tokens\": 0,\n",
" \"response_tokens\": resp.usage.completion_tokens,\n",
" \"raw_response\": resp.model_dump() if hasattr(resp, 'model_dump') else str(resp),\n",
" }\n",
"\n",
"def call_xai(client, sentence: str, model: str, system_prompt: str,\n",
" max_tokens: int = 64) -> dict:\n",
" \"\"\"\n",
" Call xAI API using OpenAI-compatible endpoint.\n",
" \n",
" grok-4 and *-reasoning models are always reasoning.\n",
" *-non-reasoning models are chat mode.\n",
" \n",
" Token billing:\n",
" - completion_tokens: visible response only\n",
" - reasoning_tokens: in completion_tokens_details, billed separately\n",
" - total output billed = completion_tokens + reasoning_tokens\n",
" \"\"\"\n",
" is_reasoning = \"non-reasoning\" not in model\n",
" \n",
" params = {\n",
" \"model\": model,\n",
" \"messages\": [\n",
" {\"role\": \"system\", \"content\": system_prompt},\n",
" {\"role\": \"user\", \"content\": sentence}\n",
" ],\n",
" \"max_tokens\": max_tokens + 4096 if is_reasoning else max_tokens,\n",
" \"temperature\": 1.0 if is_reasoning else 0,\n",
" }\n",
" \n",
" resp = client.chat.completions.create(**params)\n",
" text = (resp.choices[0].message.content or \"\").strip()\n",
" \n",
" # Extract reasoning tokens from completion_tokens_details\n",
" reasoning_tokens = 0\n",
" if hasattr(resp.usage, 'completion_tokens_details') and resp.usage.completion_tokens_details:\n",
" reasoning_tokens = getattr(resp.usage.completion_tokens_details, 'reasoning_tokens', 0) or 0\n",
" \n",
" # completion_tokens is visible response, reasoning_tokens is separate\n",
" # Both are billed at output rate\n",
" completion_tokens = resp.usage.completion_tokens\n",
" total_output = completion_tokens + reasoning_tokens\n",
" \n",
" return {\n",
" \"text\": text,\n",
" \"input_tokens\": resp.usage.prompt_tokens,\n",
" \"output_tokens\": total_output, # Total billed at output rate\n",
" \"reasoning_tokens\": reasoning_tokens,\n",
" \"response_tokens\": completion_tokens, # Visible response only\n",
" \"raw_response\": resp.model_dump() if hasattr(resp, 'model_dump') else str(resp),\n",
" }"
]
},
{
"cell_type": "markdown",
"id": "35a1a8ba-0a5d-4d1d-a10c-5404ca9a2919",
"metadata": {},
"source": [
"# Parsing & Cost"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ae5e0d1f-be15-4cd1-84d7-3a05743e7c0a",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Parsing & Cost\n",
"# -----------------------------\n",
"\n",
"def parse_labels(text: str) -> list:\n",
" \"\"\"Parse JSON array, validate labels. Handles markdown code fences and extra text.\"\"\"\n",
" text = text.strip()\n",
"\n",
" # Strip thinking tags for models that use them (e.g., DeepSeek R1, QwQ)\n",
" if \"</think>\" in text:\n",
" text = text.split(\"</think>\")[-1].strip()\n",
" elif \"<think>\" in text and \"</think>\" not in text:\n",
" # Incomplete thinking block - try to find JSON after it\n",
" pass\n",
" \n",
" # Strip markdown code fences if present\n",
" if \"```\" in text:\n",
" lines = text.split(\"\\n\")\n",
" lines = [l for l in lines if not l.strip().startswith(\"```\")]\n",
" text = \"\\n\".join(lines).strip()\n",
" \n",
" # Find the LAST JSON array (thinking models put reasoning before answer)\n",
" # Search backwards for the final [...] pattern\n",
" start = text.rfind(\"[\")\n",
" if start == -1:\n",
" raise ValueError(\"No JSON array found\")\n",
" \n",
" # Find matching closing bracket\n",
" depth = 0\n",
" end = start\n",
" for i, char in enumerate(text[start:], start):\n",
" if char == \"[\":\n",
" depth += 1\n",
" elif char == \"]\":\n",
" depth -= 1\n",
" if depth == 0:\n",
" end = i + 1\n",
" break\n",
" \n",
" json_str = text[start:end]\n",
" data = json.loads(json_str)\n",
" \n",
" if not isinstance(data, list):\n",
" raise ValueError(\"Not a JSON array\")\n",
" invalid = [x for x in data if x not in ALLOWED_LABELS]\n",
" if invalid:\n",
" raise ValueError(f\"Invalid labels: {invalid}\")\n",
" return [label for label in LABEL_ORDER if label in set(data)]\n",
"\n",
"\n",
"def compute_cost(model_info: dict, input_tokens: int, output_tokens: int) -> float:\n",
" \"\"\"\n",
" Compute cost in USD.\n",
" Note: output_tokens includes reasoning/thinking tokens, all billed at output rate.\n",
" \"\"\"\n",
" return (input_tokens / 1e6) * model_info[\"price_in\"] + (output_tokens / 1e6) * model_info[\"price_out\"]\n",
"\n",
"\n",
"def get_mode_string(vendor: str, model: str = None, reasoning_effort: str = None, \n",
" thinking_budget: int = None, thinking_level: str = None) -> str:\n",
" \"\"\"Get mode string for filenames.\"\"\"\n",
" if vendor == \"openai\":\n",
" return f\"reasoning-{reasoning_effort}\" if reasoning_effort else \"chat\"\n",
" elif vendor == \"azure\":\n",
" return \"chat\"\n",
" elif vendor == \"anthropic\":\n",
" return f\"thinking-{thinking_budget}\" if thinking_budget else \"chat\"\n",
" elif vendor == \"google\":\n",
" if thinking_level == \"minimal\":\n",
" return \"chat\" # Only Flash supports minimal\n",
" elif thinking_level == \"low\":\n",
" return \"thinking-low\" # Still uses some thinking\n",
" elif thinking_level:\n",
" return f\"thinking-{thinking_level}\"\n",
" return \"thinking\" # default for Gemini 3 \n",
" elif vendor == \"fireworks\":\n",
" if reasoning_effort:\n",
" return f\"reasoning-{reasoning_effort}\"\n",
" # Auto-detected thinking models\n",
" if model and (\"thinking\" in model or \"r1\" in model or model == \"kimi-k2p5\"):\n",
" return \"thinking\"\n",
" elif vendor == \"xai\":\n",
" return \"chat\" if model and \"non-reasoning\" in model else \"reasoning\"\n",
" return \"chat\""
]
},
{
"cell_type": "markdown",
"id": "9b028d13-8bd2-47e7-9f5d-f8589136d6d2",
"metadata": {},
"source": [
"# Classify Sentence and Retry on Error"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3cf0a31f-7403-4861-b8ab-69e6fdd65434",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Single Sentence Classification with Retry\n",
"# -----------------------------\n",
"def classify_sentence(client, sentence: str, vendor: str, model: str, model_info: dict,\n",
" system_prompt: str, max_tokens: int = 64, max_retries: int = 5,\n",
" reasoning_effort: str = None, thinking_budget: int = None,\n",
" thinking_level: str = None) -> dict:\n",
" \"\"\"\n",
" Classify a single sentence with retry logic.\n",
" \n",
" Retry wait times: 3s, 3s, 6s, 12s, 30s\n",
" Rate limit errors: 60s\n",
" \"\"\"\n",
" # Import vendor-specific exceptions only when needed\n",
" if vendor == \"openai\":\n",
" rate_limit_errors = (RateLimitError,)\n",
" api_errors = (APIError,)\n",
" auth_errors = (AuthenticationError,)\n",
" elif vendor == \"azure\":\n",
" rate_limit_errors = (RateLimitError,)\n",
" api_errors = (APIError,)\n",
" auth_errors = (AuthenticationError,)\n",
" elif vendor == \"anthropic\":\n",
" import anthropic\n",
" rate_limit_errors = (anthropic.RateLimitError,)\n",
" api_errors = (anthropic.APIError,)\n",
" auth_errors = (anthropic.AuthenticationError,)\n",
" elif vendor == \"fireworks\":\n",
" import openai\n",
" rate_limit_errors = (openai.RateLimitError,)\n",
" api_errors = (openai.APIError,)\n",
" auth_errors = (openai.AuthenticationError,)\n",
" elif vendor == \"xai\":\n",
" import openai\n",
" rate_limit_errors = (openai.RateLimitError,)\n",
" api_errors = (openai.APIError,)\n",
" auth_errors = (openai.AuthenticationError,)\n",
" elif vendor == \"google\":\n",
" from google.api_core import exceptions as google_exceptions\n",
" rate_limit_errors = (google_exceptions.ResourceExhausted,)\n",
" api_errors = (google_exceptions.GoogleAPIError,)\n",
" auth_errors = (google_exceptions.Unauthenticated, google_exceptions.PermissionDenied)\n",
" else:\n",
" rate_limit_errors = ()\n",
" api_errors = ()\n",
" auth_errors = ()\n",
" \n",
" last_error = None\n",
" error_details = None\n",
" wait_times = {1: 3, 2: 3, 3: 6, 4: 12, 5: 30}\n",
" \n",
" for attempt in range(1, max_retries + 1):\n",
" try:\n",
" t0 = time.perf_counter()\n",
" \n",
" if vendor == \"openai\":\n",
" result = call_openai(client, sentence, model, system_prompt, \n",
" max_tokens, reasoning_effort)\n",
" elif vendor == \"azure\": \n",
" result = call_azure(client, sentence, model, system_prompt,\n",
" max_tokens)\n",
" elif vendor == \"anthropic\":\n",
" result = call_anthropic(client, sentence, model, system_prompt,\n",
" max_tokens, thinking_budget)\n",
" elif vendor == \"fireworks\":\n",
" result = call_fireworks(client, sentence, model, system_prompt,\n",
" max_tokens, reasoning_effort)\n",
" elif vendor == \"xai\":\n",
" result = call_xai(client, sentence, model, system_prompt,\n",
" max_tokens)\n",
" elif vendor == \"google\":\n",
" result = call_google(client, sentence, model, system_prompt,\n",
" max_tokens, thinking_level)\n",
" else:\n",
" raise ValueError(f\"Unknown vendor: {vendor}\")\n",
" \n",
" latency = time.perf_counter() - t0\n",
" labels = parse_labels(result[\"text\"])\n",
" \n",
" return {\n",
" \"labels\": labels,\n",
" \"response_text\": result[\"text\"],\n",
" \"input_tokens\": result[\"input_tokens\"],\n",
" \"output_tokens\": result[\"output_tokens\"],\n",
" \"internal_tokens\": result.get(\"reasoning_tokens\") or result.get(\"thinking_tokens\", 0),\n",
" \"response_tokens\": result.get(\"response_tokens\", result[\"output_tokens\"]),\n",
" \"cost_usd\": compute_cost(model_info, result[\"input_tokens\"], result[\"output_tokens\"]),\n",
" \"latency_sec\": latency,\n",
" \"attempts\": attempt,\n",
" \"error\": None,\n",
" \"raw_response\": result.get(\"raw_response\"),\n",
" }\n",
" \n",
" except auth_errors as e:\n",
" # Don't retry auth errors - they won't resolve\n",
" error_details = {\"type\": \"AuthenticationError\", \"attempt\": attempt, \"message\": str(e)}\n",
" print(f\" Authentication error (not retrying): {e}\")\n",
" return {\n",
" \"labels\": None,\n",
" \"response_text\": None,\n",
" \"input_tokens\": 0,\n",
" \"output_tokens\": 0,\n",
" \"internal_tokens\": 0,\n",
" \"response_tokens\": 0,\n",
" \"cost_usd\": 0.0,\n",
" \"latency_sec\": None,\n",
" \"attempts\": attempt,\n",
" \"error\": error_details,\n",
" \"raw_response\": None,\n",
" }\n",
" \n",
" except rate_limit_errors as e:\n",
" last_error = f\"RateLimitError: {e}\"\n",
" error_details = {\"type\": \"RateLimitError\", \"attempt\": attempt}\n",
" print(f\" Rate limit hit, waiting 60s...\")\n",
" time.sleep(60)\n",
" \n",
" except api_errors as e:\n",
" last_error = f\"APIError: {e}\"\n",
" error_details = {\"type\": \"APIError\", \"attempt\": attempt, \"message\": str(e)}\n",
" wait = wait_times.get(attempt, 3)\n",
" print(f\" API error (attempt {attempt}), waiting {wait}s...\")\n",
" time.sleep(wait)\n",
" \n",
" except (json.JSONDecodeError, ValueError) as e:\n",
" last_error = f\"ParseError: {e}\"\n",
" response_preview = result.get(\"text\", \"\")[:100] if 'result' in dir() and result else \"\"\n",
" error_details = {\"type\": \"ParseError\", \"attempt\": attempt, \"message\": str(e), \"response_preview\": response_preview}\n",
" wait = wait_times.get(attempt, 3)\n",
" print(f\" Parse error (attempt {attempt}): got '{response_preview}', waiting {wait}s...\")\n",
" time.sleep(wait)\n",
" \n",
" except Exception as e:\n",
" last_error = f\"{type(e).__name__}: {e}\"\n",
" error_details = {\"type\": type(e).__name__, \"attempt\": attempt, \"message\": str(e)}\n",
" print(f\" Error (attempt {attempt}): {last_error}\")\n",
" if attempt < max_retries:\n",
" time.sleep(wait_times.get(attempt, 3))\n",
" \n",
" return {\n",
" \"labels\": None,\n",
" \"response_text\": None,\n",
" \"input_tokens\": 0,\n",
" \"output_tokens\": 0,\n",
" \"internal_tokens\": 0,\n",
" \"response_tokens\": 0,\n",
" \"cost_usd\": 0.0,\n",
" \"latency_sec\": None,\n",
" \"attempts\": max_retries,\n",
" \"error\": error_details,\n",
" \"raw_response\": None,\n",
" }"
]
},
{
"cell_type": "markdown",
"id": "b7d9948c-1278-42bb-84f9-311f5a9bcaac",
"metadata": {},
"source": [
"# DataFrame Processing with Checkpointing"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5799882a-e057-4277-9583-94fbc1872481",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# DataFrame Processing with Checkpointing\n",
"# -----------------------------\n",
"def classify_dataframe(\n",
" df: pd.DataFrame,\n",
" sentence_col: str,\n",
" vendor: str,\n",
" model: str,\n",
" system_prompt: str = SYSTEM_PROMPT,\n",
" max_tokens: int = 64,\n",
" output_dir: str = \"./output\", ## default outout folder if none is passed\n",
" checkpoint_dir: str = \"./checkpoints\", ## default checkpoint folder if none is passed\n",
" save_interval: int = 100, ##### How often you want to save\n",
" run: int = 1,\n",
" reasoning_effort: str = None,\n",
" thinking_budget: int = None,\n",
" thinking_level: str = None,\n",
" max_consecutive_errors: int = 10,\n",
") -> pd.DataFrame:\n",
" \"\"\"\n",
" Classify all sentences with checkpointing.\n",
" \n",
" Directory structure:\n",
" - checkpoint_dir/{vendor}_{model}_{mode}_run{run}.pkl (interim saves)\n",
" - output_dir/Run{run}/{vendor}_{model}_{mode}_run{run}.pkl (final output)\n",
" - output_dir/Run{run}/{vendor}_{model}_{mode}_run{run}.log.json (run log)\n",
" \"\"\"\n",
" # Validate\n",
" if sentence_col not in df.columns:\n",
" raise ValueError(f\"Column '{sentence_col}' not found\")\n",
" if vendor not in VENDORS:\n",
" raise ValueError(f\"Unknown vendor: {vendor}\")\n",
" if model not in VENDORS[vendor][\"models\"]:\n",
" raise ValueError(f\"Unknown model '{model}' for vendor '{vendor}'\")\n",
" \n",
" model_info = VENDORS[vendor][\"models\"][model]\n",
" mode_str = get_mode_string(vendor, model, reasoning_effort, thinking_budget, thinking_level)\n",
" # Derive dataset tag from output_dir (e.g., \"./output/holdout\" → \"holdout\")\n",
" dataset_tag = os.path.basename(os.path.normpath(output_dir)).lower()\n",
" base_name = f\"{dataset_tag}_{vendor}_{model}_{mode_str}_run{run}\"\n",
" \n",
" # Create directories\n",
" os.makedirs(checkpoint_dir, exist_ok=True)\n",
" run_dir = f\"{output_dir}/Run{run}\"\n",
" os.makedirs(run_dir, exist_ok=True)\n",
" \n",
" checkpoint_path = f\"{checkpoint_dir}/{base_name}.pkl\"\n",
" \n",
" print(f\"Checkpoint dir: {checkpoint_dir}\")\n",
" print(f\"Output dir: {run_dir}\")\n",
" print(f\"Checkpoint file: {checkpoint_path}\")\n",
" \n",
" # Load existing checkpoint\n",
" results = []\n",
" processed_ids = set()\n",
" \n",
" if os.path.exists(checkpoint_path):\n",
" try:\n",
" checkpoint_df = pd.read_pickle(checkpoint_path)\n",
" successful = checkpoint_df[checkpoint_df['error'].isna()]\n",
" results = successful.to_dict('records')\n",
" processed_ids = set(successful['id'].values) if 'id' in successful.columns else set()\n",
" failed_count = len(checkpoint_df) - len(successful)\n",
" print(f\"Resumed: {len(results)} processed, {failed_count} errors to retry\")\n",
" except Exception as e:\n",
" print(f\"Could not load checkpoint: {e}\")\n",
" \n",
" # Initialize client\n",
" client = init_client(vendor)\n",
" \n",
" # Track errors\n",
" consecutive_errors = 0\n",
" start_time = datetime.datetime.now()\n",
" \n",
" # Process\n",
" for idx, row in tqdm(df.iterrows(), total=len(df), desc=f\"Run {run}: {model} ({mode_str})\"):\n",
" row_id = row.get('id', idx)\n",
" \n",
" if row_id in processed_ids:\n",
" continue\n",
" \n",
" sentence = str(row[sentence_col])\n",
" \n",
" result = classify_sentence(\n",
" client, sentence, vendor, model, model_info, system_prompt, max_tokens,\n",
" reasoning_effort=reasoning_effort, thinking_budget=thinking_budget,\n",
" thinking_level=thinking_level\n",
" )\n",
" \n",
" ordered_result = {\n",
" 'id': row_id,\n",
" 'sentence': sentence,\n",
" **result\n",
" }\n",
" results.append(ordered_result)\n",
" \n",
" if result['error']:\n",
" consecutive_errors += 1\n",
" if consecutive_errors >= max_consecutive_errors:\n",
" print(f\"\\nStopping: {max_consecutive_errors} consecutive errors\")\n",
" break\n",
" else:\n",
" consecutive_errors = 0\n",
" processed_ids.add(row_id)\n",
" \n",
" # Checkpoint save\n",
" if len(results) % save_interval == 0:\n",
" pd.DataFrame(results).to_pickle(checkpoint_path)\n",
" print(f\"\\nCheckpoint saved: {len(results)} processed\")\n",
" \n",
" # Final results\n",
" results_df = pd.DataFrame(results)\n",
" end_time = datetime.datetime.now()\n",
" \n",
" # Save to checkpoint (for resume if needed)\n",
" results_df.to_pickle(checkpoint_path)\n",
" \n",
" # Save to run directory\n",
" results_df.to_pickle(f\"{run_dir}/{base_name}.pkl\")\n",
" results_df.to_csv(f\"{run_dir}/{base_name}.csv\", index=False)\n",
" \n",
" # Create log file\n",
" error_count = int(results_df['error'].notna().sum())\n",
" success_count = len(results_df) - error_count\n",
" \n",
" log = {\n",
" \"run\": run,\n",
" \"vendor\": vendor,\n",
" \"model\": model,\n",
" \"supplier\": model_info[\"supplier\"],\n",
" \"mode\": mode_str,\n",
" \"reasoning_effort\": reasoning_effort,\n",
" \"thinking_budget\": thinking_budget,\n",
" \"start_time\": str(start_time),\n",
" \"end_time\": str(end_time),\n",
" \"duration_sec\": (end_time - start_time).total_seconds(),\n",
" \"total_sentences\": len(df),\n",
" \"processed_sentences\": len(results_df),\n",
" \"successful\": success_count,\n",
" \"errors\": error_count,\n",
" \"tokens\": {\n",
" \"input\": int(results_df['input_tokens'].sum()),\n",
" \"output\": int(results_df['output_tokens'].sum()),\n",
" \"internal\": int(results_df['internal_tokens'].sum()),\n",
" \"response\": int(results_df['response_tokens'].sum()),\n",
" },\n",
" \"cost_usd\": float(results_df['cost_usd'].sum()),\n",
" \"pricing\": {\n",
" \"input_per_1m\": model_info[\"price_in\"],\n",
" \"output_per_1m\": model_info[\"price_out\"],\n",
" },\n",
" \"files\": {\n",
" \"checkpoint\": checkpoint_path,\n",
" \"output_pkl\": f\"{run_dir}/{base_name}.pkl\",\n",
" \"output_csv\": f\"{run_dir}/{base_name}.csv\",\n",
" }\n",
" }\n",
" \n",
" log_path = f\"{run_dir}/{base_name}.log.json\"\n",
" with open(log_path, 'w') as f:\n",
" json.dump(log, f, indent=2)\n",
" \n",
" print(f\"\\nRun {run} complete:\")\n",
" print(f\" Sentences: {success_count}/{len(results_df)}\")\n",
" print(f\" Tokens - Input: {log['tokens']['input']:,}, Output: {log['tokens']['output']:,} (Internal: {log['tokens']['internal']:,})\")\n",
" print(f\" Cost: ${log['cost_usd']:.4f}\")\n",
" print(f\" Log: {log_path}\")\n",
" \n",
" return results_df"
]
},
{
"cell_type": "markdown",
"id": "2e6d17c0-135a-42c6-8815-db44f8445894",
"metadata": {},
"source": [
"# Multi-Run Processing"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fba35e3e-e3ab-4f23-9b46-749ccecbe66d",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Multi-Run Processing\n",
"# -----------------------------\n",
"def run_classification(\n",
" df: pd.DataFrame,\n",
" sentence_col: str,\n",
" vendor: str,\n",
" model: str,\n",
" runs: list = [1, 2, 3], # default if nothing passed\n",
" output_dir: str = \"./output\", # default if nothing passed\n",
" checkpoint_dir: str = \"./checkpoints\", # default if nothing passed\n",
" system_prompt: str = SYSTEM_PROMPT,\n",
" max_tokens: int = 64,\n",
" save_interval: int = 100,\n",
" reasoning_effort: str = None,\n",
" thinking_budget: int = None,\n",
" thinking_level: str = None,\n",
") -> dict:\n",
" \"\"\"\n",
" Run classification multiple times.\n",
" \n",
" Directory structure:\n",
" output_dir/\n",
" Run1/\n",
" {vendor}_{model}_{mode}_run1.pkl\n",
" {vendor}_{model}_{mode}_run1.csv\n",
" {vendor}_{model}_{mode}_run1.log.json\n",
" Run2/\n",
" ...\n",
" checkpoint_dir/\n",
" {vendor}_{model}_{mode}_run1.pkl\n",
" ...\n",
" \n",
" Returns dict of {run: results_df}\n",
" \"\"\"\n",
" model_info = VENDORS[vendor][\"models\"][model]\n",
" mode_str = get_mode_string(vendor, model, reasoning_effort, thinking_budget, thinking_level)\n",
" job_start_time = time.perf_counter()\n",
" \n",
" print(f\"\\n{'='*60}\")\n",
" print(f\"CLASSIFICATION JOB\")\n",
" print(f\"{'='*60}\")\n",
" print(f\"Vendor: {vendor}\")\n",
" print(f\"Model: {model}\")\n",
" print(f\"Supplier: {model_info['supplier']}\")\n",
" print(f\"Mode: {mode_str}\")\n",
" print(f\"Pricing: ${model_info['price_in']}/M in, ${model_info['price_out']}/M out\")\n",
" print(f\"Runs: {runs}\")\n",
" print(f\"Sentences: {len(df)}\")\n",
" print(f\"Output dir: {output_dir}\")\n",
" print(f\"Checkpoint dir: {checkpoint_dir}\")\n",
" print(f\"{'='*60}\\n\")\n",
" \n",
" all_results = {}\n",
" \n",
" for run in runs:\n",
" print(f\"\\n{'='*40}\")\n",
" print(f\"RUN {run}\")\n",
" print(f\"{'='*40}\")\n",
" \n",
" results_df = classify_dataframe(\n",
" df, sentence_col, vendor, model, system_prompt, max_tokens,\n",
" output_dir=output_dir,\n",
" checkpoint_dir=checkpoint_dir,\n",
" save_interval=save_interval,\n",
" run=run,\n",
" reasoning_effort=reasoning_effort,\n",
" thinking_budget=thinking_budget,\n",
" thinking_level=thinking_level,\n",
" )\n",
" \n",
" all_results[run] = results_df\n",
" \n",
" all_results[run] = results_df\n",
" \n",
" # Print summary\n",
" job_runtime = time.perf_counter() - job_start_time\n",
" print(f\"\\n{'='*60}\")\n",
" print(\"SUMMARY\")\n",
" print(f\"{'='*60}\")\n",
" total_cost = sum(r['cost_usd'].sum() for r in all_results.values())\n",
" print(f\"Total runs: {len(runs)}\")\n",
" print(f\"Total cost: ${total_cost:.4f}\")\n",
" print(f\"Total runtime: {job_runtime:.1f}s ({job_runtime/60:.1f}m)\")\n",
" print(f\"Output location: {output_dir}/Run*/\")\n",
" \n",
" return all_results"
]
},
{
"cell_type": "markdown",
"id": "11c8f621-6826-4e71-b38c-32b1bd8de480",
"metadata": {},
"source": [
"# Load Data\n",
"\n",
">**If you are running this in your local computer:** Subfolders will be automaticallty created inside the folder that this notebook is in. All files will be saved in those local folders/subfolders\n",
"\n",
"> **If you are on Google CoLab:**: FIRST, you will need to connect your google drive and navigate to the folder that this noteobok is in. Then, the code will create subfodlers inside the folder you navigated to on your google drive. All files will be saved in those local folders/subfolders on your google drive."
]
},
{
"cell_type": "markdown",
"id": "a90bfe58-576c-4b7a-88ff-b37ee816a391",
"metadata": {},
"source": [
"##### **IMPORTANT for Google CoLab**\n",
"\n",
"If you want to save to your google drive, you have to connect it first and then navigate to this current folder:\n",
"\n",
"- Import google drive and connect\n",
"- Change dir to folder that this notebook is in\n",
"\n",
"***Copy this code in a new code cell, modify folder as needed (if you don't want \"Project\") and run:***\n",
"```\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')\n",
"\n",
"import os\n",
"\n",
"# Construct the full path to the desired folder within Google Drive\n",
"drive_path = '/content/drive/My Drive/Project'\n",
"\n",
"# Create the directory if it doesn't exist (optional, but good practice)\n",
"if not os.path.exists(drive_path):\n",
" os.makedirs(drive_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "336f43ce-f1b2-4d2f-8130-737ef8eea75d",
"metadata": {},
"outputs": [],
"source": [
"# Example Sentences\n",
"# sentences = ['The annual growth in our provision for uncollectible accounts was primarily attributable to transitioning multiple external payors from recognizing revenue on a cash basis to an accrual basis, thereby aligning with customer-related payment practices.', 'We handle certain sales-type leases internally, particularly those associated with facilities serving U.S. government hospital clients.', 'We expect that this acquisition will position us to broaden our offerings in the private label accessories segment, catering to customers seeking value-oriented solutions.', 'Additionally, the Tax Cuts and Jobs Act was enacted into law by the President in December 2017.', 'Valuations derived from the Black-Scholes model can vary significantly depending on the assumptions made regarding volatility and the duration of the underlying instruments.', 'Our audience is engaged across multiple channels, such as digital properties, print publications, as well as broadcast and streaming media.', 'The forecast incorporates managements most informed assumptions regarding anticipated economic and market trends for the duration of the projection, encompassing anticipated changes in sales growth, cost structures, operating margin expectations, and future cash outlays.', 'As of December 31, 2007, a single credit remains unassigned to a particular reserve, which, absent any favorable developments, is at risk of defaulting imminently and may lead to the filing of a claim.', \"Our model incorporates assumptions regarding anticipated stock price fluctuations, utilizing historical volatility metrics, the applicable risk-free rate derived from the treasury yield curve, the projected duration of equity awards informed by past exercise trends and termination actions post-vesting, along with anticipated dividend equivalents to be paid during the award's estimated term, given that our stock appreciation rights participate in dividends.\", 'The Company conducts comprehensive physical counts of inventory across all store and warehouse locations at least once per year, and correspondingly updates the reported merchandise inventory to reflect the results of these assessments.', 'Substantial growth was realized across categories such as basketball and lifestyle footwear, as well as branded apparel and specialized cleats for wrestling, volleyball, and soccer.', 'Actuarial gains realized in 2007 will diminish the unrecognized loss allocated across the projected average remaining service years of active plan members, with such periods differing by plan from 6 to 23 years.', 'Platform utilization expanded across enterprise segments, reflecting enhanced operational efficiency and streamlined integration of internal systems, during fiscal 2020 for the technology company.', 'The advances carried maturities at assorted dates up to 2014, with interest rates ranging from 0.3% to 3.4%.', 'Furthermore, significant investments were made to upgrade essential operational systems, establish a comprehensive disaster recovery strategy, and strengthen compliance protocols designed to better serve our customers.', 'Subscription offerings expanded during fiscal 2023, reflecting successful margin optimization and process improvements within our technology platform, as we continued to innovate and refine our approach.', 'Provisions for loss contingencies are determined by managements assessment regarding the probability of adverse results and the estimated magnitude of potential losses.', 'Before Discover Bank acquired SLC on December 31, 2010, SLC maintained a contractual relationship with Citi that facilitated the origination and ongoing management of private student loans for individuals.', 'Overview of BorgWarner Inc.s Financial Condition and Operating Results Managements Discussion and Analysis', 'We attribute the rise in RELIC watch sales chiefly to strategic modifications in our product offerings and the incremental expansion of our customer base achieved toward the end of Fiscal 2003.', 'PECO Electric Operating Statistics and Revenue Summary The following outlines PECOs electric sales metrics and associated revenue information: (a) Full service represents energy provided to customers receiving electricity under standard tariffed rates.', 'As part of the 2009 Supervisory Capital Assessment Program, the Federal Reserve Board enhanced its review of the capital sufficiency of selected major bank holding companies by utilizing an alternative measure of Tier 1 capital referred to as Tier 1 common equity.', 'During the years ended 2008, 2007, and 2006, none of our customers individually represented 10% or greater of our overall revenue.', 'Other than our operating leases, which mainly pertain to restaurant locations, we do not engage in any off-balance sheet transactions.', 'Policyholders have the option to discontinue coverage as a result of their departure from medical practice due to retirement, disability, or passing.', 'Red Robin Gourmet Burgers®, Inc., a Delaware entity along with its subsidiaries (referred to herein as “Red Robin,” the “Company,” or by similar terms), is engaged predominantly in the development, operation, and franchising of casual-dining restaurants, totaling 514 outlets across North America as of the close of the fiscal year on December 28, 2014.', 'As of December 31, 2008, 51% of the outstanding home equity lines of credit were collateralized by properties located in New York State, while 21% and 26% were backed by properties in Pennsylvania and the Mid-Atlantic region, respectively.', \"The decline in property catastrophe offerings was attributable to non-renewed contracts and decreased participation levels, influenced by prevailing market dynamics and an increased reliance on retrocessional arrangements, whereas non-catastrophe property business during 2014 comprised an incoming unearned premium portfolio transfer of $50.2 million from Gulf Reinsurance Limited ('Gulf Re'), which the Company acquired in 2015.\", 'IPLs operating performance reflected a $149 million decline in earnings allocable to common shareholders during 2008 and a $118 million improvement in 2007, which was largely attributable to a post-tax gain of $123 million stemming from the divestiture of IPLs electric transmission assets in 2007.', 'Incorporating remote work practices forms an integral aspect of our overall business continuity strategy.', 'Upon the sale of a gift card, we record a corresponding liability on our balance sheet.', 'Managements Discussion and Analysis Procedures for the Company and its Subsidiaries.', 'Our Carters, Just One Year, and Child of Mine brands are made available to third parties through licensing agreements.', 'SPP has initiated approvals for the development of transmission infrastructure designed to transport renewable energy from the wind-producing regions of western Oklahoma, the Texas Panhandle, and western Kansas to major demand centers, as part of its strategy to expand transmission capacity in these areas.', 'Consequently, we record revenue at the point when ownership is conveyed to our customer.', 'The credits provided will fluctuate from one quarter to the next, and ultimately, the cumulative benefits passed on to customers in the form of reduced electric rates will align with the amounts applied toward federal income tax obligations.', 'Outcomes may vary, and the divergence from these projections could be significant depending on alternative assumptions or circumstances.', 'Accruals associated with chargebacks and rebates involve a significant degree of estimation within our sales processes.', 'In the quarter concluding on September 30, 2013, we finalized a partnership arrangement with LPC MM Monrovia, LLC, an independent third party, specifically to facilitate the acquisition of a real estate asset in Monrovia, California, and to initiate the development of construction documentation for enhancements to the site.', 'Periodic legal proceedings or regulatory inquiries arise against the Company as a normal aspect of conducting its insurance operations.', 'In the interim between lease termination and final real estate disposition, we provided significant loan-based support to the hospital operator, enabling them to meet operational cash needs while awaiting reimbursement for patient services from Medicare and other payors.', 'Pursuant to the CRS agreement, close to one-fourth of the contracts total value becomes payable by the customer and is subject to collection solely after launch and delivery objectives have been achieved for each of the eight CRS missions.', 'In 2008, the Company completed toll processing of 126,000 ounces of PGMs, an increase compared to the 112,000 ounces processed under toll arrangements in the prior year.', 'Operations within our ready-mixed concrete, precast concrete, and ancillary concrete businesses experience fluctuations attributable to seasonal patterns.', 'Historically, our highest working capital requirements arise in the latter half of the year, as accounts receivable and inventory expand due to elevated activity during the holiday sales period, and inventory builds ahead of anticipated factory shutdowns for Chinese New Year observances.', 'Cost of sales primarily includes expenses related to products, such as major aircraft and engine components, along with direct labor, overhead, and costs associated with maintaining aircraft.', 'Our proportional interest in the mines under our management affords us an annual rated pellet production capacity of 22.9 million tons, which equates to roughly 28% of the aggregate pellet capacity available throughout North America.', 'For the fiscal year ended December 31, 2009, a total of 338 da Vinci Surgical Systems were sold, representing a slight increase from the 335 units sold in the prior year ended December 31, 2008.', 'During the 2009 fiscal year, the Company incurred costs attributable to flooding totaling $7.6 million, while recognizing $16.7 million in recoveries from insurance claims.', 'The rise was tempered in part by the enhanced redeployment incentives offered through our equipment lease initiative to new customers.', 'The Company provides defined benefit pension arrangements, primarily benefiting salaried and management employees, and oversees the associated post-retirement medical plan accounting.', 'The gross profit margin declined to 30.7% in 2013 as compared to 31.4% in the previous year, largely attributable to a rise in sales from private label and international segments, which traditionally yield lower margins.', 'A substantial portion of our multi-family lending activity is directed toward longstanding property owners whose apartment buildings operate under rent control, resulting in rental rates that are lower than prevailing market levels.', 'Growth in distillery product revenues was driven by elevated unit volumes and enhanced pricing for food grade alcohol used in both beverage and industrial sectors, alongside strengthened pricing for fuel grade alcohol.', 'In October 2005, we completed delivery of the system to the customer.', 'On July 29, 2014, the Company completed an issuance of Senior Notes totaling $300 million, maturing on February 1, 2025, with a fixed interest rate of 5.375% and sold at par value, referred to herein as the 2014 Notes.', 'Expenses associated with reactivating these subscribers were recognized, and these customers were categorized as gross new DISH TV subscriber additions for the year ended December 31, 2017, with related costs captured under “Subscriber acquisition costs” within our Consolidated Statements of Operations and Comprehensive Income (Loss) and/or as “Purchases of property and equipment” in our Consolidated Statements of Cash Flows.', \"Refer to 'Item 8' under the section 'Loans Receivable.'\", 'Valuation of swaps, interest rate swaptions, and option contracts is determined through commonly accepted industry models that estimate the present value of anticipated derivative cash flows, reflecting both prevailing and anticipated market conditions.', 'On February 5, 2010, a total of 68 workover rigs were staffed and were either in service or subject to ongoing marketing efforts.', 'As a result of our commitment to making education more affordable for our students, we project that the typical debt burden for graduates of the Art Institutes has declined by nearly 15% since 2010.', 'The market valuation of distressed inventory is determined utilizing prior sales patterns for specific product categories, prevailing market dynamics, overall economic factors, and the worth of outstanding in-house orders associated with prospective sales of such inventory.', 'The Company factors such discounts and rebates into its determination of transaction price, recording them as deductions from gross sales.', 'Within the Caribbean region, we operate vertically integrated utilities as well as generation facilities, each governed by multi-year agreements established in partnership with governmental entities.', 'Nonetheless, despite Synovus Bank meeting all required quantitative capital standards, regulatory provisions allow for reclassification of an institution to a less favorable capital category, dependent on supervisory considerations beyond capital metrics.', 'For most of the securities measured through dealers and pricing services, we source several independent valuations, all of which are not legally binding on either our company or our counterparties.', '59 Overseas Shipholding Group, Inc. serves as the registry for this section.', 'Our strategy is to consistently keep U.S. wholesaler inventories for our products at or below a one-month average, supporting reliable supply for our customers.', 'Aside from the $202.9 million gain recognized in 2011 from the sales of Big Sandy and Langley, the year-over-year decline also reflected roughly 25% reduced realized prices for sales volumes, a 23% increase in the rate of production depletion, and elevated other operating costs, which were partially counterbalanced by a 33% rise in production output, a 30% expansion in gathered volumes, and increased transmission income.', '• Establishing further criteria for evaluating applications for disability-related benefit claims.', 'For fiscal year 2016, revenue totaled $28.2 million, representing a decline of $11.3 million, or 28.5%, compared to the prior years revenue of $39.5 million.', 'During 2010, the Company divested its Orthopaedic Implant production site located in Caen, France, resulting in a $24.3 million gain that was recognized within selling, general, and administrative expenses.', 'Illustrative instances of indemnification obligations under the Separation and Distribution Agreement and associated ancillary agreements include Navients duty to compensate the Company and the Bank for any liabilities, costs, or expenses they may face resulting from actions or potential actions concerning pre-Spin-Off SLM and its subsidiaries servicing, operational, and collections activities related to Private Education Loans and FFELP Loans that were held as assets by either the Bank or Navient at the time of the Spin-Off, contingent upon written notice being furnished to Navient by April 30, 2017, the third anniversary of the Spin-Off.', 'Growth in platform usage was observed, driven by improved internal workflows and resource allocation, during fiscal 2022 for the telecommunications company.', 'Acquisitions pursuant to the Plan may occur through open market transactions or privately arranged agreements, conducted periodically and in accordance with relevant legal requirements, inclusive of Rule 10b-18 under the amended Securities Exchange Act of 1934.', 'We ended our bridge revolving credit facility, originally established on June 13, 2007, with five U.S. and international banking institutions in February 2008.', \"Our management team regularly assesses the company's cost structure to confirm that operational expenditures are managed effectively without compromising the quality of service provided to customers.\", 'The companys reach broadened in fiscal 2021, reflecting enhanced productivity and optimization of internal processes in expanded territories.', 'A portion of these rising costs was absorbed by the company, resulting in an average sales price uptick to customers of just $0.32 per kilogram.', 'The company saw advances in payment collection efficiency, reflecting cost reduction initiatives and automation of reconciliation tasks.', 'Advancements in technology have the potential to mitigate these factors or expand the availability of mineral resources.', 'Each of our two plans designed for small office and home office use provides a single business fax line at no extra cost, with supplementary fax lines available for a monthly fee.', 'In September 2011, the European Commission approved VIBATIV® for use in adult patients with hospital-acquired pneumonia, including cases linked to ventilators, where MRSA is known or suspected and alternative therapies are deemed inappropriate, thereby expanding therapeutic options for these individuals.', 'These trends are subject to modification due to factors such as the establishment of additional schools, launch of innovative programs, rising adult student enrollment, or potential acquisitions.', 'Funding requests, which generally arise at intervals of four to eight weeks, trigger the execution of necessary inspections and formal evaluations.', 'Assessing our yearly tax obligations and analyzing our tax strategies necessitates considerable judgment and expertise.', 'On March 24, 2006, the Grindle plaintiffs retracted their request to intervene, doing so without affecting their rights.', 'At December 31, 2017, the balance of commercial and industrial loans, which includes owner-occupied commercial properties, grew by $0.3 billion, reaching a total of $1.6 billion, reflecting support for our business customers financing needs.', 'After the merger, ARRIS retained complete indirect ownership of both ARRIS Group and Pace as wholly-owned subsidiaries.', '•On August 17, 2012, we provided a first mortgage loan totaling $46.0 million, secured by a Hilton hotel consisting of 315 rooms located in Rockville, Maryland.', 'Our Managements Discussion and Analysis of Financial Condition and Results of Operations highlights our position as a foremost provider of independent, technology-driven portfolio management solutions, investment guidance, and retirement income services, serving participants predominantly within employer-sponsored defined contribution plans, including 401(k) arrangements.', 'The structure of our per-minute charges is designed to encompass the entirety of services provided.', 'On June 30, 2009, nearly 83% of our holdings, excluding cash balances, possessed maturities shorter than one year.', 'Distribution of Portfolio by Geography: The following table presents the Companys operating square footage segmented by region at December 31, 2011 (in thousands). Industry Credit Risk Profile: The subsequent data illustrates the composition of our tenant portfolio by industry as of December 31, 2011.', '2 Includes additional products such as consumer revolving credit, installment loans, and consumer lease finance options tailored to customer needs.', 'Currently anticipated capital expenditures for 2006 total approximately $568.9 million, encompassing $24.0 million allocated to startup activities, resolution of the thruster defect previously detailed, and modifications mandated by customers for the GSF Development Driller I, $237.3 million dedicated to construction of a new semisubmersible unit, $45.0 million earmarked for repairs to our rig fleet due to hurricane impacts, $124.0 million assigned for significant fleet enhancements, $107.8 million directed toward additional equipment acquisitions and replacements, $17.3 million related to capitalized interest, and $13.5 million (net of intersegment eliminations) for oil and gas operations.', 'The Company plans to maintain its practice of issuing quarterly dividends, contingent upon Board approval, sufficient capital resources, and an assessment that such distributions remain advantageous for its shareholders.', 'As of December 31, 2008, the Company had no borrowings outstanding under its credit facility, with a total available capacity of roughly $48.4 million, reflecting the deduction for a $7.0 million letter of credit issued to BCBSF/HOI.', 'The progression of our product offerings relies on the availability and consistent performance of software sourced from external vendors.', 'The Company establishes BESP for its product deliverables through a weighted average pricing methodology, initiated by an examination of historical data on standalone sales transactions.']\n",
"# df = pd.DataFrame({'sentences': sentences})\n",
"# df=df.head(3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20f2dab2-7738-435f-bdff-f9a9e3a2eb2c",
"metadata": {},
"outputs": [],
"source": [
"# Load Data: Preliminary Labeling of all available data\n",
"df = pd.read_csv(\"source_text/10K_unlabeled.csv\")\n",
"df = df.sample(n=10).reset_index(drop=True) # different sample each run\n",
"df"
]
},
{
"cell_type": "markdown",
"id": "dbc7642b-d0cd-4b73-836d-893c54531155",
"metadata": {},
"source": [
"# Set Output Paths\n",
"- Preliminary (\"OUTPUT_DIR\") vs. Holdout (\"HOLDOUT_DIR\") vs. Train (\"TRAIN_DIR\") sets\n",
"- ***Make sure not to accidentally overwrite your labeled data!***\n",
"- When you query a model (below in Run Models), you need to specifiy the output path."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a37fc60e-8234-4deb-b101-de383d378f09",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# OUTPUT CONFIGURATION - SET for Base, Holdout, Train\n",
"# =============================================================================\n",
"CHECKPOINT_DIR = \"./checkpoints\"\n",
"OUTPUT_DIR = \"./output\"\n",
"PRELIM_DIR = f\"{OUTPUT_DIR}/preliminary\"\n",
"HOLDOUT_DIR = f\"{OUTPUT_DIR}/holdout\"\n",
"TRAIN_DIR = f\"{OUTPUT_DIR}/train\""
]
},
{
"cell_type": "markdown",
"id": "59c343be-dbbe-4c28-be41-44c518ed2526",
"metadata": {},
"source": [
"# Run Models\n",
"- Set desired parameters\n",
"- Set runs (currently set to 1 run per model for examples): runs=[1] vs runs=[1, 2, 3]\n",
"- Uncomment the models you what you want to run! \n",
"- Run different models? Make sure to define them above in the vendor/model configuration first.\n",
" - Not all models will run with this code.\n",
" - If you get errors, you need to debug and develop code further\n",
" - (I'd use Claude Opus 4.5 Thinking for that, but other genAI models may work as well / even better).\n",
"\n",
"> ***IMPORTANT*** If you want to rerun a model, make sure to delete its files from the checkpoints folder first. Otherwise, it will skip all examples (e.g., sentences) that the previous run already labeled (which could be all) and you don't get updated results.\n",
"\n",
"> **MORE IMPORTANT** Running this code will cost you API credits (and requires you to ahve accounts with the providers). You will need to supply your own API keys. Beware that you may be subject to rate limits (how many queries you can send per minute) and which models you can use (OpenAI, for example, requires you to verify your identidy with an ID to access many models). Regardless, every time you execute this code, you will drain your API credits = real money! Thus, make wise decisions about what and how much to label."
]
},
{
"cell_type": "markdown",
"id": "e4503bd1-0d03-4f09-8897-fc06be07d887",
"metadata": {},
"source": [
"## Labeling Approach:\n",
"\n",
"**What we ultimately need:**\n",
"- A holdout set (1000 examples)\n",
"- A training set (15k examples)\n",
"- The ability to test and train all classes / labels\n",
" - Need to have at least some balance in holdout set so that every class (label) is represented (e.g., at least 100 times)\n",
" - Need enough examples per class (label) in train set that fine-tuned model can learn them\n",
"> **Challenge**\n",
"> - How do we know that we have enough examples (i.e., sentences) per class (label)?\n",
"\n",
"> **Idea**\n",
"> - One possible approach is to label all data only once with a reasonably fast and inexpensive genAI model to get an idea about class (label) distribution. Depending how good the model is (which we cannot easily know yet unless we have a couple dozen examples that we manually constructed as preliminary evaluation set), we have at least a *directional idea* of which sentences may belong to which classes / have which labels.\n",
"> - We can then use this *directional idea* to ***construct a holdout set*** with at least N=100 (or more?) examples per class (label). This will be better than random sampling (unless classes (labels) are balanced in full data set, which is unlikely). It will still not be robust, but *at least give us some idea*.\n",
"> - We need to make sure to remove all examples that we put in the holdout set from the rest of our data (i.e., no holdout set example should also be in the training set to ***prevent leakage***). "
]
},
{
"cell_type": "markdown",
"id": "a2c07cc9-3961-4ff4-82db-d75006dfa14c",
"metadata": {},
"source": [
"### WARNING: MAKE SURE TO **PASS** THE CORRECT OUTPUT DIR!\n",
"What are you labeling? \n",
"- Prelimiary full data set (once to investigate class balance) --> PRELIM_DIR\n",
"- Holdout set --> HOLDOUT_DIR\n",
"- Train set --> TRAIN_DIR"
]
},
{
"cell_type": "markdown",
"id": "27552dca-0f9d-4feb-b2f0-0fe63df69d51",
"metadata": {},
"source": [
"### Google CoLab Runtime Timeouts\n",
"If you have a slow API and/or you are lableing many sentences (texts), Google CoLab may time-out or shut down your runtime. This will abort the labeling process. You can resume anytime, but you need to restart and run your notebook again. For this purpose, I've added checkpointing so that results are saved every N=100 sentences (texts) and the code will look for an intermis checkpoint and pickup from there (it will also find and retry errors where nothing valid was returned). *This may be less of a problem with CoLap PRO, which you can sign-up for free as a student.*\n",
"\n",
"> If you can, it may be well worth installing python and jupyter notebooks on your local computer. You won't face the timeout issues then and your code will query the APIs as long as your computer is connected to the internet (and running)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e892c00f-d577-4c86-b913-85dc75f51ff2",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# UNC Main Campus Computing on MS Azure\n",
"# Models: gpt-4.1, gpt-4o\n",
"# Modes: chat only!!! (temperature=0 set with openAI models) \n",
"# Notes: Experimental by Dr. D. To get odd number of lables per sentence across two models, I queried the first model 4 times (4 + 3 = 7)\n",
"# =============================================================================\n",
"\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"azure\",\n",
"# model=\"gpt-4.1\",\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"azure\",\n",
"# model=\"gpt-4o\",\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55eefc77-28e4-4fbe-8e21-80c26a6d041a",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# OPENAI\n",
"# Models: gpt-5.2, gpt-5, gpt-5-mini, gpt-5-nano, gpt-4.1, gpt-4o\n",
"# Modes: chat (temperature=0) or reasoning (reasoning_effort)\n",
"# reasoning_effort: \"low\", \"medium\", \"high\"\n",
"# =============================================================================\n",
"\n",
"# # OpenAI chat mode (temperature=0)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"openai\",\n",
"# model=\"gpt-4.1\", # gpt-5.2, gpt-5, gpt-5-mini, gpt-5-nano, gpt-4.1, gpt-4o --> make sure these are defined with pricing in vendor/model configuration\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# OpenAI reasoning mode\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"openai\",\n",
"# model=\"gpt-5.2\", # gpt-5.2, gpt-5, gpt-5-mini, gpt-5-nano, gpt-4.1, gpt-4o --> make sure these are defined with pricing in vendor/model configuration\n",
"# reasoning_effort=\"high\", # \"low\", \"medium\", \"high\"\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "935a4056-5e2d-47da-add1-f38a2735844d",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# ANTHROPIC\n",
"# Models: claude-opus-4-5-20251101, claude-sonnet-4-5-20250929, \n",
"# claude-haiku-4-5-20251001, claude-sonnet-4-20250514\n",
"# Modes: chat (temperature=0) or thinking (thinking_budget)\n",
"# thinking_budget: min 1024, billed as output tokens\n",
"# =============================================================================\n",
"\n",
"# # Anthropic chat mode (temperature=0)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"anthropic\",\n",
"# model=\"claude-sonnet-4-5-20250929\", # opus-4-5, sonnet-4-5, haiku-4-5, sonnet-4\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# # Anthropic thinking mode\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"anthropic\",\n",
"# model=\"claude-sonnet-4-5-20250929\", # opus-4-5, sonnet-4-5, haiku-4-5, sonnet-4\n",
"# thinking_budget=2048, # min 1024, billed as output tokens\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5174d2b0-01b8-4971-8ead-9fb23bd416f3",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# FIREWORKS\n",
"# Models: deepseek-v3p2, deepseek-r1-0528, qwen3-vl-235b-a22b-instruct,\n",
"# qwen3-vl-235b-a22b-thinking, kimi-k2p5\n",
"# Modes: auto-detected from model name\n",
"# - Chat: deepseek-v3p2, qwen3-vl-235b-a22b-instruct\n",
"# - Thinking: deepseek-r1-0528, qwen3-vl-235b-a22b-thinking, kimi-k2p5\n",
"# =============================================================================\n",
"\n",
"# # DeepSeek V3.2 (chat)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"fireworks\",\n",
"# model=\"deepseek-v3p2\",\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# # DeepSeek R1 (thinking)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"fireworks\",\n",
"# model=\"deepseek-r1-0528\",\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# # Qwen3 VL Instruct (chat)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"fireworks\",\n",
"# model=\"qwen3-vl-235b-a22b-instruct\",\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# # Qwen3 VL Thinking (thinking)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"fireworks\",\n",
"# model=\"qwen3-vl-235b-a22b-thinking\",\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# # Kimi K2.5 (thinking)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"fireworks\",\n",
"# model=\"kimi-k2p5\",\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4c2689d-f405-4061-8844-a83b2a53c80d",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# XAI (GROK)\n",
"# Models: grok-4, grok-4-1-fast-reasoning, grok-4-1-fast-non-reasoning\n",
"# Modes: auto-detected from model name\n",
"# - Reasoning: grok-4, grok-4-1-fast-reasoning\n",
"# - Chat: grok-4-1-fast-non-reasoning\n",
"# =============================================================================\n",
"\n",
"# # Grok 4 (reasoning)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"xai\",\n",
"# model=\"grok-4\", # grok-4, grok-4-1-fast-reasoning, grok-4-1-fast-non-reasoning\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# # Grok 4.1 Fast Reasoning\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"xai\",\n",
"# model=\"grok-4-1-fast-reasoning\",\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# # Grok 4.1 Fast Non-Reasoning (chat)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"xai\",\n",
"# model=\"grok-4-1-fast-non-reasoning\",\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5746b7c9-c3d0-4b0f-92a9-5cf087509421",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# GOOGLE GEMINI\n",
"# Models: gemini-3-pro-preview, gemini-3-flash-preview, \n",
"# gemini-2.5-pro, gemini-2.5-flash\n",
"# Modes: controlled via thinking_level\n",
"# - Pro: \"low\", \"high\" (default if not specified)\n",
"# - Flash: \"minimal\", \"low\", \"medium\", \"high\" (default if not specified)\n",
"# Note: Even \"low\" uses some thinking tokens. Only Flash \"minimal\" truly minimizes.\n",
"# =============================================================================\n",
"\n",
"# # Gemini 3 Pro - high thinking (default)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"google\",\n",
"# model=\"gemini-3-pro-preview\", # gemini-3-pro-preview, gemini-3-flash-preview, gemini-2.5-pro, gemini-2.5-flash\n",
"# thinking_level=\"high\", # Pro: \"low\", \"high\"\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# # Gemini 3 Pro - low thinking\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"google\",\n",
"# model=\"gemini-3-pro-preview\",\n",
"# thinking_level=\"low\", # Pro: \"low\", \"high\"\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# # Gemini 3 Flash - high thinking (default)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"google\",\n",
"# model=\"gemini-3-flash-preview\",\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )\n",
"\n",
"# # Gemini 3 Flash - minimal thinking (closest to chat mode)\n",
"# results = run_classification(\n",
"# df,\n",
"# sentence_col=\"sentences\",\n",
"# vendor=\"google\",\n",
"# model=\"gemini-3-flash-preview\",\n",
"# thinking_level=\"minimal\", # Flash: \"minimal\", \"low\", \"medium\", \"high\"\n",
"# runs=[1],\n",
"# output_dir=OUTPUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
"# checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
"# )"
]
},
{
"cell_type": "markdown",
"id": "b2e948ca-dafa-497d-8449-18af62945c33",
"metadata": {},
"source": [
"# Preliminary Labels and Holdout Set\n",
"Once we labeled out data ***one time*** with a model, we can get a first idea of class (label) distribution and construct our holdout set accordingly.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74af47ed-ce00-4b45-beed-a0316a067589",
"metadata": {},
"outputs": [],
"source": [
"# I will use GPT 4.1 for this purpose that UNC hosted on Azure. \n",
"# You probably will not have access to this API and model.\n",
"# Pick a fast and not too expensive model that you feel confident can do an okay job (maybe get at least 70% correct)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0128128-bcc9-4098-8cfd-de0d716384a5",
"metadata": {},
"outputs": [],
"source": [
"results = run_classification(\n",
" df,\n",
" sentence_col=\"sentences\", # provide the column name in which the text is you want to classify (sentence vs sentences?)\n",
" vendor=\"azure\",\n",
" model=\"gpt-4.1\",\n",
" runs=[1],\n",
" output_dir=PRELIM_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
" checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ee8af782-e8a6-43ce-ae24-8db0fe997a09",
"metadata": {},
"source": [
"## STEP 1: Load labeled data from a single genAI run"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14621c20-9da6-4217-b6a7-d6412e224531",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Configuration\n",
"# -----------------------------\n",
"\n",
"# CHANGE THIS: Path to your labeled data (one run, one model)\n",
"LABELED_FILE = \"./output/preliminary/Run1/preliminary_azure_gpt-4.1_chat_run1.pkl\"\n",
"\n",
"# Functional areas\n",
"CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7b2d53f0-22f2-4e16-8a58-8a6cca24adc1",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import ast\n",
"\n",
"\n",
"# -----------------------------\n",
"# Load and Parse\n",
"# -----------------------------\n",
"\n",
"print(\"=\"*60)\n",
"print(\"LOADING LABELED DATA\")\n",
"print(\"=\"*60)\n",
"\n",
"df_raw = pd.read_pickle(LABELED_FILE)\n",
"print(f\"Source: {LABELED_FILE}\")\n",
"print(f\"Total sentences: {len(df_raw)}\")\n",
"\n",
"# Parse labels into binary columns\n",
"def parse_labels_safe(labels):\n",
" \"\"\"Parse labels from various formats.\"\"\"\n",
" if labels is None:\n",
" return []\n",
" if isinstance(labels, list):\n",
" return labels\n",
" if isinstance(labels, str):\n",
" try:\n",
" parsed = ast.literal_eval(labels)\n",
" return parsed if isinstance(parsed, list) else []\n",
" except:\n",
" return []\n",
" return []\n",
"\n",
"# Create binary columns for each class\n",
"df_labeled = df_raw[['id', 'sentence']].copy()\n",
"\n",
"parsed_labels = df_raw['labels'].apply(parse_labels_safe)\n",
"\n",
"for cls in CLASSES:\n",
" df_labeled[cls] = parsed_labels.apply(lambda x: 1 if cls in x else 0)\n",
"\n",
"# None = all classes are 0\n",
"df_labeled['None'] = (df_labeled[CLASSES].sum(axis=1) == 0).astype(int)\n",
"\n",
"# Remove rows where labeling failed (error column is not None/NaN)\n",
"if 'error' in df_raw.columns:\n",
" error_mask = df_raw['error'].notna()\n",
" n_errors = error_mask.sum()\n",
" if n_errors > 0:\n",
" df_labeled = df_labeled[~error_mask].reset_index(drop=True)\n",
" print(f\"Removed {n_errors} rows with labeling errors\")\n",
"\n",
"print(f\"Successfully labeled: {len(df_labeled)} sentences\")\n",
"\n",
"# Preview\n",
"print(f\"\\nPreview:\")\n",
"label_cols = CLASSES + ['None']\n",
"print(df_labeled[['sentence'] + label_cols].head(5))"
]
},
{
"cell_type": "markdown",
"id": "b6b4c4d7-2ed1-42a3-8d6b-09c7371bc9ba",
"metadata": {},
"source": [
"## STEP 2: Descriptives"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91d6ef4d-6498-4d36-837a-91b44d329cd0",
"metadata": {},
"outputs": [],
"source": [
"\n",
"label_cols = CLASSES + ['None']\n",
"\n",
"print(\"=\"*60)\n",
"print(\"DESCRIPTIVE STATISTICS\")\n",
"print(\"=\"*60)\n",
"print(f\"Total sentences: {len(df_labeled)}\")\n",
"\n",
"# Class distribution\n",
"print(f\"\\n{'='*60}\")\n",
"print(\"CLASS DISTRIBUTION\")\n",
"print(f\"{'='*60}\")\n",
"print(f\"{'Class':<15} {'Count':>8} {'Share':>10}\")\n",
"print(f\"{'-'*35}\")\n",
"\n",
"for cls in label_cols:\n",
" count = df_labeled[cls].sum()\n",
" share = count / len(df_labeled) * 100\n",
" print(f\"{cls:<15} {count:>8} {share:>9.1f}%\")\n",
"\n",
"# Multi-label distribution\n",
"df_labeled['num_labels'] = df_labeled[CLASSES].sum(axis=1)\n",
"\n",
"print(f\"\\n{'='*60}\")\n",
"print(\"MULTI-LABEL DISTRIBUTION\")\n",
"print(f\"{'='*60}\")\n",
"print(f\"{'# Labels':<15} {'Count':>8} {'Share':>10}\")\n",
"print(f\"{'-'*35}\")\n",
"\n",
"for n in range(7):\n",
" count = (df_labeled['num_labels'] == n).sum()\n",
" if count > 0:\n",
" share = count / len(df_labeled) * 100\n",
" label_text = f\"{n} label{'s' if n != 1 else ''}\"\n",
" print(f\"{label_text:<15} {count:>8} {share:>9.1f}%\")\n",
"\n",
"print(f\"{'-'*35}\")\n",
"print(f\"{'Mean labels':<15} {df_labeled['num_labels'].mean():>8.2f}\")\n",
"print(f\"{'Median labels':<15} {df_labeled['num_labels'].median():>8.0f}\")\n",
"\n",
"# Co-occurrence matrix\n",
"print(f\"\\n{'='*60}\")\n",
"print(\"CO-OCCURRENCE MATRIX\")\n",
"print(f\"{'='*60}\")\n",
"\n",
"cooccurrence = pd.DataFrame(index=CLASSES, columns=CLASSES, dtype=int)\n",
"for cls1 in CLASSES:\n",
" for cls2 in CLASSES:\n",
" cooccurrence.loc[cls1, cls2] = ((df_labeled[cls1] == 1) & (df_labeled[cls2] == 1)).sum()\n",
"\n",
"print(cooccurrence.to_string())\n",
"\n",
"# Top label combinations\n",
"print(f\"\\n{'='*60}\")\n",
"print(\"TOP 10 LABEL COMBINATIONS\")\n",
"print(f\"{'='*60}\")\n",
"\n",
"def get_label_combo(row):\n",
" labels = [cls for cls in CLASSES if row[cls] == 1]\n",
" return str(labels) if labels else \"[]\"\n",
"\n",
"combo_counts = df_labeled.apply(get_label_combo, axis=1).value_counts().head(10)\n",
"print(f\"{'Combination':<50} {'Count':>8} {'Share':>10}\")\n",
"print(f\"{'-'*70}\")\n",
"for combo, count in combo_counts.items():\n",
" share = count / len(df_labeled) * 100\n",
" print(f\"{combo:<50} {count:>8} {share:>9.1f}%\")\n",
"\n",
"# Clean up temp column\n",
"df_labeled.drop(columns=['num_labels'], inplace=True)"
]
},
{
"cell_type": "markdown",
"id": "82b12506-7b3a-4f83-b3c9-da733c89306e",
"metadata": {},
"source": [
"## STEP 3: Create Holdout and Train Sets"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "86bf25c6-10b5-4d2c-8043-46c62c6598bd",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# STEP 3: CREATE HOLDOUT AND TRAIN SETS\n",
"# =============================================================================\n",
"\n",
"import os\n",
"\n",
"# -----------------------------\n",
"# Configuration\n",
"# -----------------------------\n",
"\n",
"HOLDOUT_SIZE = 1000\n",
"MIN_PER_CLASS = 100\n",
"MAX_PER_CLASS = 333\n",
"MAX_NONE = 100 # Maximum None sentences in holdout\n",
"OUT_DIR = \"./Holdout_Train\"\n",
"\n",
"os.makedirs(OUT_DIR, exist_ok=True)\n",
"\n",
"np.random.seed(42)\n",
"\n",
"# -----------------------------\n",
"# Helper\n",
"# -----------------------------\n",
"\n",
"def get_labels(row):\n",
" \"\"\"Get list of functional labels for a row.\"\"\"\n",
" return [cls for cls in CLASSES if row[cls] == 1]\n",
"\n",
"# Add temp columns\n",
"df_labeled['labels_list'] = df_labeled.apply(get_labels, axis=1)\n",
"df_labeled['num_labels'] = df_labeled[CLASSES].sum(axis=1)\n",
"\n",
"# -----------------------------\n",
"# Stratified Sampling (6 functional classes only — None is derived)\n",
"# -----------------------------\n",
"\n",
"selected_indices = set()\n",
"class_counts = {cls: 0 for cls in CLASSES}\n",
"\n",
"print(\"=\"*60)\n",
"print(\"STRATIFIED HOLDOUT SAMPLING\")\n",
"print(\"=\"*60)\n",
"print(f\"Target holdout size: {HOLDOUT_SIZE}\")\n",
"print(f\"Min per class: {MIN_PER_CLASS} (6 functional classes)\")\n",
"print(f\"Max per class: {MAX_PER_CLASS}\")\n",
"print(f\"None: not constrained (derived, not independently predicted)\")\n",
"\n",
"# Phase 1: Ensure minimums, rarest classes first\n",
"class_freq = {cls: int(df_labeled[cls].sum()) for cls in CLASSES}\n",
"classes_by_rarity = sorted(CLASSES, key=lambda c: class_freq[c])\n",
"\n",
"print(f\"\\n--- Phase 1: Ensure minimums (rarest first) ---\")\n",
"print(f\" Class frequencies: {', '.join(f'{c}: {class_freq[c]}' for c in classes_by_rarity)}\")\n",
"\n",
"for rare_class in classes_by_rarity:\n",
" candidates = df_labeled[\n",
" (df_labeled[rare_class] == 1) & \n",
" (~df_labeled.index.isin(selected_indices))\n",
" ].sort_values('num_labels', ascending=False)\n",
" \n",
" for idx, row in candidates.iterrows():\n",
" if class_counts[rare_class] >= MIN_PER_CLASS:\n",
" break\n",
" labels = row['labels_list']\n",
" if len(labels) > 0 and any(class_counts[lbl] >= MAX_PER_CLASS for lbl in labels):\n",
" continue\n",
" selected_indices.add(idx)\n",
" for lbl in labels:\n",
" class_counts[lbl] += 1\n",
" \n",
" print(f\" {rare_class}: {class_counts[rare_class]} sentences\")\n",
"\n",
"print(f\"\\nAfter Phase 1: {len(selected_indices)} sentences selected\")\n",
"\n",
"# Phase 2: Multi-label sentences\n",
"print(f\"\\n--- Phase 2: Add multi-label sentences ---\")\n",
"\n",
"current_multilabel = df_labeled.loc[list(selected_indices), 'num_labels'].gt(1).sum()\n",
"target_multilabel = int(HOLDOUT_SIZE * 0.18)\n",
"\n",
"multilabel_candidates = df_labeled[\n",
" (df_labeled['num_labels'] >= 2) & \n",
" (~df_labeled.index.isin(selected_indices))\n",
"].sample(frac=1, random_state=42)\n",
"\n",
"for idx, row in multilabel_candidates.iterrows():\n",
" if current_multilabel >= target_multilabel:\n",
" break\n",
" if len(selected_indices) >= HOLDOUT_SIZE:\n",
" break\n",
" labels = row['labels_list']\n",
" if len(labels) > 0 and any(class_counts[lbl] >= MAX_PER_CLASS for lbl in labels):\n",
" continue\n",
" selected_indices.add(idx)\n",
" for lbl in labels:\n",
" class_counts[lbl] += 1\n",
" current_multilabel += 1\n",
"\n",
"print(f\" Multi-label sentences: {current_multilabel}\")\n",
"print(f\" Total selected: {len(selected_indices)}\")\n",
"\n",
"# Phase 3: Fill remaining (cap None sentences)\n",
"none_count = sum(1 for idx in selected_indices if df_labeled.loc[idx, CLASSES].sum() == 0)\n",
"\n",
"print(f\"\\n--- Phase 3: Fill to {HOLDOUT_SIZE} (max None: {MAX_NONE}) ---\")\n",
"print(f\" None so far: {none_count}\")\n",
"\n",
"if len(selected_indices) < HOLDOUT_SIZE:\n",
" fill_candidates = df_labeled[\n",
" ~df_labeled.index.isin(selected_indices)\n",
" ].sample(frac=1, random_state=42)\n",
" \n",
" for idx, row in fill_candidates.iterrows():\n",
" if len(selected_indices) >= HOLDOUT_SIZE:\n",
" break\n",
" labels = row['labels_list']\n",
" \n",
" # Skip if any functional class would exceed max\n",
" if len(labels) > 0 and any(class_counts[lbl] >= MAX_PER_CLASS for lbl in labels):\n",
" continue\n",
" \n",
" # Skip if this is a None sentence and we've hit the cap\n",
" if len(labels) == 0 and none_count >= MAX_NONE:\n",
" continue\n",
" \n",
" selected_indices.add(idx)\n",
" for lbl in labels:\n",
" class_counts[lbl] += 1\n",
" if len(labels) == 0:\n",
" none_count += 1\n",
"\n",
"print(f\" Final count: {len(selected_indices)}\")\n",
"\n",
"# -----------------------------\n",
"# Create Clean DataFrames (LABELS BLANKED OUT)\n",
"# -----------------------------\n",
"\n",
"# Holdout and train get sentences + empty label columns\n",
"df_holdout = df_labeled.loc[list(selected_indices), ['sentence']].copy()\n",
"df_train = df_labeled.loc[~df_labeled.index.isin(selected_indices), ['sentence']].copy()\n",
"\n",
"# Add empty label columns (students fill these in)\n",
"for cls in CLASSES + ['None']:\n",
" df_holdout[cls] = \"\"\n",
" df_train[cls] = \"\"\n",
"\n",
"# Shuffle both\n",
"df_holdout = df_holdout.sample(frac=1, random_state=42).reset_index(drop=True)\n",
"df_train = df_train.sample(frac=1, random_state=42).reset_index(drop=True)\n",
"\n",
"# Clean up temp columns from df_labeled\n",
"df_labeled.drop(columns=['labels_list', 'num_labels'], inplace=True, errors='ignore')\n",
"\n",
"# -----------------------------\n",
"# Verification (using original labels for reporting only)\n",
"# -----------------------------\n",
"\n",
"print(f\"\\n{'='*60}\")\n",
"print(\"VERIFICATION\")\n",
"print(f\"{'='*60}\")\n",
"\n",
"print(f\"\\nDataset sizes:\")\n",
"print(f\" Holdout: {len(df_holdout)}\")\n",
"print(f\" Train: {len(df_train)}\")\n",
"print(f\" Total: {len(df_holdout) + len(df_train)}\")\n",
"\n",
"# Verify no overlap\n",
"overlap = set(df_holdout['sentence']) & set(df_train['sentence'])\n",
"print(f\"Overlap check: {len(overlap)} sentences (should be 0)\")\n",
"\n",
"# Report holdout class distribution (from original labels, for our reference)\n",
"print(f\"\\n{'='*60}\")\n",
"print(\"HOLDOUT CLASS DISTRIBUTION (from LLM labels, for reference)\")\n",
"print(f\"{'='*60}\")\n",
"print(f\"{'Class':<15} {'Count':>8} {'Share':>10} {'Min OK':>10} {'Max OK':>10}\")\n",
"print(f\"{'-'*55}\")\n",
"\n",
"holdout_sentences = set(df_holdout['sentence'])\n",
"df_holdout_check = df_labeled[df_labeled['sentence'].isin(holdout_sentences)]\n",
"\n",
"for cls in CLASSES:\n",
" count = df_holdout_check[cls].sum()\n",
" share = count / len(df_holdout_check) * 100\n",
" min_ok = \"✓\" if count >= MIN_PER_CLASS else \"✗\"\n",
" max_ok = \"✓\" if count <= MAX_PER_CLASS else \"✗\"\n",
" print(f\"{cls:<15} {count:>8} {share:>9.1f}% {min_ok:>10} {max_ok:>10}\")\n",
"\n",
"# None (derived, for reference)\n",
"none_count = (df_holdout_check[CLASSES].sum(axis=1) == 0).sum()\n",
"none_share = none_count / len(df_holdout_check) * 100\n",
"none_ok = \"✓\" if none_count <= MAX_NONE else \"✗\"\n",
"print(f\"{'None (derived)':<15} {none_count:>8} {none_share:>9.1f}% {'':>10} {none_ok:>10} (max {MAX_NONE})\")\n",
"\n",
"# Multi-label distribution\n",
"print(f\"\\n{'='*60}\")\n",
"print(\"HOLDOUT MULTI-LABEL DISTRIBUTION (from LLM labels, for reference)\")\n",
"print(f\"{'='*60}\")\n",
"print(f\"{'# Labels':<15} {'Count':>8} {'Share':>10}\")\n",
"print(f\"{'-'*35}\")\n",
"\n",
"num_labels = df_holdout_check[CLASSES].sum(axis=1)\n",
"for n in range(7):\n",
" count = (num_labels == n).sum()\n",
" if count > 0:\n",
" share = count / len(df_holdout_check) * 100\n",
" print(f\"{n} label{'s' if n != 1 else '':<14} {count:>8} {share:>9.1f}%\")\n",
"\n",
"# Preview\n",
"print(f\"\\nHoldout preview (labels are blank for human experts to fill):\")\n",
"print(df_holdout.head(5))\n",
"print(f\"\\nTrain preview (labels are blank for genAI to fill):\")\n",
"print(df_train.head(5))\n",
"\n",
"# -----------------------------\n",
"# Save\n",
"# -----------------------------\n",
"\n",
"df_holdout.to_csv(f\"{OUT_DIR}/holdout.csv\", index=False)\n",
"df_holdout.to_pickle(f\"{OUT_DIR}/holdout.pkl\")\n",
"df_train.to_csv(f\"{OUT_DIR}/train.csv\", index=False)\n",
"df_train.to_pickle(f\"{OUT_DIR}/train.pkl\")\n",
"\n",
"print(f\"\\n{'='*60}\")\n",
"print(\"SAVED\")\n",
"print(f\"{'='*60}\")\n",
"print(f\" Holdout: {OUT_DIR}/holdout.csv (.pkl)\")\n",
"print(f\" Train: {OUT_DIR}/train.csv (.pkl)\")\n",
"print(f\"\\nColumns: {list(df_holdout.columns)}\")\n",
"print(f\"Label columns are EMPTY - human experts must fill these in\")"
]
},
{
"cell_type": "markdown",
"id": "189837d8-bf61-4924-981a-3d5f83ec9e78",
"metadata": {},
"source": [
"# **Human Labeling of Holdout for Ground Truth**\n",
"\n",
"- Now we have a more or less balanced holdout set\n",
"- We need to get this labeled by human experts, that is, domain experts / subject matter experts\n",
"- You will need at least three humans to evaluate each sentence independently,\n",
" *or* work as an expert team and discuss each sentence to assign it the appropriate class (or labels)\n",
"- Make sure to save the file with the labels AND ***be sure that the text (e.g., sentences) are not broken / contain formatting or ASCI or HTML errors***\n",
" - You might want to label in excel, but then merge the labels from the XLSX or CSV file to the the holdout.pkl file (which better preserved the actual texts)\n",
"- If you have independent human experts label the holdout, then you need to have a majority vote per label (this suggests you need an odd number of human expert labelers). Good practice is also to check inter-rater agreement on the labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fac02954-21d8-419b-9215-733639f8fdb6",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "3abaaf59-24e0-49dc-aef9-47f88dbc4469",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "0e3aae9b-6565-413d-aa57-78113809acf5",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "546b0d8b-c049-4ae3-8f37-f471036a99bc",
"metadata": {},
"source": [
"# **genAI Labeling of Holdout**\n",
"> Now that you have a ground truth for your holdout set, you need to determine which genAI model performs best on it\n",
"\n",
"#### Approach: \n",
"1. **Label the holdout 3 times with a genAI mode** (use code from above, but instead of loading the \"10K_unlabeled.csv\" file for preliminary labeling, you want to load your holdout.pkl file (the one without your human labels).\n",
"2. **Check** the three runs for **label agreement** (krippendorff's alpha, for example)\n",
"3. Get the **majority vote** (for class or for labels - this example will be for ***labels*** that, unlike classes, are ***NOT mutually exclusive***\n",
"4. **Compare to human expert labels on holdout** set to measure genAI labeling **performance**\n",
"5. **Repeat** for other models (of other vendors)\n",
"6. **Find model** that does **best**, then use it to label the train set (3 times - same genAI API code from above but on different file: makre sure to adjust the paths where you store the data so you don't overwrite your holdout labels), get majority votes on train, then use those to fine-tune a pretrained and open-source LLM (like RoBERTa Large)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c81a6d50-b86b-4384-824c-3f7b1262d63e",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# OUTPUT CONFIGURATION - SET for Base, Holdout, Train\n",
"# =============================================================================\n",
"CHECKPOINT_DIR = \"./checkpoints\"\n",
"OUTPUT_DIR = \"./output\"\n",
"HOLDOUT_DIR = f\"{OUTPUT_DIR}/holdout\"\n",
"TRAIN_DIR = f\"{OUTPUT_DIR}/train\"\n",
"\n",
"# =============================================================================\n",
"# Load Holdout (not the human labeled, but the blank)\n",
"# =============================================================================\n",
"\n",
"# # Load Data: Labeling of (balanced) holdout data\n",
"df = pd.read_pickle(\"Holdout_Train/holdout.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7bea59bb-1faf-4469-819a-e5cbcf9d5425",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# Query genAI model via API\n",
"# =============================================================================\n",
"\n",
"# Here an example for gpt-4o via UNC Azure API\n",
"results = run_classification(\n",
" df,\n",
" sentence_col=\"sentence\", # --> Check if your file has the column as named here where the text is (sentence vs sentences vs text vs tweet vs ... )\n",
" vendor=\"azure\",\n",
" model=\"gpt-4o\",\n",
" runs=[1,2,3], # Doing 3 runs for conistency / replicability: will later take majority vote per lable across runs\n",
" output_dir=HOLDOUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
" checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "207a4e45-dc15-4df4-855c-d151c1a9f649",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# Query another genAI model via API\n",
"# =============================================================================\n",
"\n",
"# Here an example for gpt-4o via UNC Azure API\n",
"results = run_classification(\n",
" df,\n",
" sentence_col=\"sentence\", # --> Check if your file has the column as named here where the text is (sentence vs sentences vs text vs tweet vs ... )\n",
" vendor=\"azure\",\n",
" model=\"gpt-4.1\",\n",
" runs=[1,2,3], # Doing 3 runs for conistency / replicability: will later take majority vote per lable across runs\n",
" output_dir=HOLDOUT_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
" checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
")"
]
},
{
"cell_type": "markdown",
"id": "158e92e8-1f1a-4b58-bb8e-a000f4fe9b59",
"metadata": {},
"source": [
"# genAI Label Agreement\n",
"\n",
"- How consistent is genAI in its labels?\n",
"- Test the extent to which genAI's labels agree across runs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4d0c9f7c-d74d-4cc5-a16c-55e20dc0293e",
"metadata": {},
"outputs": [],
"source": [
"# pip install -q -U krippendorff # already done at the very beginning. Keeping this as a reminder"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "793ba990-1e0d-4376-b2cc-887bbe633ee8",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# INTER-RATER AGREEMENT ACROSS RUNS (Krippendorff's Alpha)\n",
"# =============================================================================\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import ast\n",
"import glob\n",
"import re\n",
"import krippendorff \n",
"\n",
"# -----------------------------\n",
"# Configuration\n",
"# -----------------------------\n",
"\n",
"HOLDOUT_DIR = \"./output/holdout\" # Set your holdout output directory\n",
"RUNS = [1, 2, 3] # Which runs to compare\n",
"CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "063a4981-1524-4bb0-a91c-c937f345fc27",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Helper Functions\n",
"# -----------------------------\n",
"\n",
"def parse_labels_safe(labels):\n",
" \"\"\"Parse labels from various formats.\"\"\"\n",
" if labels is None:\n",
" return []\n",
" if isinstance(labels, list):\n",
" return labels\n",
" if isinstance(labels, str):\n",
" try:\n",
" parsed = ast.literal_eval(labels)\n",
" return parsed if isinstance(parsed, list) else []\n",
" except:\n",
" return []\n",
" return []\n",
"\n",
"def parse_filename(filename):\n",
" \"\"\"\n",
" Parse filename like 'holdout_azure_gpt-4.1_chat_run1.pkl' \n",
" into (vendor, model, mode, run).\n",
" Handles optional dataset prefix (holdout_, train_, prelim_, output_).\n",
" \"\"\"\n",
" # Remove extension\n",
" name = filename.replace('.pkl', '').replace('.csv', '')\n",
" \n",
" # Strip known dataset prefixes\n",
" for prefix in ['holdout_', 'train_', 'prelim_', 'output_']:\n",
" if name.startswith(prefix):\n",
" name = name[len(prefix):]\n",
" break\n",
" \n",
" # Extract run number from end\n",
" match = re.search(r'_run(\\d+)$', name)\n",
" if not match:\n",
" return None\n",
" run = int(match.group(1))\n",
" name = name[:match.start()]\n",
" \n",
" # Extract mode (last part before run)\n",
" parts = name.rsplit('_', 1)\n",
" if len(parts) != 2:\n",
" return None\n",
" mode = parts[1]\n",
" remainder = parts[0]\n",
" \n",
" # Extract vendor (first part) and model (rest)\n",
" first_underscore = remainder.index('_')\n",
" vendor = remainder[:first_underscore]\n",
" model = remainder[first_underscore + 1:]\n",
" \n",
" return vendor, model, mode, run\n",
"\n",
"def discover_models(holdout_dir, runs):\n",
" \"\"\"\n",
" Discover all vendor_model_mode combinations that have ALL specified runs.\n",
" Returns dict: {(vendor, model, mode): {run: filepath}}\n",
" \"\"\"\n",
" models = {}\n",
" \n",
" for run in runs:\n",
" run_dir = f\"{holdout_dir}/Run{run}\"\n",
" pkl_files = glob.glob(f\"{run_dir}/*.pkl\")\n",
" \n",
" for filepath in pkl_files:\n",
" filename = filepath.split(\"/\")[-1]\n",
" parsed = parse_filename(filename)\n",
" if parsed is None:\n",
" continue\n",
" \n",
" vendor, model, mode, file_run = parsed\n",
" if file_run != run:\n",
" continue\n",
" \n",
" key = (vendor, model, mode)\n",
" if key not in models:\n",
" models[key] = {}\n",
" models[key][run] = filepath\n",
" \n",
" # Keep only models that have ALL specified runs\n",
" complete = {\n",
" key: paths for key, paths in models.items()\n",
" if all(r in paths for r in runs)\n",
" }\n",
" \n",
" return complete\n",
"\n",
"def load_and_binarize(filepath, classes):\n",
" \"\"\"Load pkl and convert labels to binary columns.\"\"\"\n",
" df = pd.read_pickle(filepath)\n",
" \n",
" parsed = df['labels'].apply(parse_labels_safe)\n",
" \n",
" binary = pd.DataFrame(index=df.index)\n",
" binary['id'] = df['id'] if 'id' in df.columns else df.index\n",
" binary['sentence'] = df['sentence']\n",
" \n",
" for cls in classes:\n",
" binary[cls] = parsed.apply(lambda x: 1 if cls in x else 0)\n",
" \n",
" # Remove error rows\n",
" if 'error' in df.columns:\n",
" binary = binary[df['error'].isna()].reset_index(drop=True)\n",
" \n",
" return binary\n",
"\n",
"def compute_agreement(model_runs, classes, runs):\n",
" \"\"\"\n",
" Compute Krippendorff's alpha per class and overall.\n",
" Only evaluates the 6 functional classes — None is derived, not labeled.\n",
" \n",
" Args:\n",
" model_runs: dict {run: filepath}\n",
" classes: list of class names (6 functional classes)\n",
" runs: list of run numbers\n",
" \n",
" Returns:\n",
" dict with per-class and overall alpha\n",
" \"\"\"\n",
" # Load all runs\n",
" run_dfs = {}\n",
" for run in runs:\n",
" run_dfs[run] = load_and_binarize(model_runs[run], classes)\n",
" \n",
" # Verify all runs have same sentences\n",
" n_sentences = len(run_dfs[runs[0]])\n",
" for run in runs:\n",
" assert len(run_dfs[run]) == n_sentences, \\\n",
" f\"Run {run} has {len(run_dfs[run])} sentences, expected {n_sentences}\"\n",
" \n",
" results = {}\n",
" \n",
" # Per-class alpha (6 functional classes only)\n",
" for cls in classes:\n",
" # Reliability matrix: rows = raters (runs), columns = units (sentences)\n",
" reliability_matrix = np.array([\n",
" run_dfs[run][cls].values for run in runs\n",
" ])\n",
" \n",
" # Krippendorff's alpha (nominal level for binary data)\n",
" try:\n",
" alpha = krippendorff.alpha(\n",
" reliability_data=reliability_matrix,\n",
" level_of_measurement='nominal',\n",
" )\n",
" except:\n",
" alpha = np.nan\n",
" \n",
" results[cls] = alpha\n",
" \n",
" # Overall alpha (flatten 6 functional classes into one reliability matrix)\n",
" overall_matrix = np.hstack([\n",
" np.array([run_dfs[run][cls].values for run in runs])\n",
" for cls in classes\n",
" ])\n",
" \n",
" try:\n",
" results['Overall'] = krippendorff.alpha(\n",
" reliability_data=overall_matrix,\n",
" level_of_measurement='nominal',\n",
" )\n",
" except:\n",
" results['Overall'] = np.nan\n",
" \n",
" # Pairwise agreement percentage per class (6 functional classes only)\n",
" pairwise = {}\n",
" for cls in classes:\n",
" agreements = []\n",
" for i, r1 in enumerate(runs):\n",
" for r2 in runs[i+1:]:\n",
" agree = (run_dfs[r1][cls].values == run_dfs[r2][cls].values).mean()\n",
" agreements.append(agree)\n",
" pairwise[cls] = np.mean(agreements)\n",
" pairwise['Overall'] = np.mean([pairwise[cls] for cls in classes])\n",
" \n",
" return results, pairwise, n_sentences\n",
"\n",
"\n",
"# -----------------------------\n",
"# Main Analysis\n",
"# -----------------------------\n",
"\n",
"print(\"=\"*70)\n",
"print(\"INTER-RATER AGREEMENT ANALYSIS (Krippendorff's Alpha)\")\n",
"print(\"=\"*70)\n",
"print(f\"Holdout dir: {HOLDOUT_DIR}\")\n",
"print(f\"Runs: {RUNS}\")\n",
"print(f\"Classes evaluated: {CLASSES} (None excluded — derived, not labeled)\")\n",
"\n",
"# Discover models\n",
"models = discover_models(HOLDOUT_DIR, RUNS)\n",
"\n",
"print(f\"\\nFound {len(models)} model(s) with all {len(RUNS)} runs:\")\n",
"for (vendor, model, mode), paths in models.items():\n",
" print(f\" {vendor} / {model} / {mode}\")\n",
" for run, path in sorted(paths.items()):\n",
" print(f\" Run {run}: {path}\")\n",
"\n",
"# Compute agreement for each model\n",
"all_results = {}\n",
"\n",
"for (vendor, model, mode), paths in models.items():\n",
" model_key = f\"{vendor}_{model}_{mode}\"\n",
" \n",
" print(f\"\\n{'='*70}\")\n",
" print(f\"MODEL: {vendor} / {model} / {mode}\")\n",
" print(f\"{'='*70}\")\n",
" \n",
" alphas, pairwise, n_sentences = compute_agreement(paths, CLASSES, RUNS)\n",
" all_results[model_key] = {'alphas': alphas, 'pairwise': pairwise}\n",
" \n",
" print(f\"Sentences: {n_sentences}\")\n",
" print(f\"Runs compared: {RUNS}\")\n",
" \n",
" # Per-class table (6 functional classes + Overall)\n",
" print(f\"\\n{'Class':<15} {'k-Alpha':>10} {'Agreement Share':>16}\")\n",
" print(\"-\"*45)\n",
" \n",
" for cls in CLASSES + ['Overall']:\n",
" alpha = alphas[cls]\n",
" agree = pairwise[cls]\n",
" \n",
" if cls == 'Overall':\n",
" print(\"-\"*45)\n",
" \n",
" print(f\"{cls:<15} {alpha:>10.4f} {agree:>15.1%}\")\n",
" \n",
" # Interpretation guide\n",
" print(f\"\\nKrippendorff's Alpha Interpretation:\")\n",
" print(f\" α ≥ 0.80 → Reliable agreement\")\n",
" print(f\" α ≥ 0.667 → Tentative agreement (acceptable for some purposes)\")\n",
" print(f\" α < 0.667 → Unreliable / insufficient agreement\")\n",
" print(f\"\\n Note: None class excluded — it is derived (all classes = 0),\")\n",
" print(f\" not independently labeled by the model.\")\n",
"\n",
"# -----------------------------\n",
"# Summary across models\n",
"# -----------------------------\n",
"\n",
"if len(all_results) > 1:\n",
" print(f\"\\n{'='*70}\")\n",
" print(\"SUMMARY ACROSS MODELS\")\n",
" print(f\"{'='*70}\")\n",
" \n",
" summary_rows = []\n",
" for model_key, data in all_results.items():\n",
" row = {'Model': model_key}\n",
" row['Overall_Alpha'] = data['alphas']['Overall']\n",
" row['Overall_Agreement'] = data['pairwise']['Overall']\n",
" for cls in CLASSES:\n",
" row[f'{cls}_Alpha'] = data['alphas'][cls]\n",
" summary_rows.append(row)\n",
" \n",
" df_summary = pd.DataFrame(summary_rows)\n",
" df_summary = df_summary.sort_values('Overall_Alpha', ascending=False)\n",
" \n",
" print(f\"\\n{'Model':<35} {'Overall α':>12} {'Agreement':>12}\")\n",
" print(\"-\"*60)\n",
" for _, row in df_summary.iterrows():\n",
" print(f\"{row['Model']:<35} {row['Overall_Alpha']:>12.4f} {row['Overall_Agreement']:>11.1%}\")\n",
" \n",
" # Save summary\n",
" df_summary.to_csv(f\"{HOLDOUT_DIR}/agreement_summary.csv\", index=False)\n",
" print(f\"\\nSaved summary to {HOLDOUT_DIR}/agreement_summary.csv\")"
]
},
{
"cell_type": "markdown",
"id": "b9b669eb-b3ea-4137-8d58-70455261198c",
"metadata": {},
"source": [
"# Majority Votes across Runs\n",
"\n",
"- Across runs\n",
"- Does not handle ties yet (assumes odd number of total runs)\n",
"- Assumes multi-label problem (one sentence can have multiple labels)\n",
"- Assumes that if majority is [] (none), then all other classes must be negative (0 or FALSE) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3526fb48-f557-4025-b228-f4f75915abc2",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# MAJORITY VOTE PER MODEL (across its own runs)\n",
"# =============================================================================\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import glob\n",
"import ast\n",
"import re\n",
"import os\n",
"\n",
"# -----------------------------\n",
"# Configuration\n",
"# -----------------------------\n",
"\n",
"HOLDOUT_DIR = \"./output/holdout\" # Where Run1/, Run2/, Run3/ etc. are\n",
"RUNS = [1, 2, 3] # Which runs to aggregate\n",
"CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95aa527c-daa4-4063-a5e1-fc438f2f76fb",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Helper Functions\n",
"# -----------------------------\n",
"\n",
"def parse_labels_safe(labels):\n",
" \"\"\"Parse labels from various formats.\"\"\"\n",
" if labels is None:\n",
" return []\n",
" if isinstance(labels, list):\n",
" return labels\n",
" if isinstance(labels, str):\n",
" try:\n",
" parsed = ast.literal_eval(labels)\n",
" return parsed if isinstance(parsed, list) else []\n",
" except:\n",
" return []\n",
" return []\n",
"\n",
"def parse_filename(filename):\n",
" \"\"\"\n",
" Parse filename like 'holdout_azure_gpt-4.1_chat_run1.pkl'\n",
" into (vendor, model, mode, run).\n",
" Handles optional dataset prefix (holdout_, train_, prelim_, output_).\n",
" \"\"\"\n",
" name = filename.replace('.pkl', '').replace('.csv', '')\n",
" \n",
" # Strip known dataset prefixes\n",
" for prefix in ['holdout_', 'train_', 'prelim_', 'output_']:\n",
" if name.startswith(prefix):\n",
" name = name[len(prefix):]\n",
" break\n",
" \n",
" match = re.search(r'_run(\\d+)$', name)\n",
" if not match:\n",
" return None\n",
" run = int(match.group(1))\n",
" name = name[:match.start()]\n",
" parts = name.rsplit('_', 1)\n",
" if len(parts) != 2:\n",
" return None\n",
" mode = parts[1]\n",
" remainder = parts[0]\n",
" first_underscore = remainder.index('_')\n",
" vendor = remainder[:first_underscore]\n",
" model = remainder[first_underscore + 1:]\n",
" return vendor, model, mode, run\n",
"\n",
"def discover_models(holdout_dir, runs):\n",
" \"\"\"Find all vendor_model_mode combos that have ALL specified runs.\"\"\"\n",
" models = {}\n",
" for run in runs:\n",
" run_dir = f\"{holdout_dir}/Run{run}\"\n",
" if not os.path.exists(run_dir):\n",
" continue\n",
" for filepath in glob.glob(f\"{run_dir}/*.pkl\"):\n",
" filename = filepath.split(\"/\")[-1]\n",
" parsed = parse_filename(filename)\n",
" if parsed is None:\n",
" continue\n",
" vendor, model, mode, file_run = parsed\n",
" if file_run != run:\n",
" continue\n",
" key = (vendor, model, mode)\n",
" if key not in models:\n",
" models[key] = {}\n",
" models[key][run] = filepath\n",
" \n",
" # Keep only models with ALL specified runs\n",
" return {k: v for k, v in models.items() if all(r in v for r in runs)}\n",
"\n",
"def compute_majority_vote(model_runs, classes, runs):\n",
" \"\"\"\n",
" Compute majority vote for a single model across its runs.\n",
" \n",
" Votes are counted for the 6 functional classes only.\n",
" None is derived: if no functional class gets majority, None = 1.\n",
" None_votes tracks how many runs returned an empty label list (for reference).\n",
" \n",
" Returns DataFrame with sentence + binary labels + vote counts.\n",
" \"\"\"\n",
" # Load all runs\n",
" run_dfs = {}\n",
" for run in runs:\n",
" run_dfs[run] = pd.read_pickle(model_runs[run])\n",
" \n",
" # Collect votes per sentence\n",
" votes = {}\n",
" for run in runs:\n",
" df = run_dfs[run]\n",
" for _, row in df.iterrows():\n",
" sent_id = row['id']\n",
" sentence = row['sentence']\n",
" labels = parse_labels_safe(row['labels'])\n",
" \n",
" # Skip error rows\n",
" if 'error' in df.columns and pd.notna(row.get('error')):\n",
" continue\n",
" \n",
" if sent_id not in votes:\n",
" votes[sent_id] = {\n",
" 'sentence': sentence,\n",
" **{cls: 0 for cls in classes},\n",
" 'none_runs': 0,\n",
" 'total_runs': 0,\n",
" }\n",
" \n",
" votes[sent_id]['total_runs'] += 1\n",
" \n",
" if len(labels) == 0:\n",
" votes[sent_id]['none_runs'] += 1\n",
" else:\n",
" for label in labels:\n",
" if label in classes:\n",
" votes[sent_id][label] += 1\n",
" \n",
" # Calculate majority\n",
" results = []\n",
" for sent_id, data in votes.items():\n",
" total = data['total_runs']\n",
" threshold = total / 2 # > 50%\n",
" \n",
" row = {'id': sent_id, 'sentence': data['sentence']}\n",
" \n",
" # Majority vote on 6 functional classes\n",
" for cls in classes:\n",
" row[cls] = 1 if data[cls] > threshold else 0\n",
" \n",
" # None is derived: all functional classes = 0\n",
" row['None'] = 1 if all(row[cls] == 0 for cls in classes) else 0\n",
" \n",
" # Vote counts (for reference / pseudo-probabilities)\n",
" for cls in classes:\n",
" row[f'{cls}_votes'] = data[cls]\n",
" row['None_votes'] = data['none_runs']\n",
" row['total_runs'] = total\n",
" \n",
" results.append(row)\n",
" \n",
" df_result = pd.DataFrame(results)\n",
" label_cols = classes + ['None']\n",
" vote_cols = [f'{cls}_votes' for cls in classes] + ['None_votes', 'total_runs']\n",
" df_result = df_result[['id', 'sentence'] + label_cols + vote_cols]\n",
" df_result = df_result.sort_values('id').reset_index(drop=True)\n",
" \n",
" return df_result\n",
"\n",
"# -----------------------------\n",
"# Main\n",
"# -----------------------------\n",
"\n",
"print(\"=\"*70)\n",
"print(\"MAJORITY VOTE PER MODEL\")\n",
"print(\"=\"*70)\n",
"print(f\"Holdout dir: {HOLDOUT_DIR}\")\n",
"print(f\"Runs: {RUNS}\")\n",
"\n",
"# Discover models\n",
"models = discover_models(HOLDOUT_DIR, RUNS)\n",
"\n",
"print(f\"\\nFound {len(models)} model(s) with all {len(RUNS)} runs:\")\n",
"for (vendor, model, mode), paths in models.items():\n",
" print(f\" {vendor} / {model} / {mode}\")\n",
"\n",
"# Process each model\n",
"all_majority = {}\n",
"\n",
"for (vendor, model, mode), paths in models.items():\n",
" model_key = f\"{vendor}_{model}_{mode}\"\n",
" \n",
" print(f\"\\n{'='*70}\")\n",
" print(f\"MODEL: {vendor} / {model} / {mode}\")\n",
" print(f\"{'='*70}\")\n",
" \n",
" # Compute majority vote\n",
" df_mv = compute_majority_vote(paths, CLASSES, RUNS)\n",
" all_majority[model_key] = df_mv\n",
" \n",
" # Report\n",
" print(f\"Total sentences: {len(df_mv)}\")\n",
" print(f\"Runs per sentence: {df_mv['total_runs'].iloc[0]}\")\n",
" \n",
" print(f\"\\n{'Class':<15} {'Count':>8} {'Share':>10}\")\n",
" print(\"-\"*35)\n",
" for cls in CLASSES:\n",
" count = df_mv[cls].sum()\n",
" share = count / len(df_mv) * 100\n",
" print(f\"{cls:<15} {count:>8} {share:>9.1f}%\")\n",
" \n",
" # None (derived)\n",
" none_count = df_mv['None'].sum()\n",
" none_share = none_count / len(df_mv) * 100\n",
" print(\"-\"*35)\n",
" print(f\"{'None (derived)':<15} {none_count:>8} {none_share:>9.1f}%\")\n",
" \n",
" # Multi-label distribution\n",
" num_labels = df_mv[CLASSES].sum(axis=1)\n",
" print(f\"\\n{'# Labels':<15} {'Count':>8} {'Share':>10}\")\n",
" print(\"-\"*35)\n",
" for n in range(7):\n",
" count = (num_labels == n).sum()\n",
" if count > 0:\n",
" share = count / len(df_mv) * 100\n",
" label_text = f\"{n} label{'s' if n != 1 else ''}\"\n",
" print(f\"{label_text:<15} {count:>8} {share:>9.1f}%\")\n",
" \n",
" # Save (with dataset prefix for consistency)\n",
" dataset_tag = os.path.basename(os.path.normpath(HOLDOUT_DIR)).lower()\n",
" save_path = f\"{HOLDOUT_DIR}/{dataset_tag}_{model_key}_majority_vote\"\n",
" df_mv.to_csv(f\"{save_path}.csv\", index=False)\n",
" df_mv.to_pickle(f\"{save_path}.pkl\")\n",
" print(f\"\\nSaved: {save_path}.csv (.pkl)\")\n",
"\n",
"# -----------------------------\n",
"# Summary across models\n",
"# -----------------------------\n",
"\n",
"if len(all_majority) > 1:\n",
" print(f\"\\n{'='*70}\")\n",
" print(\"COMPARISON ACROSS MODELS\")\n",
" print(f\"{'='*70}\")\n",
" \n",
" # Header\n",
" model_keys = list(all_majority.keys())\n",
" header = f\"{'Class':<15}\"\n",
" for key in model_keys:\n",
" short = key.split('_', 1)[1] # Remove vendor prefix for display\n",
" header += f\" {short:>20}\"\n",
" print(header)\n",
" print(\"-\" * (15 + 21 * len(model_keys)))\n",
" \n",
" for cls in CLASSES:\n",
" row = f\"{cls:<15}\"\n",
" for key in model_keys:\n",
" df_mv = all_majority[key]\n",
" count = df_mv[cls].sum()\n",
" share = count / len(df_mv) * 100\n",
" row += f\" {count:>8} ({share:>5.1f}%)\"\n",
" print(row)\n",
" \n",
" # None row\n",
" row = f\"{'None (derived)':<15}\"\n",
" for key in model_keys:\n",
" df_mv = all_majority[key]\n",
" count = df_mv['None'].sum()\n",
" share = count / len(df_mv) * 100\n",
" row += f\" {count:>8} ({share:>5.1f}%)\"\n",
" print(\"-\" * (15 + 21 * len(model_keys)))\n",
" print(row)\n",
"\n",
"print(f\"\\n{'='*70}\")\n",
"print(\"DONE\")\n",
"print(f\"{'='*70}\")"
]
},
{
"cell_type": "markdown",
"id": "4222ebe5-8ee4-444d-9d43-26a155112487",
"metadata": {},
"source": [
"# genAI Model Performance\n",
"\n",
"- Check genAI majority labels against ground truth from human experts"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "820abbf4-8938-4c3f-bfa0-a458a4841f5a",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# EVALUATE GenAI MODELS vs HUMAN LABELS ON HOLDOUT\n",
"# =============================================================================\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import glob\n",
"import json\n",
"from sklearn.metrics import (\n",
" f1_score,\n",
" roc_auc_score,\n",
" matthews_corrcoef,\n",
" precision_score,\n",
" recall_score,\n",
" confusion_matrix,\n",
" hamming_loss,\n",
" jaccard_score,\n",
")\n",
"\n",
"# -----------------------------\n",
"# Configuration\n",
"# -----------------------------\n",
"\n",
"HOLDOUT_DIR = \"./output/holdout\" # Where model majority vote files are\n",
"HUMAN_LABELS_PATH = \"./Holdout_Train/holdout_human.pkl\" # Ground truth\n",
"CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4ffde335-fa5c-4944-a0ff-1b25fbd054d9",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Load Human Labels (Ground Truth)\n",
"# -----------------------------\n",
"\n",
"print(\"=\"*70)\n",
"print(\"EVALUATE GenAI MODELS vs HUMAN LABELS\")\n",
"print(\"=\"*70)\n",
"\n",
"df_human = pd.read_pickle(HUMAN_LABELS_PATH)\n",
"print(f\"Human-labeled holdout: {len(df_human)} sentences\")\n",
"print(f\"Source: {HUMAN_LABELS_PATH}\")\n",
"\n",
"print(f\"\\nHuman label distribution:\")\n",
"for cls in CLASSES:\n",
" count = df_human[cls].sum()\n",
" share = count / len(df_human) * 100\n",
" print(f\" {cls}: {int(count)} ({share:.1f}%)\")\n",
"none_count = (df_human[CLASSES].sum(axis=1) == 0).sum()\n",
"print(f\" None (derived): {none_count} ({none_count / len(df_human) * 100:.1f}%)\")\n",
"\n",
"# -----------------------------\n",
"# Discover Model Majority Vote Files\n",
"# -----------------------------\n",
"\n",
"mv_files = glob.glob(f\"{HOLDOUT_DIR}/*_majority_vote.pkl\")\n",
"print(f\"\\nFound {len(mv_files)} model majority vote file(s):\")\n",
"for f in mv_files:\n",
" print(f\" {f.split('/')[-1]}\")\n",
"\n",
"# -----------------------------\n",
"# Evaluation Functions\n",
"# -----------------------------\n",
"\n",
"def evaluate_multilabel(y_true, y_pred, class_names):\n",
" \"\"\"\n",
" Comprehensive multi-label evaluation on the 6 functional classes.\n",
" \n",
" Each label is an independent binary decision.\n",
" None is derived (all classes = 0), not independently predicted,\n",
" so it is excluded from macro averaging but reported separately.\n",
" No AUC: requires continuous probabilities, not available for binary GenAI predictions.\n",
" \"\"\"\n",
" # Per-class metrics (6 functional classes)\n",
" per_class = {}\n",
" for i, cls in enumerate(class_names):\n",
" y_t = y_true[:, i]\n",
" y_p = y_pred[:, i]\n",
" \n",
" per_class[cls] = {\n",
" 'f1': f1_score(y_t, y_p, zero_division=0),\n",
" 'precision': precision_score(y_t, y_p, zero_division=0),\n",
" 'recall': recall_score(y_t, y_p, zero_division=0),\n",
" 'mcc': matthews_corrcoef(y_t, y_p),\n",
" 'support_true': int(y_t.sum()),\n",
" 'support_pred': int(y_p.sum()),\n",
" }\n",
" \n",
" # Macro averages (6 functional classes only)\n",
" macro = {\n",
" 'f1_macro': np.mean([per_class[c]['f1'] for c in class_names]),\n",
" 'precision_macro': np.mean([per_class[c]['precision'] for c in class_names]),\n",
" 'recall_macro': np.mean([per_class[c]['recall'] for c in class_names]),\n",
" 'mcc_macro': np.mean([per_class[c]['mcc'] for c in class_names]),\n",
" }\n",
" \n",
" # Micro averages (pooled across 6 classes)\n",
" micro = {\n",
" 'f1_micro': f1_score(y_true, y_pred, average='micro', zero_division=0),\n",
" 'precision_micro': precision_score(y_true, y_pred, average='micro', zero_division=0),\n",
" 'recall_micro': recall_score(y_true, y_pred, average='micro', zero_division=0),\n",
" }\n",
" \n",
" # Sample-based metrics (per sentence, 6 classes)\n",
" # Note: sentences where both true and predicted are all-zero (None) score 0/0.\n",
" # Using zero_division=1 so correct \"None\" predictions score 1.0, not 0.0.\n",
" sample = {\n",
" 'f1_samples': f1_score(y_true, y_pred, average='samples', zero_division=1),\n",
" 'jaccard_samples': jaccard_score(y_true, y_pred, average='samples', zero_division=1),\n",
" }\n",
" \n",
" # Overall metrics\n",
" overall = {\n",
" 'exact_match_ratio': (y_pred == y_true).all(axis=1).mean(),\n",
" 'hamming_loss': hamming_loss(y_true, y_pred),\n",
" }\n",
" \n",
" # None class (derived: all classes = 0, reported separately)\n",
" none_true = (y_true.sum(axis=1) == 0).astype(int)\n",
" none_pred = (y_pred.sum(axis=1) == 0).astype(int)\n",
" \n",
" none = {\n",
" 'f1': f1_score(none_true, none_pred, zero_division=0),\n",
" 'precision': precision_score(none_true, none_pred, zero_division=0),\n",
" 'recall': recall_score(none_true, none_pred, zero_division=0),\n",
" 'mcc': matthews_corrcoef(none_true, none_pred),\n",
" 'support_true': int(none_true.sum()),\n",
" 'support_pred': int(none_pred.sum()),\n",
" }\n",
" \n",
" return {\n",
" 'per_class': per_class,\n",
" 'macro': macro,\n",
" 'micro': micro,\n",
" 'sample': sample,\n",
" 'overall': overall,\n",
" 'none': none,\n",
" }\n",
"\n",
"\n",
"def print_evaluation_report(model_name, metrics, class_names):\n",
" \"\"\"Print formatted evaluation report for one model.\"\"\"\n",
" \n",
" print(f\"\\n{'='*70}\")\n",
" print(f\"MODEL: {model_name}\")\n",
" print(f\"{'='*70}\")\n",
" \n",
" # Per-class table\n",
" print(f\"\\n{'Class':<12} {'F1':>7} {'MCC':>7} {'Prec':>7} {'Rec':>7} {'True':>6} {'Pred':>6}\")\n",
" print(\"-\"*58)\n",
" \n",
" for cls in class_names:\n",
" m = metrics['per_class'][cls]\n",
" print(f\"{cls:<12} {m['f1']:>7.4f} {m['mcc']:>7.4f} \"\n",
" f\"{m['precision']:>7.4f} {m['recall']:>7.4f} {m['support_true']:>6} {m['support_pred']:>6}\")\n",
" \n",
" # None (derived, separate from macro)\n",
" print(\"-\"*58)\n",
" n = metrics['none']\n",
" print(f\"{'None':<12} {n['f1']:>7.4f} {n['mcc']:>7.4f} \"\n",
" f\"{n['precision']:>7.4f} {n['recall']:>7.4f} {n['support_true']:>6} {n['support_pred']:>6}\")\n",
" print(f\"{'':>12} (derived: all classes = 0, excluded from Macro)\")\n",
" \n",
" # Macro averages (6 functional classes)\n",
" print(\"-\"*58)\n",
" m = metrics['macro']\n",
" print(f\"{'Macro Avg':<12} {m['f1_macro']:>7.4f} {m['mcc_macro']:>7.4f} \"\n",
" f\"{m['precision_macro']:>7.4f} {m['recall_macro']:>7.4f}\")\n",
" print(f\"{'(6 classes)':>12}\")\n",
" \n",
" # Summary\n",
" print(\"\\n--- Summary Metrics ---\")\n",
" print(f\" Macro F1: {metrics['macro']['f1_macro']:.4f} ← Primary (class-balanced, 6 classes)\")\n",
" print(f\" Macro MCC: {metrics['macro']['mcc_macro']:.4f} ← Robust to imbalance\")\n",
" print(f\" Micro F1: {metrics['micro']['f1_micro']:.4f} ← Pooled across 6 classes\")\n",
" print(f\" Sample F1: {metrics['sample']['f1_samples']:.4f} ← Per-sentence average\")\n",
" print(f\" Sample Jaccard: {metrics['sample']['jaccard_samples']:.4f} ← Per-sentence |∩|/||\")\n",
" print(f\" Exact Match: {metrics['overall']['exact_match_ratio']:.4f} ← All 6 labels correct\")\n",
" print(f\" Hamming Loss: {metrics['overall']['hamming_loss']:.4f} ← Fraction wrong (lower=better)\")\n",
"\n",
" # Interpretation\n",
" print(\"\\n--- Interpretation ---\")\n",
" gap = metrics['micro']['f1_micro'] - metrics['macro']['f1_macro']\n",
" print(f\" Macro F1 ({metrics['macro']['f1_macro']:.4f}) treats 6 functional classes equally;\")\n",
" print(f\" Micro F1 ({metrics['micro']['f1_micro']:.4f}) pools all labels, weighted by frequency.\")\n",
" if gap > 0.05:\n",
" print(f\" Gap of {gap:.4f} suggests rare classes (IT, HR) drag down Macro F1.\")\n",
" \n",
" print(f\" Sample F1 ({metrics['sample']['f1_samples']:.4f}) and Jaccard ({metrics['sample']['jaccard_samples']:.4f}):\")\n",
" print(f\" Per-sentence evaluation. Correct 'None' predictions (both empty) score 1.0.\")\n",
" print(f\" Exact Match ({metrics['overall']['exact_match_ratio']:.4f}): {metrics['overall']['exact_match_ratio']*100:.1f}% of sentences had ALL labels correct.\")\n",
" print(f\" Binary per sentence: one wrong label fails the entire sentence.\")\n",
" print(f\" Hamming Loss ({metrics['overall']['hamming_loss']:.4f}): {metrics['overall']['hamming_loss']*100:.1f}% of individual label decisions are wrong.\")\n",
" print(f\" MCC ({metrics['macro']['mcc_macro']:.4f}): Balanced metric even with class imbalance. >0.70 is good.\")\n",
" print(f\" None (derived): F1 = {metrics['none']['f1']:.4f}, MCC = {metrics['none']['mcc']:.4f}\")\n",
" print(f\" Not independently predicted — excluded from Macro to avoid inflating scores.\")\n",
"\n",
" # For paper\n",
" print(\"\\n--- For Paper/Report ---\")\n",
" print(f\" Primary: Macro F1 = {metrics['macro']['f1_macro']:.4f}\")\n",
" print(f\" Secondary: Macro MCC = {metrics['macro']['mcc_macro']:.4f}\")\n",
" print(f\" Overall: Exact Match = {metrics['overall']['exact_match_ratio']:.4f}, Hamming Loss = {metrics['overall']['hamming_loss']:.4f}\")\n",
"\n",
"\n",
"# -----------------------------\n",
"# Evaluate Each Model\n",
"# -----------------------------\n",
"\n",
"all_metrics = {}\n",
"\n",
"for mv_file in sorted(mv_files):\n",
" filename = mv_file.split('/')[-1]\n",
" #model_name = filename.replace('_majority_vote.pkl', '')\n",
" model_name = filename.replace('_majority_vote.pkl', '')\n",
" # Strip dataset prefix for clean display\n",
" for prefix in ['holdout_', 'train_', 'prelim_', 'output_']:\n",
" if model_name.startswith(prefix):\n",
" model_name = model_name[len(prefix):]\n",
" break\n",
" \n",
" # Load model predictions\n",
" df_model = pd.read_pickle(mv_file)\n",
" \n",
" # Merge with human labels on sentence (6 functional classes only)\n",
" df_merged = df_human[['sentence'] + CLASSES].merge(\n",
" df_model[['sentence'] + CLASSES],\n",
" on='sentence',\n",
" how='inner',\n",
" suffixes=('_true', '_pred'),\n",
" )\n",
" \n",
" # Check match rate\n",
" match_rate = len(df_merged) / len(df_human) * 100\n",
" print(f\"\\n{'─'*70}\")\n",
" print(f\"Model: {model_name}\")\n",
" print(f\"Matched sentences: {len(df_merged)}/{len(df_human)} ({match_rate:.1f}%)\")\n",
" \n",
" if len(df_merged) == 0:\n",
" print(\" WARNING: No matching sentences! Skipping.\")\n",
" continue\n",
" \n",
" # Extract true labels (human) — 6 functional classes\n",
" y_true = df_merged[[f'{cls}_true' for cls in CLASSES]].values.astype(int)\n",
" \n",
" # Extract predicted labels (model majority vote) — 6 functional classes\n",
" y_pred = df_merged[[f'{cls}_pred' for cls in CLASSES]].values.astype(int)\n",
" \n",
" # Evaluate\n",
" metrics = evaluate_multilabel(y_true, y_pred, CLASSES)\n",
" all_metrics[model_name] = metrics\n",
" \n",
" # Print report\n",
" print_evaluation_report(model_name, metrics, CLASSES)\n",
" \n",
" # Confusion matrices\n",
" print(f\"\\n--- Confusion Matrices ---\")\n",
" print(f\" {'Class':<12} {'TP':>6} {'TN':>6} {'FP':>6} {'FN':>6}\")\n",
" print(f\" {'-'*38}\")\n",
" \n",
" for i, cls in enumerate(CLASSES):\n",
" cm = confusion_matrix(y_true[:, i], y_pred[:, i])\n",
" tn, fp, fn, tp = cm.ravel()\n",
" print(f\" {cls:<12} {tp:>6} {tn:>6} {fp:>6} {fn:>6}\")\n",
" \n",
" # None confusion matrix (derived)\n",
" none_true = (y_true.sum(axis=1) == 0).astype(int)\n",
" none_pred = (y_pred.sum(axis=1) == 0).astype(int)\n",
" none_cm = confusion_matrix(none_true, none_pred)\n",
" none_tn, none_fp, none_fn, none_tp = none_cm.ravel()\n",
" print(f\" {'-'*38}\")\n",
" print(f\" {'None':<12} {none_tp:>6} {none_tn:>6} {none_fp:>6} {none_fn:>6}\")\n",
"\n",
"# -----------------------------\n",
"# Comparison Summary\n",
"# -----------------------------\n",
"\n",
"if len(all_metrics) > 1:\n",
" print(f\"\\n{'='*70}\")\n",
" print(\"MODEL COMPARISON SUMMARY\")\n",
" print(f\"{'='*70}\")\n",
" \n",
" # Summary table\n",
" print(f\"\\n{'Model':<40} {'F1 Mac':>8} {'MCC Mac':>8} {'Exact':>8} {'Hamming':>8}\")\n",
" print(\"-\"*68)\n",
" \n",
" summary_rows = []\n",
" for model_name, metrics in sorted(all_metrics.items(), \n",
" key=lambda x: x[1]['macro']['f1_macro'], \n",
" reverse=True):\n",
" m = metrics['macro']\n",
" o = metrics['overall']\n",
" print(f\"{model_name:<40} {m['f1_macro']:>8.4f} {m['mcc_macro']:>8.4f} \"\n",
" f\"{o['exact_match_ratio']:>8.4f} {o['hamming_loss']:>8.4f}\")\n",
" \n",
" summary_rows.append({\n",
" 'model': model_name,\n",
" 'f1_macro': m['f1_macro'],\n",
" 'mcc_macro': m['mcc_macro'],\n",
" 'f1_micro': metrics['micro']['f1_micro'],\n",
" 'f1_samples': metrics['sample']['f1_samples'],\n",
" 'jaccard_samples': metrics['sample']['jaccard_samples'],\n",
" 'exact_match': o['exact_match_ratio'],\n",
" 'hamming_loss': o['hamming_loss'],\n",
" 'none_f1': metrics['none']['f1'],\n",
" 'none_mcc': metrics['none']['mcc'],\n",
" })\n",
" \n",
" # Per-class F1 comparison\n",
" print(f\"\\n--- Per-Class F1 Comparison ---\")\n",
" header = f\"{'Class':<12}\"\n",
" for model_name in sorted(all_metrics.keys()):\n",
" short = model_name.replace('azure_', '').replace('fireworks_', 'fw_')\n",
" header += f\" {short:>20}\"\n",
" print(header)\n",
" print(\"-\" * (12 + 21 * len(all_metrics)))\n",
" \n",
" for cls in CLASSES + ['None', 'Macro Avg']:\n",
" row = f\"{cls:<12}\"\n",
" if cls == 'None':\n",
" row = f\"{'-'*12}\\n{cls:<12}\"\n",
" for model_name in sorted(all_metrics.keys()):\n",
" if cls == 'Macro Avg':\n",
" val = all_metrics[model_name]['macro']['f1_macro']\n",
" elif cls == 'None':\n",
" val = all_metrics[model_name]['none']['f1']\n",
" else:\n",
" val = all_metrics[model_name]['per_class'][cls]['f1']\n",
" row += f\" {val:>20.4f}\"\n",
" print(row)\n",
" \n",
" # Save comparison\n",
" df_summary = pd.DataFrame(summary_rows).sort_values('f1_macro', ascending=False)\n",
" df_summary.to_csv(f\"{HOLDOUT_DIR}/model_vs_human_comparison.csv\", index=False)\n",
" print(f\"\\nSaved comparison to {HOLDOUT_DIR}/model_vs_human_comparison.csv\")\n",
"\n",
"# -----------------------------\n",
"# Save all metrics\n",
"# -----------------------------\n",
"\n",
"def convert_for_json(obj):\n",
" \"\"\"Convert numpy types for JSON serialization.\"\"\"\n",
" if isinstance(obj, (np.integer,)):\n",
" return int(obj)\n",
" if isinstance(obj, (np.floating,)):\n",
" return float(obj)\n",
" if isinstance(obj, np.ndarray):\n",
" return obj.tolist()\n",
" if isinstance(obj, dict):\n",
" return {k: convert_for_json(v) for k, v in obj.items()}\n",
" return obj\n",
"\n",
"all_metrics_json = convert_for_json(all_metrics)\n",
"with open(f\"{HOLDOUT_DIR}/model_vs_human_metrics.json\", 'w') as f:\n",
" json.dump(all_metrics_json, f, indent=2)\n",
"print(f\"Saved all metrics to {HOLDOUT_DIR}/model_vs_human_metrics.json\")\n",
"\n",
"print(f\"\\n{'='*70}\")\n",
"print(\"DONE\")\n",
"print(f\"{'='*70}\")"
]
},
{
"cell_type": "markdown",
"id": "0af919a8-c95a-4598-847e-142fedee656c",
"metadata": {},
"source": [
"# Label Training Data with genAI\n",
"\n",
"- Determine which genAI model (or models) does on holdout\n",
"- Assume that this will also be the best model for labeling the train data\n",
"- Label train data at least 3 times (i.e. 3 runs: [1,2,3] with best genAI model (or models)\n",
"- Build final train set with majority votes across runs (and models if you want to pool genAI models for possibly better performance)\n",
"> I am giving an example here using GPT-4.1. ***This wil most likely not be the best model! Try others on the holdout!***"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "45c897fb-16e2-4ebe-abaf-9ab6be1dffe2",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# OUTPUT CONFIGURATION - SET for Train\n",
"# =============================================================================\n",
"CHECKPOINT_DIR = \"./checkpoints\"\n",
"OUTPUT_DIR = \"./output\"\n",
"TRAIN_DIR = f\"{OUTPUT_DIR}/train\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "72367ec1-6d23-44e3-81f5-826f53518666",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# Load Train dataset\n",
"# =============================================================================\n",
"\n",
"# # Load Data: Labeling of (balanced) holdout data\n",
"df = pd.read_pickle(\"Holdout_Train/train.pkl\")\n",
"df = df.head(101)\n",
"\n",
"# =============================================================================\n",
"# Query genAI model via API\n",
"# =============================================================================\n",
"\n",
"# Here an example for gpt-4.1 via UNC Azure API. \n",
"# Most likely, this will not be the best model! Try other vendors and other models.\n",
"results = run_classification(\n",
" df,\n",
" sentence_col=\"sentence\", # --> Check if your file has the column as named here where the text is (sentence vs sentences vs text vs tweet vs ... )\n",
" vendor=\"azure\",\n",
" model=\"gpt-4.1\",\n",
" runs=[1,2,3], # Doing 3 runs for conistency / replicability: will later take majority vote per lable across runs\n",
" output_dir=TRAIN_DIR, # overrides default: OUTPUT_DIR vs. HOLDOUT_DIR vs. TRAIN_DIR vs. PRELIM_DIR\n",
" checkpoint_dir=CHECKPOINT_DIR, # overrides default\n",
")"
]
},
{
"cell_type": "markdown",
"id": "16420fc3-a55a-4e90-a306-38f52ff8a3f6",
"metadata": {},
"source": [
"# Majority Votes\n",
"\n",
"- Across Runs\n",
"- If you are ***pooling models***, then you need to also do this ***Across Models***: THIS CODE DOES NOT DO THAT!\n",
"- Does not handle ties yet (assumes odd number of total runs across models)\n",
"- Assumes multi-label problem (one sentence can have multiple labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "45172423-8146-478a-bfcf-7060e52f60c0",
"metadata": {},
"outputs": [],
"source": [
"# =============================================================================\n",
"# MAJORITY VOTE PER MODEL (across its own runs)\n",
"# =============================================================================\n",
"\n",
"import pandas as pd\n",
"import numpy as np\n",
"import glob\n",
"import ast\n",
"import re\n",
"import os\n",
"\n",
"# -----------------------------\n",
"# Configuration\n",
"# -----------------------------\n",
"\n",
"TRAIN_DIR = \"./output/train\" # Now pointing to train (all code is the same, but we are sing a differnt folder for the files\n",
"RUNS = [1, 2, 3]\n",
"CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07170d67-f36c-45c6-a41c-2843cc5c52e0",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Helper Functions\n",
"# -----------------------------\n",
"\n",
"def parse_labels_safe(labels):\n",
" \"\"\"Parse labels from various formats.\"\"\"\n",
" if labels is None:\n",
" return []\n",
" if isinstance(labels, list):\n",
" return labels\n",
" if isinstance(labels, str):\n",
" try:\n",
" parsed = ast.literal_eval(labels)\n",
" return parsed if isinstance(parsed, list) else []\n",
" except:\n",
" return []\n",
" return []\n",
"\n",
"def parse_filename(filename):\n",
" \"\"\"\n",
" Parse filename like 'train_azure_gpt-4.1_chat_run1.pkl'\n",
" into (vendor, model, mode, run).\n",
" Handles optional dataset prefix (holdout_, train_, prelim_, output_).\n",
" \"\"\"\n",
" name = filename.replace('.pkl', '').replace('.csv', '')\n",
" \n",
" # Strip known dataset prefixes\n",
" for prefix in ['holdout_', 'train_', 'prelim_', 'output_']:\n",
" if name.startswith(prefix):\n",
" name = name[len(prefix):]\n",
" break\n",
" \n",
" match = re.search(r'_run(\\d+)$', name)\n",
" if not match:\n",
" return None\n",
" run = int(match.group(1))\n",
" name = name[:match.start()]\n",
" parts = name.rsplit('_', 1)\n",
" if len(parts) != 2:\n",
" return None\n",
" mode = parts[1]\n",
" remainder = parts[0]\n",
" first_underscore = remainder.index('_')\n",
" vendor = remainder[:first_underscore]\n",
" model = remainder[first_underscore + 1:]\n",
" return vendor, model, mode, run\n",
"\n",
"def discover_models(train_dir, runs):\n",
" \"\"\"Find all vendor_model_mode combos that have ALL specified runs.\"\"\"\n",
" models = {}\n",
" for run in runs:\n",
" run_dir = f\"{train_dir}/Run{run}\"\n",
" if not os.path.exists(run_dir):\n",
" continue\n",
" for filepath in glob.glob(f\"{run_dir}/*.pkl\"):\n",
" filename = filepath.split(\"/\")[-1]\n",
" parsed = parse_filename(filename)\n",
" if parsed is None:\n",
" continue\n",
" vendor, model, mode, file_run = parsed\n",
" if file_run != run:\n",
" continue\n",
" key = (vendor, model, mode)\n",
" if key not in models:\n",
" models[key] = {}\n",
" models[key][run] = filepath\n",
" \n",
" # Keep only models with ALL specified runs\n",
" return {k: v for k, v in models.items() if all(r in v for r in runs)}\n",
"\n",
"def compute_majority_vote(model_runs, classes, runs):\n",
" \"\"\"\n",
" Compute majority vote for a single model across its runs.\n",
" \n",
" Votes are counted for the 6 functional classes only.\n",
" None is derived: if no functional class gets majority, None = 1.\n",
" None_votes tracks how many runs returned an empty label list (for reference).\n",
" \n",
" Returns DataFrame with sentence + binary labels + vote counts.\n",
" \"\"\"\n",
" # Load all runs\n",
" run_dfs = {}\n",
" for run in runs:\n",
" run_dfs[run] = pd.read_pickle(model_runs[run])\n",
" \n",
" # Collect votes per sentence\n",
" votes = {}\n",
" for run in runs:\n",
" df = run_dfs[run]\n",
" for _, row in df.iterrows():\n",
" sent_id = row['id']\n",
" sentence = row['sentence']\n",
" labels = parse_labels_safe(row['labels'])\n",
" \n",
" # Skip error rows\n",
" if 'error' in df.columns and pd.notna(row.get('error')):\n",
" continue\n",
" \n",
" if sent_id not in votes:\n",
" votes[sent_id] = {\n",
" 'sentence': sentence,\n",
" **{cls: 0 for cls in classes},\n",
" 'none_runs': 0,\n",
" 'total_runs': 0,\n",
" }\n",
" \n",
" votes[sent_id]['total_runs'] += 1\n",
" \n",
" if len(labels) == 0:\n",
" votes[sent_id]['none_runs'] += 1\n",
" else:\n",
" for label in labels:\n",
" if label in classes:\n",
" votes[sent_id][label] += 1\n",
" \n",
" # Calculate majority\n",
" results = []\n",
" for sent_id, data in votes.items():\n",
" total = data['total_runs']\n",
" threshold = total / 2 # > 50%\n",
" \n",
" row = {'id': sent_id, 'sentence': data['sentence']}\n",
" \n",
" # Majority vote on 6 functional classes\n",
" for cls in classes:\n",
" row[cls] = 1 if data[cls] > threshold else 0\n",
" \n",
" # None is derived: all functional classes = 0\n",
" row['None'] = 1 if all(row[cls] == 0 for cls in classes) else 0\n",
" \n",
" # Vote counts (for reference / pseudo-probabilities)\n",
" for cls in classes:\n",
" row[f'{cls}_votes'] = data[cls]\n",
" row['None_votes'] = data['none_runs']\n",
" row['total_runs'] = total\n",
" \n",
" results.append(row)\n",
" \n",
" df_result = pd.DataFrame(results)\n",
" label_cols = classes + ['None']\n",
" vote_cols = [f'{cls}_votes' for cls in classes] + ['None_votes', 'total_runs']\n",
" df_result = df_result[['id', 'sentence'] + label_cols + vote_cols]\n",
" df_result = df_result.sort_values('id').reset_index(drop=True)\n",
" \n",
" return df_result\n",
"\n",
"# -----------------------------\n",
"# Main\n",
"# -----------------------------\n",
"\n",
"print(\"=\"*70)\n",
"print(\"MAJORITY VOTE PER MODEL\")\n",
"print(\"=\"*70)\n",
"print(f\"Train dir: {TRAIN_DIR}\")\n",
"print(f\"Runs: {RUNS}\")\n",
"\n",
"# Discover models\n",
"models = discover_models(TRAIN_DIR, RUNS)\n",
"\n",
"print(f\"\\nFound {len(models)} model(s) with all {len(RUNS)} runs:\")\n",
"for (vendor, model, mode), paths in models.items():\n",
" print(f\" {vendor} / {model} / {mode}\")\n",
"\n",
"# Process each model\n",
"all_majority = {}\n",
"\n",
"for (vendor, model, mode), paths in models.items():\n",
" model_key = f\"{vendor}_{model}_{mode}\"\n",
" \n",
" print(f\"\\n{'='*70}\")\n",
" print(f\"MODEL: {vendor} / {model} / {mode}\")\n",
" print(f\"{'='*70}\")\n",
" \n",
" # Compute majority vote\n",
" df_mv = compute_majority_vote(paths, CLASSES, RUNS)\n",
" all_majority[model_key] = df_mv\n",
" \n",
" # Report\n",
" print(f\"Total sentences: {len(df_mv)}\")\n",
" print(f\"Runs per sentence: {df_mv['total_runs'].iloc[0]}\")\n",
" \n",
" print(f\"\\n{'Class':<15} {'Count':>8} {'Share':>10}\")\n",
" print(\"-\"*35)\n",
" for cls in CLASSES:\n",
" count = df_mv[cls].sum()\n",
" share = count / len(df_mv) * 100\n",
" print(f\"{cls:<15} {count:>8} {share:>9.1f}%\")\n",
" \n",
" # None (derived)\n",
" none_count = df_mv['None'].sum()\n",
" none_share = none_count / len(df_mv) * 100\n",
" print(\"-\"*35)\n",
" print(f\"{'None (derived)':<15} {none_count:>8} {none_share:>9.1f}%\")\n",
" \n",
" # Multi-label distribution\n",
" num_labels = df_mv[CLASSES].sum(axis=1)\n",
" print(f\"\\n{'# Labels':<15} {'Count':>8} {'Share':>10}\")\n",
" print(\"-\"*35)\n",
" for n in range(7):\n",
" count = (num_labels == n).sum()\n",
" if count > 0:\n",
" share = count / len(df_mv) * 100\n",
" label_text = f\"{n} label{'s' if n != 1 else ''}\"\n",
" print(f\"{label_text:<15} {count:>8} {share:>9.1f}%\")\n",
" \n",
" # Save (with dataset prefix for consistency)\n",
" dataset_tag = os.path.basename(os.path.normpath(TRAIN_DIR)).lower()\n",
" save_path = f\"{TRAIN_DIR}/{dataset_tag}_{model_key}_majority_vote\"\n",
" df_mv.to_csv(f\"{save_path}.csv\", index=False)\n",
" df_mv.to_pickle(f\"{save_path}.pkl\")\n",
" print(f\"\\nSaved: {save_path}.csv (.pkl)\")\n",
"\n",
"# -----------------------------\n",
"# Summary across models\n",
"# -----------------------------\n",
"\n",
"if len(all_majority) > 1:\n",
" print(f\"\\n{'='*70}\")\n",
" print(\"COMPARISON ACROSS MODELS\")\n",
" print(f\"{'='*70}\")\n",
" \n",
" # Header\n",
" model_keys = list(all_majority.keys())\n",
" header = f\"{'Class':<15}\"\n",
" for key in model_keys:\n",
" short = key.split('_', 1)[1] # Remove vendor prefix for display\n",
" header += f\" {short:>20}\"\n",
" print(header)\n",
" print(\"-\" * (15 + 21 * len(model_keys)))\n",
" \n",
" for cls in CLASSES:\n",
" row = f\"{cls:<15}\"\n",
" for key in model_keys:\n",
" df_mv = all_majority[key]\n",
" count = df_mv[cls].sum()\n",
" share = count / len(df_mv) * 100\n",
" row += f\" {count:>8} ({share:>5.1f}%)\"\n",
" print(row)\n",
" \n",
" # None row\n",
" row = f\"{'None (derived)':<15}\"\n",
" for key in model_keys:\n",
" df_mv = all_majority[key]\n",
" count = df_mv['None'].sum()\n",
" share = count / len(df_mv) * 100\n",
" row += f\" {count:>8} ({share:>5.1f}%)\"\n",
" print(\"-\" * (15 + 21 * len(model_keys)))\n",
" print(row)\n",
"\n",
"print(f\"\\n{'='*70}\")\n",
"print(\"DONE\")\n",
"print(f\"{'='*70}\")"
]
},
{
"cell_type": "markdown",
"id": "43c77538-5661-4801-ad7f-aaa06c6eb828",
"metadata": {},
"source": [
"# Fine-Tune a pretrained LLM to create a Vertical AI model\n",
"- We now have training data for fine-tuning a pretrained LLM (here, RoBERTa Large) to become a classifier for business functions\n",
"- You can experiment with different pretrained LLMs on Huggingface that may be more appropriate for your fine-tuning purpose:\n",
"> https://huggingface.co/models?pipeline_tag=fill-mask&library=transformers&sort=trending"
]
},
{
"cell_type": "markdown",
"id": "2ac54fc6-18a0-4577-941c-4aac33d2dd07",
"metadata": {},
"source": [
"### **IMPORTANT**: There is a hyperparameter that determines how long text (in tokens) can be at most for fine-tuning.\n",
"```\n",
"MAX_LENGTH = 128 # max length of sentences in tokens\n",
"```\n",
"> * I set it to 128. If your texts are longer, you need to change that because it will be truncated! \n",
"> * The larger the max length is, the slower the training process. \n",
"> * Why? See class 8 on Deep Learning. Hint: When you have more input tokens, more needs to be embedded and contextualized, which takes computer and RAM."
]
},
{
"cell_type": "markdown",
"id": "a912ff49-3530-42a4-bbc8-624b9b7b695f",
"metadata": {},
"source": [
"## Part 1: Imports and Configuration\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42db0c31-14f2-4354-9e51-cc81a1ca85bc",
"metadata": {},
"outputs": [],
"source": [
"# Install if needed:\n",
"# pip install -q -U transformers datasets accelerate scikit-learn\n",
"\n",
"import os\n",
"import torch\n",
"import numpy as np\n",
"import pandas as pd\n",
"import json\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import (\n",
" f1_score,\n",
" roc_auc_score,\n",
" matthews_corrcoef,\n",
" precision_score,\n",
" recall_score,\n",
" classification_report,\n",
" confusion_matrix,\n",
" hamming_loss,\n",
" jaccard_score,\n",
")\n",
"from transformers import (\n",
" RobertaTokenizer,\n",
" RobertaForSequenceClassification,\n",
" Trainer,\n",
" TrainingArguments,\n",
" EarlyStoppingCallback,\n",
")\n",
"from datasets import Dataset\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"# -----------------------------\n",
"# Configuration\n",
"# -----------------------------\n",
"\n",
"MODEL_NAME = \"roberta-large\"\n",
"CLASSES = [\"Marketing\", \"Finance\", \"Accounting\", \"Operations\", \"IT\", \"HR\"]\n",
"NUM_LABELS = len(CLASSES)\n",
"\n",
"# Hyperparameters\n",
"LEARNING_RATE = 2e-5\n",
"WEIGHT_DECAY = 0.01\n",
"WARMUP_RATIO = 0.1\n",
"NUM_EPOCHS = 4\n",
"BATCH_SIZE = 32 # Reduce if OOM (out of memory)\n",
"GRADIENT_ACCUMULATION = 2 # Effective batch size = 32 * 2 = 54\n",
"MAX_LENGTH = 128 # max length of sentences in tokens\n",
"SEED = 42\n",
"\n",
"# Early stopping\n",
"EARLY_STOPPING_PATIENCE = 2\n",
"\n",
"# Threshold for binary predictions\n",
"THRESHOLD = 0.5\n",
"\n",
"# Mixed precision (True = faster, less memory; False = more stable)\n",
"USE_MIXED_PRECISION = True\n",
"\n",
"# -----------------------------\n",
"# Device Selection (CUDA > MPS > CPU)\n",
"# -----------------------------\n",
"\n",
"def get_device_and_precision():\n",
" \"\"\"Detect best available device and set precision flags.\"\"\"\n",
" if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
" device_name = torch.cuda.get_device_name(0)\n",
" fp16 = USE_MIXED_PRECISION\n",
" bf16 = False\n",
" use_mps = False\n",
" print(f\"Device: CUDA ({device_name})\")\n",
" print(f\" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
" \n",
" elif torch.backends.mps.is_available():\n",
" device = torch.device(\"mps\")\n",
" fp16 = False\n",
" bf16 = USE_MIXED_PRECISION\n",
" use_mps = True\n",
" print(f\"Device: Apple Silicon (MPS)\")\n",
" os.environ[\"PYTORCH_MPS_HIGH_WATERMARK_RATIO\"] = \"0.0\"\n",
" \n",
" else:\n",
" device = torch.device(\"cpu\")\n",
" fp16 = False\n",
" bf16 = False\n",
" use_mps = False\n",
" print(f\"Device: CPU (this will be slow!)\")\n",
" \n",
" precision = \"mixed (fp16)\" if fp16 else \"mixed (bf16)\" if bf16 else \"full (fp32)\"\n",
" print(f\" Precision: {precision}\")\n",
" \n",
" return device, fp16, bf16, use_mps\n",
"\n",
"device, use_fp16, use_bf16, use_mps = get_device_and_precision()\n",
"\n",
"print(f\"\\nConfiguration:\")\n",
"print(f\" Model: {MODEL_NAME}\")\n",
"print(f\" Classes: {CLASSES}\")\n",
"print(f\" Epochs: {NUM_EPOCHS}\")\n",
"print(f\" Batch size: {BATCH_SIZE} x {GRADIENT_ACCUMULATION} = {BATCH_SIZE * GRADIENT_ACCUMULATION}\")\n",
"print(f\" Learning rate: {LEARNING_RATE}\")\n",
"print(f\" Max length: {MAX_LENGTH}\")"
]
},
{
"cell_type": "markdown",
"id": "c3c3e7e6-8c47-4690-a490-b49dd3c08edb",
"metadata": {},
"source": [
"## Part 2: Prepare Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54863fc9-94f2-44ae-8ac8-c80b459b10da",
"metadata": {},
"outputs": [],
"source": [
"#!pip install iterative-stratification"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "81180705-cf64-4f5e-acd4-34e20e4cca39",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Load Data\n",
"# -----------------------------\n",
"\n",
"print(\"=\"*60)\n",
"print(\"LOADING DATA\")\n",
"print(\"=\"*60)\n",
"\n",
"df_train_full = pd.read_pickle(\"./output/train/train_azure_gpt-4.1_chat_majority_vote.pkl\")\n",
"df_holdout = pd.read_pickle(\"./Holdout_Train/holdout_human.pkl\")\n",
"\n",
"print(f\"Train (full): {len(df_train_full)}\")\n",
"print(f\"Holdout: {len(df_holdout)}\")\n",
"\n",
"# -----------------------------\n",
"# Where to save model and results\n",
"# -----------------------------\n",
"model_save_path = \"./roberta_multilabel/best_model\"\n",
"\n",
"\n",
"# -----------------------------\n",
"# Train/Validation Split (90/10)\n",
"# -----------------------------\n",
"\n",
"# Option 1: Iterative stratification (best for multi-label)\n",
"# pip install iterative-stratification\n",
"try:\n",
" from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit\n",
" \n",
" msss = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)\n",
" train_idx, val_idx = next(msss.split(df_train_full, df_train_full[CLASSES].values))\n",
" \n",
" df_train = df_train_full.iloc[train_idx].copy()\n",
" df_val = df_train_full.iloc[val_idx].copy()\n",
" print(\"Using: Iterative multi-label stratification\")\n",
"\n",
"except ImportError:\n",
" # Option 2: Simple random split (no stratification)\n",
" df_train, df_val = train_test_split(\n",
" df_train_full, \n",
" test_size=0.2, \n",
" random_state=SEED,\n",
" )\n",
" print(\"Using: Random split (install 'iterative-stratification' for better stratification)\")\n",
"\n",
"print(f\"\\nSplit:\")\n",
"print(f\" Train: {len(df_train)}\")\n",
"print(f\" Validation: {len(df_val)}\")\n",
"print(f\" Holdout: {len(df_holdout)}\")\n",
"\n",
"# -----------------------------\n",
"# Extract Labels\n",
"# -----------------------------\n",
"\n",
"def get_labels_array(df):\n",
" \"\"\"Extract labels as numpy array.\"\"\"\n",
" return df[CLASSES].values.astype(np.float32)\n",
"\n",
"train_labels = get_labels_array(df_train)\n",
"val_labels = get_labels_array(df_val)\n",
"holdout_labels = get_labels_array(df_holdout)\n",
"\n",
"print(f\"\\nLabel shapes:\")\n",
"print(f\" Train: {train_labels.shape}\")\n",
"print(f\" Val: {val_labels.shape}\")\n",
"print(f\" Holdout: {holdout_labels.shape}\")\n",
"\n",
"# Class distribution\n",
"print(f\"\\nTrain class distribution:\")\n",
"for i, cls in enumerate(CLASSES):\n",
" count = train_labels[:, i].sum()\n",
" pct = count / len(train_labels) * 100\n",
" print(f\" {cls}: {int(count)} ({pct:.1f}%)\")\n",
"\n",
"# Verify validation has all classes\n",
"print(f\"\\nValidation class distribution:\")\n",
"for i, cls in enumerate(CLASSES):\n",
" count = val_labels[:, i].sum()\n",
" pct = count / len(val_labels) * 100\n",
" print(f\" {cls}: {int(count)} ({pct:.1f}%)\")\n",
"\n",
"# -----------------------------\n",
"# Tokenization\n",
"# -----------------------------\n",
"\n",
"print(f\"\\n{'='*60}\")\n",
"print(\"TOKENIZATION\")\n",
"print(\"=\"*60)\n",
"\n",
"tokenizer = RobertaTokenizer.from_pretrained(MODEL_NAME)\n",
"\n",
"def tokenize_data(texts, labels):\n",
" \"\"\"Tokenize texts and create HuggingFace Dataset.\"\"\"\n",
" encodings = tokenizer(\n",
" texts.tolist(),\n",
" truncation=True,\n",
" padding='max_length',\n",
" max_length=MAX_LENGTH,\n",
" return_tensors=None,\n",
" )\n",
" \n",
" dataset = Dataset.from_dict({\n",
" 'input_ids': encodings['input_ids'],\n",
" 'attention_mask': encodings['attention_mask'],\n",
" 'labels': labels.tolist(),\n",
" })\n",
" \n",
" return dataset\n",
"\n",
"print(\"Tokenizing datasets...\")\n",
"train_dataset = tokenize_data(df_train['sentence'].values, train_labels)\n",
"val_dataset = tokenize_data(df_val['sentence'].values, val_labels)\n",
"holdout_dataset = tokenize_data(df_holdout['sentence'].values, holdout_labels)\n",
"\n",
"print(f\" Train: {len(train_dataset)}\")\n",
"print(f\" Validation: {len(val_dataset)}\")\n",
"print(f\" Holdout: {len(holdout_dataset)}\")\n",
"\n",
"# Sequence length stats\n",
"train_lengths = [len(ids) for ids in train_dataset['input_ids']]\n",
"print(f\" Sequence lengths: min={min(train_lengths)}, max={max(train_lengths)}, mean={np.mean(train_lengths):.0f}\")"
]
},
{
"cell_type": "markdown",
"id": "e9fabf48-8b21-4303-b8d4-cc2d1c4a17cd",
"metadata": {},
"source": [
"## PART 3: FINE-TUNE MODEL"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0bdfa218-edd5-41af-bdab-5ccced4510f7",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Load Model\n",
"# -----------------------------\n",
"\n",
"print(\"=\"*60)\n",
"print(\"MODEL\")\n",
"print(\"=\"*60)\n",
"\n",
"model = RobertaForSequenceClassification.from_pretrained(\n",
" MODEL_NAME,\n",
" num_labels=NUM_LABELS,\n",
" problem_type=\"multi_label_classification\", # Enables BCE loss\n",
")\n",
"\n",
"if not use_mps:\n",
" model.to(device)\n",
"\n",
"total_params = sum(p.numel() for p in model.parameters())\n",
"trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total parameters: {total_params:,}\")\n",
"print(f\"Trainable parameters: {trainable_params:,}\")\n",
"\n",
"# -----------------------------\n",
"# Metrics Function\n",
"# -----------------------------\n",
"\n",
"def compute_metrics(eval_pred):\n",
" \"\"\"Compute metrics for validation during training.\"\"\"\n",
" logits, labels = eval_pred\n",
" probs = torch.sigmoid(torch.tensor(logits)).numpy()\n",
" preds = (probs >= THRESHOLD).astype(int)\n",
" \n",
" f1_micro = f1_score(labels, preds, average='micro', zero_division=0)\n",
" f1_macro = f1_score(labels, preds, average='macro', zero_division=0)\n",
" f1_samples = f1_score(labels, preds, average='samples', zero_division=0)\n",
" \n",
" try:\n",
" auc_macro = roc_auc_score(labels, probs, average='macro')\n",
" except ValueError:\n",
" auc_macro = 0.0\n",
" \n",
" mcc_scores = []\n",
" for i in range(labels.shape[1]):\n",
" try:\n",
" mcc = matthews_corrcoef(labels[:, i], preds[:, i])\n",
" mcc_scores.append(mcc)\n",
" except:\n",
" mcc_scores.append(0.0)\n",
" mcc_macro = np.mean(mcc_scores)\n",
" \n",
" return {\n",
" 'f1_micro': f1_micro,\n",
" 'f1_macro': f1_macro,\n",
" 'f1_samples': f1_samples,\n",
" 'auc_macro': auc_macro,\n",
" 'mcc_macro': mcc_macro,\n",
" }\n",
"\n",
"# -----------------------------\n",
"# Training Arguments\n",
"# -----------------------------\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=\"./roberta_multilabel\",\n",
" \n",
" # Training\n",
" num_train_epochs=NUM_EPOCHS,\n",
" per_device_train_batch_size=BATCH_SIZE,\n",
" per_device_eval_batch_size=BATCH_SIZE * 2,\n",
" gradient_accumulation_steps=GRADIENT_ACCUMULATION,\n",
" \n",
" # Optimizer\n",
" learning_rate=LEARNING_RATE,\n",
" weight_decay=WEIGHT_DECAY,\n",
" warmup_ratio=WARMUP_RATIO,\n",
" optim=\"adamw_torch\",\n",
" lr_scheduler_type=\"linear\",\n",
" max_grad_norm=1.0,\n",
" \n",
" # Evaluation & Saving\n",
" eval_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" load_best_model_at_end=True,\n",
" metric_for_best_model=\"f1_macro\",\n",
" greater_is_better=True,\n",
" save_total_limit=2,\n",
" \n",
" # Logging\n",
" logging_dir=\"./logs\",\n",
" logging_steps=50,\n",
" logging_first_step=True,\n",
" report_to=\"none\",\n",
" \n",
" # Precision\n",
" fp16=use_fp16,\n",
" bf16=use_bf16,\n",
" use_mps_device=use_mps,\n",
" \n",
" # Reproducibility\n",
" seed=SEED,\n",
" data_seed=SEED,\n",
" \n",
" # Performance\n",
" dataloader_num_workers=0 if use_mps else 2,\n",
" dataloader_pin_memory=not use_mps,\n",
" remove_unused_columns=False,\n",
")\n",
"\n",
"# -----------------------------\n",
"# Trainer\n",
"# -----------------------------\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=val_dataset,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[\n",
" EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE),\n",
" ],\n",
")\n",
"\n",
"# -----------------------------\n",
"# Train\n",
"# -----------------------------\n",
"\n",
"print(f\"\\n{'='*60}\")\n",
"print(\"TRAINING\")\n",
"print(\"=\"*60)\n",
"\n",
"train_result = trainer.train()\n",
"\n",
"print(f\"\\n{'='*60}\")\n",
"print(\"TRAINING COMPLETE\")\n",
"print(\"=\"*60)\n",
"print(f\"Runtime: {train_result.metrics['train_runtime']:.0f} seconds\")\n",
"print(f\"Samples/second: {train_result.metrics['train_samples_per_second']:.1f}\")\n",
"print(f\"Final loss: {train_result.metrics['train_loss']:.4f}\")\n",
"\n",
"# Validation results\n",
"print(f\"\\nValidation (Best Model):\")\n",
"val_results = trainer.evaluate()\n",
"for key, value in val_results.items():\n",
" if isinstance(value, float):\n",
" print(f\" {key}: {value:.4f}\")\n",
"\n",
"# -----------------------------\n",
"# Save Model\n",
"# -----------------------------\n",
"\n",
"trainer.save_model(model_save_path)\n",
"tokenizer.save_pretrained(model_save_path)\n",
"print(f\"\\nModel saved to: {model_save_path}\")"
]
},
{
"cell_type": "markdown",
"id": "44acf40a-da7d-4aa3-912c-2bdf883fdc54",
"metadata": {},
"source": [
"## Part 4: Evaluate on Holdout (ground truth human expert labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df813061-90b2-4076-87ac-5c5be10ee048",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Set performance metric output path first! It is based in the model save path you defined earlier\n",
"# -----------------------------\n",
"\n",
"MODEL_DIR = model_save_path.replace(\"best_model\", \"performance\")\n",
"os.makedirs(MODEL_DIR, exist_ok=True) "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fede292-3cfc-4a69-8dfd-dd21e224572d",
"metadata": {},
"outputs": [],
"source": [
"# -----------------------------\n",
"# Evaluation Functions\n",
"# -----------------------------\n",
"\n",
"def evaluate_multilabel(y_true, y_pred, y_prob, class_names):\n",
" \"\"\"\n",
" Comprehensive multi-label evaluation.\n",
" \n",
" Each label is treated as an independent binary decision.\n",
" Macro averaging gives equal weight to each class (recommended for imbalanced data).\n",
" \"\"\"\n",
" n_classes = len(class_names)\n",
" \n",
" # Per-class metrics\n",
" per_class = {}\n",
" for i, cls in enumerate(class_names):\n",
" y_t = y_true[:, i]\n",
" y_p = y_pred[:, i]\n",
" y_pr = y_prob[:, i]\n",
" \n",
" try:\n",
" auc = roc_auc_score(y_t, y_pr)\n",
" except:\n",
" auc = 0.0\n",
" \n",
" per_class[cls] = {\n",
" 'f1': f1_score(y_t, y_p, zero_division=0),\n",
" 'precision': precision_score(y_t, y_p, zero_division=0),\n",
" 'recall': recall_score(y_t, y_p, zero_division=0),\n",
" 'mcc': matthews_corrcoef(y_t, y_p),\n",
" 'auc': auc,\n",
" 'support': int(y_t.sum()),\n",
" }\n",
" \n",
" # Macro averages (equal weight to each class)\n",
" macro = {\n",
" 'f1_macro': np.mean([per_class[c]['f1'] for c in class_names]),\n",
" 'precision_macro': np.mean([per_class[c]['precision'] for c in class_names]),\n",
" 'recall_macro': np.mean([per_class[c]['recall'] for c in class_names]),\n",
" 'mcc_macro': np.mean([per_class[c]['mcc'] for c in class_names]),\n",
" 'auc_macro': np.mean([per_class[c]['auc'] for c in class_names]),\n",
" }\n",
" \n",
" # Micro averages (pooled across all classes)\n",
" micro = {\n",
" 'f1_micro': f1_score(y_true, y_pred, average='micro', zero_division=0),\n",
" 'precision_micro': precision_score(y_true, y_pred, average='micro', zero_division=0),\n",
" 'recall_micro': recall_score(y_true, y_pred, average='micro', zero_division=0),\n",
" }\n",
" \n",
" # Sample-based metrics\n",
" # zero_division=1: correct \"None\" predictions (both empty) score 1.0, not 0.0\n",
" sample = {\n",
" 'f1_samples': f1_score(y_true, y_pred, average='samples', zero_division=1),\n",
" 'jaccard_samples': jaccard_score(y_true, y_pred, average='samples', zero_division=1),\n",
" }\n",
" \n",
" # Overall metrics\n",
" overall = {\n",
" 'exact_match_ratio': (y_pred == y_true).all(axis=1).mean(),\n",
" 'hamming_loss': hamming_loss(y_true, y_pred),\n",
" }\n",
" \n",
" return {\n",
" 'per_class': per_class,\n",
" 'macro': macro,\n",
" 'micro': micro,\n",
" 'sample': sample,\n",
" 'overall': overall,\n",
" }\n",
"\n",
"\n",
"def print_evaluation_report(metrics, class_names):\n",
" \"\"\"Print formatted evaluation report.\"\"\"\n",
" \n",
" print(\"=\"*70)\n",
" print(\"HOLDOUT EVALUATION REPORT\")\n",
" print(\"=\"*70)\n",
" \n",
" # Per-class table\n",
" print(\"\\n--- Per-Class Metrics (Each Label = Independent Binary Decision) ---\")\n",
" print(f\"{'Class':<12} {'F1':>7} {'AUC':>7} {'MCC':>7} {'Prec':>7} {'Rec':>7} {'Support':>8}\")\n",
" print(\"-\"*60)\n",
" \n",
" for cls in class_names:\n",
" m = metrics['per_class'][cls]\n",
" print(f\"{cls:<12} {m['f1']:>7.4f} {m['auc']:>7.4f} {m['mcc']:>7.4f} \"\n",
" f\"{m['precision']:>7.4f} {m['recall']:>7.4f} {m['support']:>8}\")\n",
" \n",
" # Macro averages\n",
" print(\"-\"*60)\n",
" m = metrics['macro']\n",
" print(f\"{'Macro Avg':<12} {m['f1_macro']:>7.4f} {m['auc_macro']:>7.4f} {m['mcc_macro']:>7.4f} \"\n",
" f\"{m['precision_macro']:>7.4f} {m['recall_macro']:>7.4f}\")\n",
" \n",
" # Summary\n",
" print(\"\\n--- Summary Metrics ---\")\n",
" print(f\" Macro F1: {metrics['macro']['f1_macro']:.4f} ← Primary (class-balanced)\")\n",
" print(f\" Macro AUC: {metrics['macro']['auc_macro']:.4f} ← Threshold-independent\")\n",
" print(f\" Macro MCC: {metrics['macro']['mcc_macro']:.4f} ← Robust to imbalance\")\n",
" print(f\" Micro F1: {metrics['micro']['f1_micro']:.4f} ← Dominated by frequent classes\")\n",
" print(f\" Sample F1: {metrics['sample']['f1_samples']:.4f} ← Per-instance average\")\n",
" print(f\" Exact Match: {metrics['overall']['exact_match_ratio']:.4f} ← All labels correct\")\n",
" print(f\" Hamming Loss: {metrics['overall']['hamming_loss']:.4f} ← Fraction wrong (lower=better)\")\n",
" \n",
" # For paper\n",
" print(\"\\n--- For Report ---\")\n",
" print(f\" Primary: Macro F1 = {metrics['macro']['f1_macro']:.4f}\")\n",
" print(f\" Secondary: Macro AUC = {metrics['macro']['auc_macro']:.4f}, Macro MCC = {metrics['macro']['mcc_macro']:.4f}\")\n",
" print(f\" Overall: Exact Match = {metrics['overall']['exact_match_ratio']:.4f}\")\n",
"\n",
"\n",
"# -----------------------------\n",
"# Get Predictions\n",
"# -----------------------------\n",
"\n",
"print(\"=\"*60)\n",
"print(\"HOLDOUT PREDICTIONS\")\n",
"print(\"=\"*60)\n",
"\n",
"holdout_output = trainer.predict(holdout_dataset)\n",
"holdout_logits = holdout_output.predictions\n",
"holdout_probs = torch.sigmoid(torch.tensor(holdout_logits)).numpy()\n",
"holdout_preds = (holdout_probs >= THRESHOLD).astype(int)\n",
"\n",
"print(f\"Holdout samples: {len(holdout_labels)}\")\n",
"print(f\"Predictions shape: {holdout_preds.shape}\")\n",
"\n",
"# -----------------------------\n",
"# Evaluate\n",
"# -----------------------------\n",
"\n",
"metrics = evaluate_multilabel(\n",
" y_true=holdout_labels,\n",
" y_pred=holdout_preds,\n",
" y_prob=holdout_probs,\n",
" class_names=CLASSES,\n",
")\n",
"\n",
"print_evaluation_report(metrics, CLASSES)\n",
"\n",
"# -----------------------------\n",
"# None Class (Derived)\n",
"# -----------------------------\n",
"\n",
"print(\"\\n--- 'None' Class (Derived: all classes = 0) ---\")\n",
"\n",
"holdout_none_true = (holdout_labels.sum(axis=1) == 0).astype(int)\n",
"holdout_none_pred = (holdout_preds.sum(axis=1) == 0).astype(int)\n",
"holdout_none_prob = 1 - holdout_probs.max(axis=1)\n",
"\n",
"none_f1 = f1_score(holdout_none_true, holdout_none_pred, zero_division=0)\n",
"none_mcc = matthews_corrcoef(holdout_none_true, holdout_none_pred)\n",
"try:\n",
" none_auc = roc_auc_score(holdout_none_true, holdout_none_prob)\n",
"except:\n",
" none_auc = 0.0\n",
"none_precision = precision_score(holdout_none_true, holdout_none_pred, zero_division=0)\n",
"none_recall = recall_score(holdout_none_true, holdout_none_pred, zero_division=0)\n",
"none_support = int(holdout_none_true.sum())\n",
"\n",
"print(f\"{'None':<12} {none_f1:>7.4f} {none_auc:>7.4f} {none_mcc:>7.4f} \"\n",
" f\"{none_precision:>7.4f} {none_recall:>7.4f} {none_support:>8}\")\n",
"\n",
"# -----------------------------\n",
"# Confusion Matrices\n",
"# -----------------------------\n",
"\n",
"print(\"\\n--- Confusion Matrices (per class) ---\")\n",
"print(f\"{'Class':<12} {'TP':>6} {'TN':>6} {'FP':>6} {'FN':>6}\")\n",
"print(\"-\"*40)\n",
"\n",
"for i, cls in enumerate(CLASSES):\n",
" y_true = holdout_labels[:, i]\n",
" y_pred = holdout_preds[:, i]\n",
" cm = confusion_matrix(y_true, y_pred)\n",
" tn, fp, fn, tp = cm.ravel()\n",
" print(f\"{cls:<12} {tp:>6} {tn:>6} {fp:>6} {fn:>6}\")\n",
"\n",
"none_cm = confusion_matrix(holdout_none_true, holdout_none_pred)\n",
"none_tn, none_fp, none_fn, none_tp = none_cm.ravel()\n",
"print(f\"{'None':<12} {none_tp:>6} {none_tn:>6} {none_fp:>6} {none_fn:>6}\")\n",
"\n",
"# -----------------------------\n",
"# Multi-label Analysis\n",
"# -----------------------------\n",
"\n",
"print(\"\\n--- Multi-label Analysis ---\")\n",
"\n",
"true_label_counts = holdout_labels.sum(axis=1)\n",
"pred_label_counts = holdout_preds.sum(axis=1)\n",
"\n",
"print(f\"Average labels per sentence:\")\n",
"print(f\" True: {true_label_counts.mean():.2f}\")\n",
"print(f\" Predicted: {pred_label_counts.mean():.2f}\")\n",
"\n",
"print(f\"\\nLabel count distribution:\")\n",
"print(f\"{'# Labels':<10} {'True':>8} {'Predicted':>10}\")\n",
"print(\"-\"*30)\n",
"for n in range(7):\n",
" true_count = (true_label_counts == n).sum()\n",
" pred_count = (pred_label_counts == n).sum()\n",
" print(f\"{n:<10} {true_count:>8} {pred_count:>10}\")\n",
"\n",
"# -----------------------------\n",
"# Save Results\n",
"# -----------------------------\n",
"\n",
"print(\"\\n\" + \"=\"*60)\n",
"print(\"SAVING RESULTS\")\n",
"print(\"=\"*60)\n",
"\n",
"# Save predictions\n",
"df_holdout_results = df_holdout.copy()\n",
"for i, cls in enumerate(CLASSES):\n",
" df_holdout_results[f'{cls}_prob'] = holdout_probs[:, i]\n",
" df_holdout_results[f'{cls}_pred'] = holdout_preds[:, i]\n",
"df_holdout_results['None_prob'] = holdout_none_prob\n",
"df_holdout_results['None_pred'] = holdout_none_pred\n",
"\n",
"df_holdout_results.to_csv(f\"{MODEL_DIR}/holdout_predictions.csv\", index=False)\n",
"df_holdout_results.to_pickle(f\"{MODEL_DIR}/holdout_predictions.pkl\")\n",
"print(f\" Predictions: {MODEL_DIR}/holdout_predictions.csv\")\n",
"\n",
"# Save metrics\n",
"metrics_summary = {\n",
" 'model': MODEL_NAME,\n",
" 'threshold': THRESHOLD,\n",
" 'holdout_size': len(df_holdout),\n",
" 'macro': metrics['macro'],\n",
" 'micro': metrics['micro'],\n",
" 'sample': metrics['sample'],\n",
" 'overall': metrics['overall'],\n",
" 'per_class': metrics['per_class'],\n",
" 'none_class': {\n",
" 'f1': none_f1, 'auc': none_auc, 'mcc': none_mcc,\n",
" 'precision': none_precision, 'recall': none_recall, 'support': none_support,\n",
" },\n",
" 'hyperparameters': {\n",
" 'learning_rate': LEARNING_RATE,\n",
" 'weight_decay': WEIGHT_DECAY,\n",
" 'warmup_ratio': WARMUP_RATIO,\n",
" 'num_epochs': NUM_EPOCHS,\n",
" 'batch_size': BATCH_SIZE,\n",
" 'gradient_accumulation': GRADIENT_ACCUMULATION,\n",
" 'max_length': MAX_LENGTH,\n",
" },\n",
"}\n",
"\n",
"with open(f\"{MODEL_DIR}/holdout_metrics.json\", 'w') as f:\n",
" json.dump(metrics_summary, f, indent=2, default=float)\n",
"print(f\" Metrics: {MODEL_DIR}/holdout_metrics.json\")\n",
"\n",
"# Save training history\n",
"with open(f\"{MODEL_DIR}/training_history.json\", 'w') as f:\n",
" json.dump(trainer.state.log_history, f, indent=2)\n",
"print(f\" History: {MODEL_DIR}/training_history.json\")\n",
"\n",
"print(\"\\n\" + \"=\"*60)\n",
"print(\"DONE\")\n",
"print(\"=\"*60)\n",
"print(f\"\\nFinal Results:\")\n",
"print(f\" Macro F1: {metrics['macro']['f1_macro']:.4f}\")\n",
"print(f\" Macro AUC: {metrics['macro']['auc_macro']:.4f}\")\n",
"print(f\" Macro MCC: {metrics['macro']['mcc_macro']:.4f}\")\n",
"print(f\" Exact Match: {metrics['overall']['exact_match_ratio']:.4f}\")"
]
},
{
"cell_type": "markdown",
"id": "91936e6d-3591-4a86-956c-9090d7e46bb6",
"metadata": {},
"source": [
"**Please cite this paper** if you use any part or all of this code in a project - be it commercial or academic: \n",
"\n",
"> Ringel, Daniel, *Creating Synthetic Experts with Generative Artificial Intelligence* (December 5, 2023). Kenan Institute of Private Enterprise Research Paper No. 4542949, Available at SSRN: https://ssrn.com/abstract=4542949 or http://dx.doi.org/10.2139/ssrn.4542949 \n"
]
},
{
"cell_type": "markdown",
"id": "5e5e8344-2674-4daf-b468-6afbf7aa15bf",
"metadata": {},
"source": [
"*This notebook was developed by Daniel M. Ringel in January 2026 with the help of the various vendor's API documentation (and examples) as well as genAI models from OpenAI and Anthropic.*"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}