Spaces:
Runtime error
Runtime error
from typing import List, Dict, Any, Union, Optional | |
import io | |
import os | |
import base64 | |
from PIL import Image | |
import mimetypes | |
import google.generativeai as genai | |
import tempfile | |
import time | |
from urllib.parse import urlparse | |
import requests | |
from io import BytesIO | |
class GeminiWrapper: | |
"""Wrapper for Gemini to support multiple models and logging""" | |
def __init__( | |
self, | |
model_name: str = "gemini-1.5-pro-002", | |
temperature: float = 0.7, | |
print_cost: bool = False, | |
verbose: bool = False, | |
use_langfuse: bool = False | |
): | |
""" | |
Initialize the Gemini wrapper | |
Args: | |
model_name: Name of the model to use | |
temperature: Temperature for completion | |
print_cost: Whether to print the cost of the completion | |
verbose: Whether to print verbose output | |
use_langfuse: Whether to enable Langfuse logging | |
""" | |
self.model_name = model_name.split('/')[-1] if '/' in model_name else model_name | |
self.temperature = temperature | |
self.print_cost = print_cost | |
self.verbose = verbose | |
self.accumulated_cost = 0 | |
api_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") | |
if not api_key: | |
raise ValueError("No API_KEY found. Please set the `GEMINI_API_KEY` or `GOOGLE_API_KEY` environment variable.") | |
genai.configure(api_key=api_key) | |
generation_config = { | |
"temperature": self.temperature, | |
"top_p": 0.95, | |
"response_mime_type": "text/plain", | |
} | |
safety_settings = [ | |
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, | |
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, | |
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, | |
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, | |
] | |
self.model = genai.GenerativeModel( | |
model_name=self.model_name, | |
safety_settings=safety_settings, | |
generation_config=generation_config, | |
) | |
def _get_mime_type(self, file_path: str) -> str: | |
""" | |
Get the MIME type of a file based on its extension | |
Args: | |
file_path: Path to the file | |
Returns: | |
MIME type as a string (e.g., "image/jpeg", "audio/mp3") | |
""" | |
mime_type, _ = mimetypes.guess_type(file_path) | |
if mime_type is None: | |
raise ValueError(f"Unsupported file type: {file_path}") | |
return mime_type | |
def _download_file(self, url: str) -> str: | |
""" | |
Download a file from a URL and save it as a temporary file | |
Args: | |
url: URL of the file to download | |
Returns: | |
Path to the temporary file | |
""" | |
response = requests.get(url) | |
if response.status_code == 200: | |
temp_file = tempfile.NamedTemporaryFile(delete=False) | |
temp_file.write(response.content) | |
temp_file.close() | |
return temp_file.name | |
else: | |
raise ValueError(f"Failed to download file from URL: {url}") | |
def _save_image_to_temp(self, image: Image.Image) -> str: | |
""" | |
Save a PIL Image to a temporary file | |
Args: | |
image: PIL Image object | |
Returns: | |
Path to the temporary file | |
""" | |
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
image.save(temp_file, format="PNG") | |
temp_file.close() | |
return temp_file.name | |
def _upload_to_gemini(self, file_path: str, mime_type: Optional[str] = None): | |
""" | |
Uploads the given file to Gemini. | |
Args: | |
file_path: Path to the file | |
mime_type: MIME type of the file | |
Returns: | |
Uploaded file object | |
""" | |
return genai.upload_file(file_path, mime_type=mime_type) | |
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str: | |
""" | |
Process messages and return completion | |
Args: | |
messages: List of message dictionaries with 'type' and 'content' keys | |
metadata: Optional metadata to pass to Gemini completion | |
Returns: | |
Generated text response | |
""" | |
contents = [] | |
for msg in messages: | |
if msg["type"] == "text": | |
contents.append(msg["content"]) | |
elif msg["type"] in ["image", "audio", "video"]: | |
if isinstance(msg["content"], Image.Image): | |
file_path = self._save_image_to_temp(msg["content"]) | |
mime_type = "image/png" | |
elif isinstance(msg["content"], str): | |
if msg["content"].startswith("http"): | |
file_path = self._download_file(msg["content"]) | |
mime_type = self._get_mime_type(msg["content"]) | |
else: | |
file_path = msg["content"] | |
mime_type = self._get_mime_type(file_path) | |
else: | |
raise ValueError("Unsupported content type") | |
uploaded_file = self._upload_to_gemini(file_path, mime_type) | |
while uploaded_file.state.name == "PROCESSING": | |
print('.', end='') | |
time.sleep(3) | |
uploaded_file = genai.get_file(uploaded_file.name) | |
if uploaded_file.state.name == "FAILED": | |
raise ValueError(uploaded_file.state.name) | |
print("Upload successfully") | |
contents.append(uploaded_file) | |
else: | |
raise ValueError("Unsupported message type") | |
response = self.model.generate_content(contents, request_options={"timeout": 600}) | |
try: | |
return response.text | |
except Exception as e: | |
print(e) | |
print(response.prompt_feedback) | |
return str(response.prompt_feedback) | |
if __name__ == "__main__": | |
pass |