|
import json |
|
import logging |
|
from abc import ABC, abstractmethod |
|
from enum import Enum |
|
from functools import cached_property |
|
from typing import Any, Literal, Optional, Self |
|
|
|
from pydantic import BaseModel |
|
|
|
from proxy_lite.history import ToolCall |
|
from proxy_lite.tools import Tool, ToolExecutionResponse |
|
|
|
|
|
class EventType(str, Enum): |
|
OBSERVATION = "observation" |
|
ACTION = "action" |
|
MESSAGE = "message" |
|
|
|
|
|
class Event(BaseModel): |
|
type: EventType |
|
|
|
|
|
class State(BaseModel): |
|
text: Optional[str] = None |
|
image: Optional[str] = None |
|
html: Optional[str] = None |
|
tool_responses: Optional[list[ToolExecutionResponse]] = None |
|
|
|
|
|
class Observation(Event): |
|
type: Literal[EventType.OBSERVATION] = EventType.OBSERVATION |
|
state: State |
|
terminated: bool |
|
reward: Optional[float] = None |
|
info: Optional[dict[str, Any]] = None |
|
|
|
|
|
class Action(Event): |
|
type: Literal[EventType.ACTION] = EventType.ACTION |
|
text: Optional[str] = None |
|
tool_calls: Optional[list[ToolCall]] = None |
|
info: Optional[dict[str, Any]] = None |
|
|
|
|
|
class BaseEnvironmentConfig(BaseModel): ... |
|
|
|
|
|
class BaseEnvironment(BaseModel, ABC): |
|
config: BaseEnvironmentConfig |
|
logger: logging.Logger | None = None |
|
|
|
class Config: |
|
arbitrary_types_allowed = True |
|
|
|
async def __aenter__(self) -> Self: |
|
return self |
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback): |
|
pass |
|
|
|
@property |
|
@abstractmethod |
|
def info_for_user(self) -> str: ... |
|
|
|
@cached_property |
|
@abstractmethod |
|
def tools(self) -> list[Tool]: ... |
|
|
|
@abstractmethod |
|
async def initialise(self) -> Observation: ... |
|
|
|
@abstractmethod |
|
async def execute_action(self, action: Action) -> Observation: ... |
|
|
|
@abstractmethod |
|
async def observe(self) -> Observation: ... |
|
|
|
@abstractmethod |
|
async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]: ... |
|
|
|
async def execute_tool(self, tool_call: ToolCall) -> None: |
|
function = tool_call.function |
|
for tool in self.tools: |
|
if hasattr(tool, function["name"]): |
|
arguments = json.loads(function["arguments"]) |
|
if isinstance(arguments, str): |
|
arguments = json.loads(arguments) |
|
return await getattr(tool, function["name"])( |
|
**arguments, |
|
) |
|
msg = f'No tool function with name "{function["name"]}"' |
|
raise ValueError(msg) |
|
|
|
async def get_info(self) -> dict[str, Any]: |
|
return {} |
|
|
|
|
|
class Environments: |
|
_environment_registry: dict[str, type[BaseEnvironment]] = {} |
|
_environment_config_registry: dict[str, type[BaseEnvironmentConfig]] = {} |
|
|
|
@classmethod |
|
def register_environment(cls, name: str): |
|
""" |
|
Decorator to register an Environment class under a given name. |
|
|
|
Example: |
|
@Environments.register_environment("my_environment") |
|
class MyEnvironment(BaseEnvironment): |
|
... |
|
""" |
|
|
|
def decorator(env_cls: type[BaseEnvironment]) -> type[BaseEnvironment]: |
|
cls._environment_registry[name] = env_cls |
|
return env_cls |
|
|
|
return decorator |
|
|
|
@classmethod |
|
def register_environment_config(cls, name: str): |
|
""" |
|
Decorator to register an Environment configuration class under a given name. |
|
|
|
Example: |
|
@Environments.register_environment_config("my_environment") |
|
class MyEnvironmentConfig(BaseEnvironmentConfig): |
|
... |
|
""" |
|
|
|
def decorator(config_cls: type[BaseEnvironmentConfig]) -> type[BaseEnvironmentConfig]: |
|
cls._environment_config_registry[name] = config_cls |
|
return config_cls |
|
|
|
return decorator |
|
|
|
@classmethod |
|
def get(cls, name: str) -> type[BaseEnvironment]: |
|
""" |
|
Retrieve a registered Environment class by its name. |
|
|
|
Raises: |
|
ValueError: If no such environment is found. |
|
""" |
|
try: |
|
return cls._environment_registry[name] |
|
except KeyError: |
|
raise ValueError(f"Environment '{name}' not found.") |
|
|
|
@classmethod |
|
def get_config(cls, name: str) -> type[BaseEnvironmentConfig]: |
|
""" |
|
Retrieve a registered Environment configuration class by its name. |
|
|
|
Raises: |
|
ValueError: If no such configuration is found. |
|
""" |
|
try: |
|
return cls._environment_config_registry[name] |
|
except KeyError: |
|
raise ValueError(f"Environment config for '{name}' not found.") |
|
|