|
import os |
|
from abc import ABC, abstractmethod |
|
from functools import cached_property |
|
from typing import ClassVar, Literal, Optional, Union |
|
|
|
import httpx |
|
from httpx import Limits, Timeout |
|
from openai import AsyncOpenAI |
|
from openai.types.chat.chat_completion import ( |
|
ChatCompletion, |
|
) |
|
from pydantic import BaseModel |
|
|
|
from proxy_lite.history import MessageHistory |
|
from proxy_lite.logger import logger |
|
from proxy_lite.serializer import ( |
|
BaseSerializer, |
|
OpenAICompatibleSerializer, |
|
) |
|
from proxy_lite.tools import Tool |
|
|
|
|
|
class BaseClientConfig(BaseModel): |
|
http_timeout: float = 50 |
|
http_concurrent_connections: int = 50 |
|
|
|
|
|
class BaseClient(BaseModel, ABC): |
|
config: BaseClientConfig |
|
serializer: ClassVar[BaseSerializer] |
|
|
|
@abstractmethod |
|
async def create_completion( |
|
self, |
|
messages: MessageHistory, |
|
temperature: float = 0.7, |
|
seed: Optional[int] = None, |
|
tools: Optional[list[Tool]] = None, |
|
response_format: Optional[type[BaseModel]] = None, |
|
) -> ChatCompletion: ... |
|
|
|
""" |
|
Create completion from model. |
|
Expect subclasses to adapt from various endpoints that will handle |
|
requests differently, make sure to raise appropriate warnings. |
|
|
|
Returns: |
|
ChatCompletion: OpenAI ChatCompletion format for consistency |
|
""" |
|
|
|
@classmethod |
|
def create(cls, config: BaseClientConfig) -> "BaseClient": |
|
supported_clients = { |
|
"openai-azure": OpenAIClient, |
|
"convergence": ConvergenceClient, |
|
} |
|
if config.name not in supported_clients: |
|
error_message = f"Unsupported model: {config.name}." |
|
raise ValueError(error_message) |
|
return supported_clients[config.name](config=config) |
|
|
|
@property |
|
def http_client(self) -> httpx.AsyncClient: |
|
return httpx.AsyncClient( |
|
timeout=Timeout(self.config.http_timeout), |
|
limits=Limits( |
|
max_connections=self.config.http_concurrent_connections, |
|
max_keepalive_connections=self.config.http_concurrent_connections, |
|
), |
|
) |
|
|
|
|
|
class OpenAIClientConfig(BaseClientConfig): |
|
name: Literal["openai"] = "openai" |
|
model_id: str = "gpt-4o" |
|
api_key: str = os.environ.get("OPENAI_API_KEY") |
|
|
|
|
|
class OpenAIClient(BaseClient): |
|
config: OpenAIClientConfig |
|
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() |
|
|
|
@cached_property |
|
def external_client(self) -> AsyncOpenAI: |
|
return AsyncOpenAI( |
|
api_key=self.config.api_key, |
|
http_client=self.http_client, |
|
) |
|
|
|
async def create_completion( |
|
self, |
|
messages: MessageHistory, |
|
temperature: float = 0.7, |
|
seed: Optional[int] = None, |
|
tools: Optional[list[Tool]] = None, |
|
response_format: Optional[type[BaseModel]] = None, |
|
) -> ChatCompletion: |
|
base_params = { |
|
"model": self.config.model_id, |
|
"messages": self.serializer.serialize_messages(messages), |
|
"temperature": temperature, |
|
} |
|
optional_params = { |
|
"seed": seed, |
|
"tools": self.serializer.serialize_tools(tools) if tools else None, |
|
"tool_choice": "required" if tools else None, |
|
"response_format": {"type": "json_object"} if response_format else {"type": "text"}, |
|
} |
|
base_params.update({k: v for k, v in optional_params.items() if v is not None}) |
|
return await self.external_client.chat.completions.create(**base_params) |
|
|
|
|
|
class ConvergenceClientConfig(BaseClientConfig): |
|
name: Literal["convergence"] = "convergence" |
|
model_id: str = "convergence-ai/proxy-lite-7b" |
|
api_base: str = "http://localhost:8000/v1" |
|
api_key: str = "none" |
|
|
|
|
|
class ConvergenceClient(OpenAIClient): |
|
config: ConvergenceClientConfig |
|
serializer: ClassVar[OpenAICompatibleSerializer] = OpenAICompatibleSerializer() |
|
_model_validated: bool = False |
|
|
|
async def _validate_model(self) -> None: |
|
try: |
|
response = await self.external_client.models.list() |
|
assert self.config.model_id in [model.id for model in response.data], ( |
|
f"Model {self.config.model_id} not found in {response.data}" |
|
) |
|
self._model_validated = True |
|
logger.debug(f"Model {self.config.model_id} validated and connected to cluster") |
|
except Exception as e: |
|
logger.error(f"Error retrieving model: {e}") |
|
raise e |
|
|
|
@cached_property |
|
def external_client(self) -> AsyncOpenAI: |
|
return AsyncOpenAI( |
|
api_key=self.config.api_key, |
|
base_url=self.config.api_base, |
|
http_client=self.http_client, |
|
) |
|
|
|
async def create_completion( |
|
self, |
|
messages: MessageHistory, |
|
temperature: float = 0.7, |
|
seed: Optional[int] = None, |
|
tools: Optional[list[Tool]] = None, |
|
response_format: Optional[type[BaseModel]] = None, |
|
) -> ChatCompletion: |
|
if not self._model_validated: |
|
await self._validate_model() |
|
base_params = { |
|
"model": self.config.model_id, |
|
"messages": self.serializer.serialize_messages(messages), |
|
"temperature": temperature, |
|
} |
|
optional_params = { |
|
"seed": seed, |
|
"tools": self.serializer.serialize_tools(tools) if tools else None, |
|
"tool_choice": "auto" if tools else None, |
|
"response_format": response_format if response_format else {"type": "text"}, |
|
} |
|
base_params.update({k: v for k, v in optional_params.items() if v is not None}) |
|
return await self.external_client.chat.completions.create(**base_params) |
|
|
|
|
|
ClientConfigTypes = Union[OpenAIClientConfig, ConvergenceClientConfig] |
|
ClientTypes = Union[OpenAIClient, ConvergenceClient] |
|
|