|
from __future__ import annotations |
|
|
|
import datetime |
|
import json |
|
import os |
|
import uuid |
|
from pathlib import Path |
|
from typing import Any, Optional, Self |
|
|
|
from pydantic import BaseModel, Field |
|
|
|
from proxy_lite.environments import EnvironmentConfigTypes |
|
from proxy_lite.environments.environment_base import Action, Observation |
|
from proxy_lite.history import MessageHistory |
|
from proxy_lite.solvers import SolverConfigTypes |
|
|
|
|
|
class Run(BaseModel): |
|
run_id: str |
|
task: str |
|
created_at: str |
|
complete: bool = False |
|
terminated_at: str | None = None |
|
evaluation: dict[str, Any] | None = None |
|
history: list[Observation | Action] = Field(default_factory=list) |
|
solver_history: MessageHistory | None = None |
|
result: str | None = None |
|
env_info: dict[str, Any] = Field(default_factory=dict) |
|
environment: Optional[EnvironmentConfigTypes] = None |
|
solver: Optional[SolverConfigTypes] = None |
|
|
|
@classmethod |
|
def initialise(cls, task: str) -> Self: |
|
run_id = str(uuid.uuid4()) |
|
return cls( |
|
run_id=run_id, |
|
task=task, |
|
created_at=str(datetime.datetime.now(datetime.UTC)), |
|
) |
|
|
|
@classmethod |
|
def load(cls, run_id: str) -> Self: |
|
with open(Path(__file__).parent.parent.parent / "local_trajectories" / f"{run_id}.json", "r") as f: |
|
return cls(**json.load(f)) |
|
|
|
@property |
|
def observations(self) -> list[Observation]: |
|
return [h for h in self.history if isinstance(h, Observation)] |
|
|
|
@property |
|
def actions(self) -> list[Action]: |
|
return [h for h in self.history if isinstance(h, Action)] |
|
|
|
@property |
|
def last_action(self) -> Action | None: |
|
return self.actions[-1] if self.actions else None |
|
|
|
@property |
|
def last_observation(self) -> Observation | None: |
|
return self.observations[-1] if self.observations else None |
|
|
|
def record( |
|
self, |
|
observation: Optional[Observation] = None, |
|
action: Optional[Action] = None, |
|
solver_history: Optional[MessageHistory] = None, |
|
) -> None: |
|
|
|
if observation and action: |
|
raise ValueError("Only one of observation and action can be provided") |
|
if observation: |
|
self.history.append(observation) |
|
if action: |
|
self.history.append(action) |
|
if solver_history: |
|
self.solver_history = solver_history |
|
|
|
def terminate(self) -> None: |
|
self.terminated_at = str(datetime.datetime.now(datetime.UTC)) |
|
|
|
|
|
class DataRecorder: |
|
def __init__(self, local_folder: str | None = None): |
|
self.local_folder = local_folder |
|
|
|
def initialise_run(self, task: str) -> Run: |
|
self.local_folder = Path(__file__).parent.parent.parent / "local_trajectories" |
|
os.makedirs(self.local_folder, exist_ok=True) |
|
return Run.initialise(task) |
|
|
|
async def terminate( |
|
self, |
|
run: Run, |
|
save: bool = True, |
|
) -> None: |
|
run.terminate() |
|
if save: |
|
await self.save(run) |
|
|
|
async def save(self, run: Run) -> None: |
|
json_payload = run.model_dump() |
|
with open(self.local_folder / f"{run.run_id}.json", "w") as f: |
|
json.dump(json_payload, f) |
|
|