Spaces:
Runtime error
Runtime error
import os | |
from typing import List, Dict, Any, Optional | |
import vertexai | |
from vertexai.generative_models import GenerativeModel, Part | |
from google.auth import default | |
from google.auth.transport import requests | |
# TODO: check if this is the correct way to use Vertex AI | |
# TODO: add langfuse support | |
class VertexAIWrapper: | |
"""Wrapper for Vertex AI to support Gemini models.""" | |
def __init__( | |
self, | |
model_name: str = "gemini-1.5-pro", | |
temperature: float = 0.7, | |
print_cost: bool = False, | |
verbose: bool = False, | |
use_langfuse: bool = False | |
): | |
"""Initialize the Vertex AI wrapper. | |
Args: | |
model_name: Name of the model to use (e.g. "gemini-1.5-pro") | |
temperature: Temperature for generation between 0 and 1 | |
print_cost: Whether to print the cost of the completion | |
verbose: Whether to print verbose output | |
use_langfuse: Whether to enable Langfuse logging | |
""" | |
self.model_name = model_name | |
self.temperature = temperature | |
self.print_cost = print_cost | |
self.verbose = verbose | |
# Initialize Vertex AI | |
project_id = os.getenv("GOOGLE_CLOUD_PROJECT") | |
location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1") | |
if not project_id: | |
raise ValueError("No GOOGLE_CLOUD_PROJECT found in environment variables") | |
vertexai.init(project=project_id, location=location) | |
self.model = GenerativeModel(model_name) | |
def __call__(self, messages: List[Dict[str, Any]], metadata: Optional[Dict[str, Any]] = None) -> str: | |
"""Process messages and return completion. | |
Args: | |
messages: List of message dictionaries containing type and content | |
metadata: Optional metadata dictionary to pass to the model | |
Returns: | |
Generated text response from the model | |
Raises: | |
ValueError: If message type is not supported | |
""" | |
parts = [] | |
for msg in messages: | |
if msg["type"] == "text": | |
parts.append(Part.from_text(msg["content"])) | |
elif msg["type"] in ["image", "video"]: | |
mime_type = "video/mp4" if msg["type"] == "video" else "image/jpeg" | |
if isinstance(msg["content"], str): | |
# Handle GCS URI | |
parts.append(Part.from_uri( | |
msg["content"], | |
mime_type=mime_type | |
)) | |
else: | |
# Handle file path or bytes | |
parts.append(Part.from_data( | |
msg["content"], | |
mime_type=mime_type | |
)) | |
response = self.model.generate_content( | |
parts, | |
generation_config={ | |
"temperature": self.temperature, | |
"top_p": 0.95, | |
} | |
) | |
return response.text |