|
import logging |
|
from abc import ABC, abstractmethod |
|
from functools import cached_property |
|
from typing import Optional, Self, Type, cast |
|
|
|
from pydantic import BaseModel, Field |
|
|
|
from proxy_lite.environments.environment_base import Action, Observation |
|
from proxy_lite.tools import Tool |
|
|
|
|
|
class BaseSolverConfig(BaseModel): |
|
pass |
|
|
|
|
|
class BaseSolver(BaseModel, ABC): |
|
task: Optional[str] = None |
|
env_tools: list[Tool] = Field(default_factory=list) |
|
config: BaseSolverConfig |
|
logger: logging.Logger | None = None |
|
|
|
class Config: |
|
arbitrary_types_allowed = True |
|
|
|
async def __aenter__(self) -> Self: |
|
return self |
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback) -> None: |
|
pass |
|
|
|
@cached_property |
|
@abstractmethod |
|
def tools(self) -> list[Tool]: ... |
|
|
|
@abstractmethod |
|
async def initialise( |
|
self, |
|
task: str, |
|
env_tools: list[Tool], |
|
env_info: str, |
|
) -> None: |
|
""" |
|
Initialise the solution with the given task. |
|
""" |
|
... |
|
|
|
@abstractmethod |
|
async def act(self, observation: Observation) -> Action: |
|
""" |
|
Return an action for interacting with the environment. |
|
""" |
|
... |
|
|
|
async def is_complete(self, observation: Observation) -> bool: |
|
""" |
|
Return a boolean indicating if the task is complete. |
|
""" |
|
return observation.terminated |
|
|
|
|
|
class Solvers: |
|
_solver_registry: dict[str, type[BaseSolver]] = {} |
|
_solver_config_registry: dict[str, type[BaseSolverConfig]] = {} |
|
|
|
@classmethod |
|
def register_solver(cls, name: str): |
|
""" |
|
Decorator to register a Solver class under a given name. |
|
|
|
Example: |
|
@Solvers.register_solver("my_solver") |
|
class MySolver(BaseSolver): |
|
... |
|
""" |
|
|
|
def decorator(solver_cls: type[BaseSolver]) -> type[BaseSolver]: |
|
cls._solver_registry[name] = solver_cls |
|
return solver_cls |
|
|
|
return decorator |
|
|
|
@classmethod |
|
def register_solver_config(cls, name: str): |
|
""" |
|
Decorator to register a Solver configuration class under a given name. |
|
|
|
Example: |
|
@Solvers.register_solver_config("my_solver") |
|
class MySolverConfig(BaseSolverConfig): |
|
... |
|
""" |
|
|
|
def decorator(config_cls: type[BaseSolverConfig]) -> type[BaseSolverConfig]: |
|
cls._solver_config_registry[name] = config_cls |
|
return config_cls |
|
|
|
return decorator |
|
|
|
@classmethod |
|
def get(cls, name: str) -> type[BaseSolver]: |
|
""" |
|
Retrieve a registered Solver class by its name. |
|
|
|
Raises: |
|
ValueError: If no such solver is found. |
|
""" |
|
try: |
|
return cast(Type[BaseSolver], cls._solver_registry[name]) |
|
except KeyError: |
|
raise ValueError(f"Solver '{name}' not found.") |
|
|
|
@classmethod |
|
def get_config(cls, name: str) -> type[BaseSolverConfig]: |
|
""" |
|
Retrieve a registered Solver configuration class by its name. |
|
|
|
Raises: |
|
ValueError: If no such config is found. |
|
""" |
|
try: |
|
return cast(Type[BaseSolverConfig], cls._solver_config_registry[name]) |
|
except KeyError: |
|
raise ValueError(f"Solver config for '{name}' not found.") |
|
|