Spaces:
Running
Running
""" | |
HuggingFace Chat Model Wrapper for vision models like Qwen2-VL | |
""" | |
import os | |
import base64 | |
import requests | |
from typing import List, Dict, Any, Optional | |
from langchain_core.messages import BaseMessage, HumanMessage | |
from langchain_core.language_models.chat_models import BaseChatModel | |
from langchain_core.outputs import ChatResult, ChatGeneration | |
from pydantic import Field | |
class HuggingFaceChat(BaseChatModel): | |
"""Chat model wrapper for HuggingFace Inference API""" | |
model: str = Field(description="HuggingFace model name") | |
temperature: float = Field(default=0.0, description="Temperature for sampling") | |
max_tokens: int = Field(default=1000, description="Max tokens to generate") | |
api_token: Optional[str] = Field(default=None, description="HF API token") | |
def __init__(self, model: str, temperature: float = 0.0, **kwargs): | |
api_token = kwargs.get("api_token") or os.getenv("HF_TOKEN") | |
if not api_token: | |
raise ValueError("HF_TOKEN environment variable is required") | |
super().__init__( | |
model=model, temperature=temperature, api_token=api_token, **kwargs | |
) | |
def _llm_type(self) -> str: | |
return "huggingface_chat" | |
def _format_message_for_hf(self, message: HumanMessage) -> Dict[str, Any]: | |
"""Convert LangChain message to HuggingFace format""" | |
if isinstance(message.content, str): | |
return {"role": "user", "content": message.content} | |
# Handle multi-modal content (text + images) | |
formatted_content = [] | |
for item in message.content: | |
if item["type"] == "text": | |
formatted_content.append({"type": "text", "text": item["text"]}) | |
elif item["type"] == "image_url": | |
# Extract base64 data from data URL | |
image_url = item["image_url"]["url"] | |
if image_url.startswith("data:image"): | |
# Extract base64 data | |
base64_data = image_url.split(",")[1] | |
formatted_content.append({"type": "image", "image": base64_data}) | |
return {"role": "user", "content": formatted_content} | |
def _generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult: | |
"""Generate response using HuggingFace Inference API""" | |
# Format messages for HF API | |
formatted_messages = [] | |
for msg in messages: | |
if isinstance(msg, HumanMessage): | |
formatted_messages.append(self._format_message_for_hf(msg)) | |
# Prepare API request | |
api_url = f"https://api-inference.huggingface.co/models/{self.model}/v1/chat/completions" | |
headers = { | |
"Authorization": f"Bearer {self.api_token}", | |
"Content-Type": "application/json", | |
} | |
payload = { | |
"model": self.model, | |
"messages": formatted_messages, | |
"temperature": self.temperature, | |
"max_tokens": self.max_tokens, | |
"stream": False, | |
} | |
try: | |
response = requests.post(api_url, headers=headers, json=payload, timeout=60) | |
response.raise_for_status() | |
result = response.json() | |
content = result["choices"][0]["message"]["content"] | |
return ChatResult( | |
generations=[ChatGeneration(message=HumanMessage(content=content))] | |
) | |
except requests.exceptions.RequestException as e: | |
# Fallback to simple text-only API if chat completions fail | |
return self._fallback_generate(messages, **kwargs) | |
def _fallback_generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult: | |
"""Fallback to simple HF Inference API""" | |
try: | |
# Use simple inference API as fallback | |
api_url = f"https://api-inference.huggingface.co/models/{self.model}" | |
headers = { | |
"Authorization": f"Bearer {self.api_token}", | |
"Content-Type": "application/json", | |
} | |
# Extract text content only for fallback | |
text_content = "" | |
for msg in messages: | |
if isinstance(msg, HumanMessage): | |
if isinstance(msg.content, str): | |
text_content += msg.content | |
else: | |
for item in msg.content: | |
if item["type"] == "text": | |
text_content += item["text"] + "\n" | |
payload = { | |
"inputs": text_content, | |
"parameters": { | |
"temperature": self.temperature, | |
"max_new_tokens": self.max_tokens, | |
}, | |
} | |
response = requests.post(api_url, headers=headers, json=payload, timeout=60) | |
response.raise_for_status() | |
result = response.json() | |
if isinstance(result, list) and len(result) > 0: | |
content = result[0].get("generated_text", "No response generated") | |
else: | |
content = "Error: Invalid response format" | |
return ChatResult( | |
generations=[ChatGeneration(message=HumanMessage(content=content))] | |
) | |
except Exception as e: | |
# Last resort fallback | |
error_msg = f"HuggingFace API Error: {str(e)}. Please check your API key and model availability." | |
return ChatResult( | |
generations=[ChatGeneration(message=HumanMessage(content=error_msg))] | |
) | |