Spaces:
Running
Running
File size: 5,562 Bytes
4d37e51 78ec24e 4d37e51 78ec24e 4d37e51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""
HuggingFace Chat Model Wrapper for vision models like Qwen2-VL
"""
import os
import base64
import requests
from typing import List, Dict, Any, Optional
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.outputs import ChatResult, ChatGeneration
from pydantic import Field
class HuggingFaceChat(BaseChatModel):
"""Chat model wrapper for HuggingFace Inference API"""
model: str = Field(description="HuggingFace model name")
temperature: float = Field(default=0.0, description="Temperature for sampling")
max_tokens: int = Field(default=1000, description="Max tokens to generate")
api_token: Optional[str] = Field(default=None, description="HF API token")
def __init__(self, model: str, temperature: float = 0.0, **kwargs):
api_token = kwargs.get("api_token") or os.getenv("HF_TOKEN")
if not api_token:
raise ValueError("HF_TOKEN environment variable is required")
super().__init__(
model=model, temperature=temperature, api_token=api_token, **kwargs
)
@property
def _llm_type(self) -> str:
return "huggingface_chat"
def _format_message_for_hf(self, message: HumanMessage) -> Dict[str, Any]:
"""Convert LangChain message to HuggingFace format"""
if isinstance(message.content, str):
return {"role": "user", "content": message.content}
# Handle multi-modal content (text + images)
formatted_content = []
for item in message.content:
if item["type"] == "text":
formatted_content.append({"type": "text", "text": item["text"]})
elif item["type"] == "image_url":
# Extract base64 data from data URL
image_url = item["image_url"]["url"]
if image_url.startswith("data:image"):
# Extract base64 data
base64_data = image_url.split(",")[1]
formatted_content.append({"type": "image", "image": base64_data})
return {"role": "user", "content": formatted_content}
def _generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult:
"""Generate response using HuggingFace Inference API"""
# Format messages for HF API
formatted_messages = []
for msg in messages:
if isinstance(msg, HumanMessage):
formatted_messages.append(self._format_message_for_hf(msg))
# Prepare API request
api_url = f"https://api-inference.huggingface.co/models/{self.model}/v1/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"messages": formatted_messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"stream": False,
}
try:
response = requests.post(api_url, headers=headers, json=payload, timeout=60)
response.raise_for_status()
result = response.json()
content = result["choices"][0]["message"]["content"]
return ChatResult(
generations=[ChatGeneration(message=HumanMessage(content=content))]
)
except requests.exceptions.RequestException as e:
# Fallback to simple text-only API if chat completions fail
return self._fallback_generate(messages, **kwargs)
def _fallback_generate(self, messages: List[BaseMessage], **kwargs) -> ChatResult:
"""Fallback to simple HF Inference API"""
try:
# Use simple inference API as fallback
api_url = f"https://api-inference.huggingface.co/models/{self.model}"
headers = {
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json",
}
# Extract text content only for fallback
text_content = ""
for msg in messages:
if isinstance(msg, HumanMessage):
if isinstance(msg.content, str):
text_content += msg.content
else:
for item in msg.content:
if item["type"] == "text":
text_content += item["text"] + "\n"
payload = {
"inputs": text_content,
"parameters": {
"temperature": self.temperature,
"max_new_tokens": self.max_tokens,
},
}
response = requests.post(api_url, headers=headers, json=payload, timeout=60)
response.raise_for_status()
result = response.json()
if isinstance(result, list) and len(result) > 0:
content = result[0].get("generated_text", "No response generated")
else:
content = "Error: Invalid response format"
return ChatResult(
generations=[ChatGeneration(message=HumanMessage(content=content))]
)
except Exception as e:
# Last resort fallback
error_msg = f"HuggingFace API Error: {str(e)}. Please check your API key and model availability."
return ChatResult(
generations=[ChatGeneration(message=HumanMessage(content=error_msg))]
)
|