Omniscient01 / hf_chat.py
Andy Lee
fix: keep simple, and use hf_token for qwen
78ec24e
"""
HuggingFace Chat Model Wrapper for vision models like Qwen2-VL
"""
import os
import base64
import requests
from typing import List, Dict, Any, Optional
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.outputs import ChatResult, ChatGeneration
from pydantic import Field
class HuggingFaceChat(BaseChatModel):
"""Chat model wrapper for HuggingFace Inference API"""
model: str = Field(description="HuggingFace model name")
temperature: float = Field(default=0.0, description="Temperature for sampling")
max_tokens: int = Field(default=1000, description="Max tokens to generate")
api_token: Optional[str] = Field(default=None, description="HF API token")
def __init__(self, model: str, temperature: float = 0.0, **kwargs):
api_token = kwargs.get("api_token") or os.getenv("HF_TOKEN")
if not api_token:
raise ValueError("HF_TOKEN environment variable is required")
super().__init__(
model=model, temperature=temperature, api_token=api_token, **kwargs
)
@property
def _llm_type(self) -> str:
return "huggingface_chat"
def _format_message_for_hf(self, message: HumanMessage) -> Dict[str, Any]:
"""Convert LangChain message to HuggingFace format"""
if isinstance(message.content, str):
return {"role": "user", "content": message.content}
# Handle multi-modal content (text + images)
formatted_content = []
for item in message.content:
if item["type"] == "text":
formatted_content.append({"type": "text", "text": item["text"]})
elif item["type"] == "image_url":
# Extract base64 data from data URL
image_url = item["image_url"]["url"]
if image_url.startswith("data:image"):
# Extract base64 data
base64_data = image_url.split(",")[1]
formatted_content.append({"type": "image", "image": base64_data})
return {"role": "user", "content": formatted_content}
def _generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult:
"""Generate response using HuggingFace Inference API"""
# Format messages for HF API
formatted_messages = []
for msg in messages:
if isinstance(msg, HumanMessage):
formatted_messages.append(self._format_message_for_hf(msg))
# Prepare API request
api_url = f"https://api-inference.huggingface.co/models/{self.model}/v1/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"messages": formatted_messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"stream": False,
}
try:
response = requests.post(api_url, headers=headers, json=payload, timeout=60)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
return ChatResult(
generations=[ChatGeneration(message=HumanMessage(content=content))]
)
except requests.exceptions.RequestException as e:
# Fallback to simple text-only API if chat completions fail
return self._fallback_generate(messages, **kwargs)
def _fallback_generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult:
"""Fallback to simple HF Inference API"""
try:
# Use simple inference API as fallback
api_url = f"https://api-inference.huggingface.co/models/{self.model}"
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json",
}
# Extract text content only for fallback
text_content = ""
for msg in messages:
if isinstance(msg, HumanMessage):
if isinstance(msg.content, str):
text_content += msg.content
else:
for item in msg.content:
if item["type"] == "text":
text_content += item["text"] + "\n"
payload = {
"inputs": text_content,
"parameters": {
"temperature": self.temperature,
"max_new_tokens": self.max_tokens,
},
}
response = requests.post(api_url, headers=headers, json=payload, timeout=60)
response.raise_for_status()
result = response.json()
if isinstance(result, list) and len(result) > 0:
content = result[0].get("generated_text", "No response generated")
else:
content = "Error: Invalid response format"
return ChatResult(
generations=[ChatGeneration(message=HumanMessage(content=content))]
)
except Exception as e:
# Last resort fallback
error_msg = f"HuggingFace API Error: {str(e)}. Please check your API key and model availability."
return ChatResult(
generations=[ChatGeneration(message=HumanMessage(content=error_msg))]
)