Spaces:
Sleeping
Sleeping
""" | |
Domain Dataset Module for Cross-Domain Uncertainty Quantification | |
This module provides functionality for loading and managing datasets from different domains | |
for evaluating uncertainty quantification methods across domains. | |
""" | |
import os | |
import json | |
import pandas as pd | |
import numpy as np | |
from typing import List, Dict, Any, Union, Optional, Tuple | |
from datasets import load_dataset | |
class DomainDataset: | |
"""Base class for domain-specific datasets.""" | |
def __init__(self, name: str, domain: str): | |
""" | |
Initialize the domain dataset. | |
Args: | |
name: Name of the dataset | |
domain: Domain category (e.g., 'medical', 'legal', 'general') | |
""" | |
self.name = name | |
self.domain = domain | |
self.data = None | |
def load(self) -> None: | |
"""Load the dataset.""" | |
raise NotImplementedError("Subclasses must implement this method") | |
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: | |
""" | |
Get samples from the dataset. | |
Args: | |
n: Number of samples to return (None for all) | |
Returns: | |
List of samples with prompts and expected outputs | |
""" | |
raise NotImplementedError("Subclasses must implement this method") | |
def get_prompt_template(self) -> str: | |
""" | |
Get the prompt template for this domain. | |
Returns: | |
Prompt template string | |
""" | |
raise NotImplementedError("Subclasses must implement this method") | |
class MedicalQADataset(DomainDataset): | |
"""Dataset for medical question answering.""" | |
def __init__(self, data_path: Optional[str] = None): | |
""" | |
Initialize the medical QA dataset. | |
Args: | |
data_path: Path to the dataset file (None to use default) | |
""" | |
super().__init__("medical_qa", "medical") | |
self.data_path = data_path | |
def load(self) -> None: | |
"""Load the medical QA dataset.""" | |
if self.data_path and os.path.exists(self.data_path): | |
# Load from local file if available | |
if self.data_path.endswith('.csv'): | |
self.data = pd.read_csv(self.data_path) | |
elif self.data_path.endswith('.json'): | |
with open(self.data_path, 'r') as f: | |
self.data = json.load(f) | |
else: | |
raise ValueError(f"Unsupported file format: {self.data_path}") | |
else: | |
# Use a sample of the MedMCQA dataset from Hugging Face | |
try: | |
dataset = load_dataset("medmcqa", split="train[:100]") | |
self.data = dataset.to_pandas() | |
except Exception as e: | |
# Fallback to synthetic data if dataset loading fails | |
print(f"Failed to load MedMCQA dataset: {e}") | |
self.data = self._create_synthetic_data() | |
def _create_synthetic_data(self) -> pd.DataFrame: | |
"""Create synthetic medical QA data for testing.""" | |
questions = [ | |
"What are the common symptoms of myocardial infarction?", | |
"How does insulin regulate blood glucose levels?", | |
"What is the mechanism of action for ACE inhibitors?", | |
"What are the diagnostic criteria for rheumatoid arthritis?", | |
"How does the SARS-CoV-2 virus enter human cells?", | |
"What are the main side effects of chemotherapy?", | |
"How does the blood-brain barrier function?", | |
"What is the pathophysiology of type 2 diabetes?", | |
"How do vaccines create immunity?", | |
"What are the stages of chronic kidney disease?" | |
] | |
# Create a dataframe with questions only (answers would be generated by LLMs) | |
return pd.DataFrame({ | |
'question': questions, | |
'domain': ['medical'] * len(questions) | |
}) | |
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: | |
""" | |
Get samples from the medical QA dataset. | |
Args: | |
n: Number of samples to return (None for all) | |
Returns: | |
List of samples with prompts | |
""" | |
if self.data is None: | |
self.load() | |
if 'question' in self.data.columns: | |
questions = self.data['question'].tolist() | |
elif 'question_text' in self.data.columns: | |
questions = self.data['question_text'].tolist() | |
else: | |
raise ValueError("Dataset does not contain question column") | |
if n is not None: | |
questions = questions[:n] | |
# Create samples with prompts | |
samples = [] | |
for question in questions: | |
prompt = self.get_prompt_template().format(question=question) | |
samples.append({ | |
'domain': 'medical', | |
'question': question, | |
'prompt': prompt | |
}) | |
return samples | |
def get_prompt_template(self) -> str: | |
""" | |
Get the prompt template for medical domain. | |
Returns: | |
Prompt template string | |
""" | |
return "You are a medical expert. Please answer the following medical question accurately and concisely:\n\n{question}" | |
class LegalQADataset(DomainDataset): | |
"""Dataset for legal question answering.""" | |
def __init__(self, data_path: Optional[str] = None): | |
""" | |
Initialize the legal QA dataset. | |
Args: | |
data_path: Path to the dataset file (None to use default) | |
""" | |
super().__init__("legal_qa", "legal") | |
self.data_path = data_path | |
def load(self) -> None: | |
"""Load the legal QA dataset.""" | |
if self.data_path and os.path.exists(self.data_path): | |
# Load from local file if available | |
if self.data_path.endswith('.csv'): | |
self.data = pd.read_csv(self.data_path) | |
elif self.data_path.endswith('.json'): | |
with open(self.data_path, 'r') as f: | |
self.data = json.load(f) | |
else: | |
raise ValueError(f"Unsupported file format: {self.data_path}") | |
else: | |
# Use synthetic data for legal domain | |
self.data = self._create_synthetic_data() | |
def _create_synthetic_data(self) -> pd.DataFrame: | |
"""Create synthetic legal QA data for testing.""" | |
questions = [ | |
"What constitutes a breach of contract?", | |
"How is intellectual property protected under international law?", | |
"What are the elements of negligence in tort law?", | |
"How does the doctrine of stare decisis function in common law systems?", | |
"What rights are protected under the Fourth Amendment?", | |
"What is the difference between a patent and a copyright?", | |
"How does arbitration differ from litigation?", | |
"What constitutes insider trading under securities law?", | |
"What are the legal requirements for a valid will?", | |
"How does diplomatic immunity work under international law?" | |
] | |
# Create a dataframe with questions only | |
return pd.DataFrame({ | |
'question': questions, | |
'domain': ['legal'] * len(questions) | |
}) | |
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: | |
""" | |
Get samples from the legal QA dataset. | |
Args: | |
n: Number of samples to return (None for all) | |
Returns: | |
List of samples with prompts | |
""" | |
if self.data is None: | |
self.load() | |
questions = self.data['question'].tolist() | |
if n is not None: | |
questions = questions[:n] | |
# Create samples with prompts | |
samples = [] | |
for question in questions: | |
prompt = self.get_prompt_template().format(question=question) | |
samples.append({ | |
'domain': 'legal', | |
'question': question, | |
'prompt': prompt | |
}) | |
return samples | |
def get_prompt_template(self) -> str: | |
""" | |
Get the prompt template for legal domain. | |
Returns: | |
Prompt template string | |
""" | |
return "You are a legal expert. Please answer the following legal question accurately and concisely:\n\n{question}" | |
class GeneralKnowledgeDataset(DomainDataset): | |
"""Dataset for general knowledge question answering.""" | |
def __init__(self, data_path: Optional[str] = None): | |
""" | |
Initialize the general knowledge dataset. | |
Args: | |
data_path: Path to the dataset file (None to use default) | |
""" | |
super().__init__("general_knowledge", "general") | |
self.data_path = data_path | |
def load(self) -> None: | |
"""Load the general knowledge dataset.""" | |
if self.data_path and os.path.exists(self.data_path): | |
# Load from local file if available | |
if self.data_path.endswith('.csv'): | |
self.data = pd.read_csv(self.data_path) | |
elif self.data_path.endswith('.json'): | |
with open(self.data_path, 'r') as f: | |
self.data = json.load(f) | |
else: | |
raise ValueError(f"Unsupported file format: {self.data_path}") | |
else: | |
# Use a sample of the TriviaQA dataset from Hugging Face | |
try: | |
dataset = load_dataset("trivia_qa", "unfiltered", split="train[:100]") | |
self.data = dataset.to_pandas() | |
except Exception as e: | |
# Fallback to synthetic data if dataset loading fails | |
print(f"Failed to load TriviaQA dataset: {e}") | |
self.data = self._create_synthetic_data() | |
def _create_synthetic_data(self) -> pd.DataFrame: | |
"""Create synthetic general knowledge data for testing.""" | |
questions = [ | |
"What is the capital of France?", | |
"Who wrote the novel '1984'?", | |
"What is the chemical symbol for gold?", | |
"Which planet is known as the Red Planet?", | |
"Who painted the Mona Lisa?", | |
"What is the largest ocean on Earth?", | |
"What year did World War II end?", | |
"What is the tallest mountain in the world?", | |
"Who was the first person to step on the moon?", | |
"What is the speed of light in a vacuum?" | |
] | |
# Create a dataframe with questions only | |
return pd.DataFrame({ | |
'question': questions, | |
'domain': ['general'] * len(questions) | |
}) | |
def get_samples(self, n: Optional[int] = None) -> List[Dict[str, Any]]: | |
""" | |
Get samples from the general knowledge dataset. | |
Args: | |
n: Number of samples to return (None for all) | |
Returns: | |
List of samples with prompts | |
""" | |
if self.data is None: | |
self.load() | |
if 'question' in self.data.columns: | |
questions = self.data['question'].tolist() | |
elif 'question_text' in self.data.columns: | |
questions = self.data['question_text'].tolist() | |
else: | |
raise ValueError("Dataset does not contain question column") | |
if n is not None: | |
questions = questions[:n] | |
# Create samples with prompts | |
samples = [] | |
for question in questions: | |
prompt = self.get_prompt_template().format(question=question) | |
samples.append({ | |
'domain': 'general', | |
'question': question, | |
'prompt': prompt | |
}) | |
return samples | |
def get_prompt_template(self) -> str: | |
""" | |
Get the prompt template for general knowledge domain. | |
Returns: | |
Prompt template string | |
""" | |
return "Please answer the following general knowledge question accurately and concisely:\n\n{question}" | |
# Factory function to create domain datasets | |
def create_domain_dataset(domain: str, data_path: Optional[str] = None) -> DomainDataset: | |
""" | |
Create a domain dataset based on the specified domain. | |
Args: | |
domain: Domain category ('medical', 'legal', 'general') | |
data_path: Path to the dataset file (None to use default) | |
Returns: | |
Domain dataset instance | |
""" | |
if domain == "medical": | |
return MedicalQADataset(data_path) | |
elif domain == "legal": | |
return LegalQADataset(data_path) | |
elif domain == "general": | |
return GeneralKnowledgeDataset(data_path) | |
else: | |
raise ValueError(f"Unsupported domain: {domain}") | |