|
import re |
|
import tempfile |
|
import logging |
|
import dataclasses |
|
|
|
from browsergym.core.action.highlevel import HighLevelActionSet |
|
from browsergym.experiments.agent import Agent |
|
from browsergym.experiments.loop import AbstractAgentArgs, EnvArgs, ExpArgs, get_exp_result |
|
from browsergym.utils.obs import flatten_axtree_to_str |
|
|
|
|
|
class MiniwobTestAgent(Agent): |
|
|
|
action_set = HighLevelActionSet(subsets="bid") |
|
|
|
def obs_preprocessor(self, obs: dict): |
|
return {"axtree_txt": flatten_axtree_to_str(obs["axtree_object"])} |
|
|
|
def get_action(self, obs: dict) -> tuple[str, dict]: |
|
match = re.search(r"^\s*\[(\d+)\].*button", obs["axtree_txt"], re.MULTILINE | re.IGNORECASE) |
|
|
|
if match: |
|
bid = match.group(1) |
|
action = f'click("{bid}")' |
|
else: |
|
raise Exception("Can't find the button's bid") |
|
|
|
return action, dict(think="I'm clicking the button as requested.") |
|
|
|
|
|
@dataclasses.dataclass |
|
class MiniwobTestAgentArgs(AbstractAgentArgs): |
|
def make_agent(self): |
|
return MiniwobTestAgent() |
|
|
|
|
|
def test_run_exp(): |
|
exp_args = ExpArgs( |
|
agent_args=MiniwobTestAgentArgs(), |
|
env_args=EnvArgs(task_name="miniwob.click-test", task_seed=42), |
|
) |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
exp_args.prepare(tmp_dir) |
|
exp_args.run() |
|
exp_result = get_exp_result(exp_args.exp_dir) |
|
exp_record = exp_result.get_exp_record() |
|
|
|
target = { |
|
"env_args.task_name": "miniwob.click-test", |
|
"env_args.task_seed": 42, |
|
"env_args.headless": True, |
|
"env_args.record_video": False, |
|
"n_steps": 1, |
|
"cum_reward": 1.0, |
|
"terminated": True, |
|
"truncated": False, |
|
} |
|
|
|
assert len(exp_result.steps_info) == 2 |
|
|
|
for key, target_val in target.items(): |
|
assert key in exp_record |
|
assert exp_record[key] == target_val |
|
|
|
|
|
assert exp_record["stats.cum_step_elapsed"] < 5 |
|
if exp_record["stats.cum_step_elapsed"] > 3: |
|
t = exp_record["stats.cum_step_elapsed"] |
|
logging.warning( |
|
f"miniwob.click-test is taking {t:.2f}s (> 3s) to solve with an oracle." |
|
) |
|
|