Spaces:
Sleeping
Sleeping
""" | |
LLM Interface Module for Cross-Domain Uncertainty Quantification | |
This module provides a unified interface for interacting with large language models, | |
supporting multiple model architectures and uncertainty quantification methods. | |
""" | |
import torch | |
import numpy as np | |
from typing import List, Dict, Any, Union, Optional | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM | |
from tqdm import tqdm | |
class LLMInterface: | |
"""Interface for interacting with large language models with uncertainty quantification.""" | |
def __init__( | |
self, | |
model_name: str, | |
model_type: str = "causal", | |
device: str = "cuda" if torch.cuda.is_available() else "cpu", | |
cache_dir: Optional[str] = None, | |
max_length: int = 512, | |
temperature: float = 1.0, | |
top_p: float = 1.0, | |
num_beams: int = 1 | |
): | |
""" | |
Initialize the LLM interface. | |
Args: | |
model_name: Name of the Hugging Face model to use | |
model_type: Type of model ('causal' or 'seq2seq') | |
device: Device to run the model on ('cpu' or 'cuda') | |
cache_dir: Directory to cache models | |
max_length: Maximum length of generated sequences | |
temperature: Sampling temperature | |
top_p: Nucleus sampling parameter | |
num_beams: Number of beams for beam search | |
""" | |
self.model_name = model_name | |
self.model_type = model_type | |
self.device = device | |
self.cache_dir = cache_dir | |
self.max_length = max_length | |
self.temperature = temperature | |
self.top_p = top_p | |
self.num_beams = num_beams | |
# Load tokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
cache_dir=cache_dir | |
) | |
# Load model based on type | |
if model_type == "causal": | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
cache_dir=cache_dir, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
elif model_type == "seq2seq": | |
self.model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_name, | |
cache_dir=cache_dir, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
else: | |
raise ValueError(f"Unsupported model type: {model_type}") | |
# Response cache for efficiency | |
self.response_cache = {} | |
def generate( | |
self, | |
prompt: str, | |
num_samples: int = 1, | |
return_logits: bool = False, | |
**kwargs | |
) -> Dict[str, Any]: | |
""" | |
Generate responses from the model with uncertainty quantification. | |
Args: | |
prompt: Input text prompt | |
num_samples: Number of samples to generate (for MC methods) | |
return_logits: Whether to return token logits | |
**kwargs: Additional generation parameters | |
Returns: | |
Dictionary containing: | |
- response: The generated text | |
- samples: Multiple samples if num_samples > 1 | |
- logits: Token logits if return_logits is True | |
""" | |
# Check cache first | |
cache_key = (prompt, num_samples, return_logits, str(kwargs)) | |
if cache_key in self.response_cache: | |
return self.response_cache[cache_key] | |
# Prepare inputs | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
# Set generation parameters | |
gen_kwargs = { | |
"max_length": self.max_length, | |
"temperature": self.temperature, | |
"top_p": self.top_p, | |
"num_beams": self.num_beams, | |
"do_sample": self.temperature > 0, | |
"pad_token_id": self.tokenizer.eos_token_id | |
} | |
gen_kwargs.update(kwargs) | |
# Generate multiple samples if requested | |
samples = [] | |
all_logits = [] | |
for _ in range(num_samples): | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
**inputs, | |
output_scores=return_logits, | |
return_dict_in_generate=True, | |
**gen_kwargs | |
) | |
# Extract generated tokens | |
if self.model_type == "causal": | |
gen_tokens = outputs.sequences[0, inputs.input_ids.shape[1]:] | |
else: | |
gen_tokens = outputs.sequences[0] | |
# Decode tokens to text | |
gen_text = self.tokenizer.decode(gen_tokens, skip_special_tokens=True) | |
samples.append(gen_text) | |
# Extract logits if requested | |
if return_logits and hasattr(outputs, "scores"): | |
all_logits.append([score.cpu().numpy() for score in outputs.scores]) | |
# Prepare result | |
result = { | |
"response": samples[0], # Primary response is first sample | |
"samples": samples | |
} | |
if return_logits: | |
result["logits"] = all_logits | |
# Cache result | |
self.response_cache[cache_key] = result | |
return result | |
def batch_generate( | |
self, | |
prompts: List[str], | |
**kwargs | |
) -> List[Dict[str, Any]]: | |
""" | |
Generate responses for a batch of prompts. | |
Args: | |
prompts: List of input text prompts | |
**kwargs: Additional generation parameters | |
Returns: | |
List of generation results for each prompt | |
""" | |
results = [] | |
for prompt in tqdm(prompts, desc="Generating responses"): | |
results.append(self.generate(prompt, **kwargs)) | |
return results | |
def clear_cache(self): | |
"""Clear the response cache.""" | |
self.response_cache = {} | |