|
from pathlib import Path |
|
import multiprocessing |
|
import logging |
|
from PIL import Image |
|
import io |
|
import base64 |
|
import numpy as np |
|
import gymnasium as gym |
|
import os |
|
|
|
from agent.checklist import generate_checklist |
|
from agent.reward import get_ar_reward |
|
|
|
from browser_agent import BrowserAgent |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel('INFO') |
|
|
|
templates_dir = Path(__file__).parent / "templates" |
|
CSS_RM_CARDS: str = (templates_dir / "rm_cards.css").read_text() |
|
CSS_TRAJECTORY: str = (templates_dir / "trajectory.css").read_text() |
|
CARD_HTML_TEMPLATE: str = (templates_dir / "card.html").read_text() |
|
|
|
RM_BASE_URL = os.environ['RM_BASE_URL'] |
|
RM_MODEL_NAME = os.environ['RM_MODEL_NAME'] |
|
|
|
def return_state(state, screenshot=None): |
|
return state, None, None, screenshot, None |
|
|
|
def run_agent(instruction: str, model_name: str = "gpt-4o", start_url: str = "about:blank", |
|
use_html: bool = False, use_axtree: bool = True, use_screenshot: bool = False, max_steps: int = 20): |
|
logger.info(f"Starting agent with instruction: {instruction}") |
|
logger.info(f"Configuration: model={model_name}, start_url={start_url}") |
|
|
|
trajectory = [] |
|
trajectory_str = '' |
|
agent = BrowserAgent( |
|
model_name=model_name, |
|
use_html=use_html, |
|
use_axtree=use_axtree, |
|
use_screenshot=use_screenshot |
|
) |
|
|
|
|
|
logger.info("Initializing BrowserGym environment") |
|
yield return_state("## Initializing BrowserGym environment...", None) |
|
env = gym.make( |
|
"browsergym/openended", |
|
task_kwargs={ |
|
"start_url": start_url, |
|
"goal": instruction, |
|
}, |
|
wait_for_user_message=True |
|
) |
|
obs, info = env.reset() |
|
logger.info("Environment initialized") |
|
|
|
|
|
logger.info("Sending user instruction to environment") |
|
obs, reward, terminated, truncated, info = env.step({ |
|
"type": "send_msg_to_user", |
|
"message": instruction |
|
}) |
|
processed_obs = agent.obs_preprocessor(obs) |
|
logger.info(f"Obs: {processed_obs.keys()}") |
|
logger.info(f"axtree_txt: {processed_obs['axtree_txt']}") |
|
|
|
yield return_state("## Generating checklist...", obs['som_screenshot']) |
|
checklist = generate_checklist(intent=instruction, start_url=start_url, text_observation=processed_obs['axtree_txt']) |
|
|
|
|
|
current_screenshot = obs['som_screenshot'].copy() |
|
yield "## Rollout actions from policy...", checklist, [], current_screenshot, trajectory.copy() |
|
|
|
try: |
|
step_count = 0 |
|
while step_count < max_steps: |
|
logger.info(f"Step {step_count}: Getting next action") |
|
|
|
candidates, _ = agent.get_action(processed_obs) |
|
|
|
yield return_state(f"## Rewarding actions...", current_screenshot) |
|
|
|
total_rewards, total_thoughts = get_ar_reward( |
|
dataset=[ |
|
{ |
|
'text_observation': processed_obs['axtree_txt'], |
|
'intent': instruction, |
|
'trajectory': trajectory_str, |
|
'current_url': processed_obs['open_pages_urls'][processed_obs['active_page_index'][0]], |
|
'checklist': checklist, |
|
'thought': cand['thought'], |
|
'action': cand['action'], |
|
} for cand in candidates |
|
], |
|
base_url=RM_BASE_URL, |
|
model_name=RM_MODEL_NAME, |
|
) |
|
|
|
|
|
diff_reward = abs(max(total_rewards) - total_rewards[0]) |
|
if diff_reward <= 0.01: |
|
logger.info(f"diff_reward: {diff_reward} -> most frequent action") |
|
max_index = 0 |
|
else: |
|
logger.info(f"diff_reward: {diff_reward} -> highest reward") |
|
max_index = total_rewards.index(max(total_rewards)) |
|
|
|
|
|
sorted_indices = sorted(list(enumerate(total_rewards)), key=lambda x: (-1 if x[0] == max_index else 0, -x[1])) |
|
new_order = [idx for idx, _ in sorted_indices] |
|
candidates = [candidates[idx] for idx in new_order] |
|
total_rewards = [total_rewards[idx] for idx in new_order] |
|
total_thoughts = [total_thoughts[idx] for idx in new_order] |
|
|
|
best_cand = candidates[0] |
|
|
|
agent.action_history.append(best_cand['response']) |
|
|
|
action = best_cand['action'] |
|
|
|
|
|
step_info = { |
|
'thought': best_cand['thought'], |
|
'action': action |
|
} |
|
current_cards = [{'thought': cand['thought'], 'action': cand['action'], 'feedback': feedback, 'reward': round(reward, 2)} for idx, (cand, reward, feedback) in enumerate(zip(candidates, total_rewards, total_thoughts))] |
|
|
|
trajectory_str += f'THOUGHT {step_count+1}: {step_info["thought"]}\nACTION {step_count+1}: {step_info["action"]}\n\n' |
|
|
|
|
|
logger.info(f"Step {step_count}: Executing action: {action}") |
|
yield f"## Executing action: {action}", checklist, current_cards, current_screenshot, trajectory.copy() |
|
if action.startswith('send_msg_to_user'): |
|
terminated = True |
|
truncated = False |
|
else: |
|
obs, reward, terminated, truncated, info = env.step(action) |
|
trajectory.append((processed_obs['som_screenshot'], [{'action': cand['action'], 'reward': round(reward, 2)} for cand, reward in zip(candidates, total_rewards)])) |
|
processed_obs = agent.obs_preprocessor(obs) |
|
current_screenshot = processed_obs['som_screenshot'].copy() |
|
|
|
while '\n\n' in step_info['thought']: |
|
step_info['thought'] = step_info['thought'].replace('\n\n', '\n') |
|
|
|
|
|
logger.info(f"Step {step_count}: Saved screenshot and updated trajectory") |
|
step_count += 1 |
|
|
|
|
|
yield "## Rollout actions from policy...", checklist, current_cards, current_screenshot, trajectory.copy() |
|
|
|
if terminated or truncated: |
|
logger.info(f"Episode ended: terminated={terminated}, truncated={truncated}") |
|
yield return_state("## Episode ended", current_screenshot) |
|
break |
|
|
|
finally: |
|
logger.info("Finished") |
|
|
|
|
|
def run_agent_worker(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps, return_queue): |
|
"""Worker function that runs the agent in a separate process and puts results in a queue.""" |
|
try: |
|
for result in run_agent(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps): |
|
return_queue.put(result) |
|
except Exception as e: |
|
logger.error(f"Error in agent worker process: {e}") |
|
return_queue.put(("Error occurred in agent process", [], None, [])) |
|
import traceback |
|
traceback.print_exc() |
|
finally: |
|
|
|
return_queue.put(None) |
|
|
|
def run_agent_wrapper(instruction, model_name="gpt-4o", start_url="about:blank", |
|
use_html=False, use_axtree=True, use_screenshot=False, max_steps=20): |
|
"""Wrapper function that runs the agent in a separate process and yields results.""" |
|
return_queue = multiprocessing.Queue() |
|
|
|
|
|
p = multiprocessing.Process( |
|
target=run_agent_worker, |
|
args=(instruction, model_name, start_url, use_html, use_axtree, use_screenshot, max_steps, return_queue) |
|
) |
|
p.daemon = True |
|
p.start() |
|
|
|
|
|
while True: |
|
result = return_queue.get() |
|
if result is None: |
|
break |
|
yield result |
|
|
|
|
|
if p.is_alive(): |
|
p.terminate() |
|
p.join() |
|
|
|
def process_run(instruction, model_name, start_url): |
|
|
|
trajectory_generator = run_agent_wrapper( |
|
instruction, |
|
model_name, |
|
start_url, |
|
use_html=False, |
|
use_axtree=True, |
|
use_screenshot=False |
|
) |
|
|
|
all_trajectory = [] |
|
last_checklist_view, last_trajectory_html = None, None |
|
|
|
for state, checklist_view, rm_cards, screenshot, trajectory in trajectory_generator: |
|
if checklist_view is None: |
|
yield state, screenshot, last_checklist_view, None, last_trajectory_html |
|
continue |
|
|
|
rm_cards_html = f""" |
|
<style> |
|
{CSS_RM_CARDS} |
|
</style> |
|
<div class="rm-cards-container"> |
|
""" |
|
|
|
for idx, card in enumerate(rm_cards): |
|
rm_cards_html += CARD_HTML_TEMPLATE.format( |
|
additional_class='top-candidate' if idx == 0 else '', |
|
k=idx+1, |
|
suffix='(best)' if idx == 0 else '', |
|
thought=card['thought'], |
|
action=card['action'], |
|
reward=card['reward'], |
|
feedback=card['feedback'] |
|
) |
|
|
|
rm_cards_html += "</div>" |
|
all_trajectory = trajectory |
|
|
|
|
|
trajectory_html = f""" |
|
<style> |
|
{CSS_TRAJECTORY} |
|
</style> |
|
<div class="trajectory-container"> |
|
""" |
|
|
|
for idx, (after_img, cands) in enumerate(all_trajectory): |
|
|
|
img = all_trajectory[idx][0] |
|
if isinstance(img, np.ndarray): |
|
img = Image.fromarray(img) |
|
if isinstance(img, Image.Image): |
|
buffer = io.BytesIO() |
|
img.save(buffer, format="JPEG") |
|
img_str = base64.b64encode(buffer.getvalue()).decode() |
|
img_src = f"data:image/jpeg;base64,{img_str}" |
|
else: |
|
img_src = img |
|
|
|
trajectory_html += f""" |
|
<div class="step-container"> |
|
<div class="step-header">Step {idx + 1}</div> |
|
<div class="step-content"> |
|
<div class="step-image"> |
|
<img src="{img_src}" alt="Browser state"> |
|
</div> |
|
<div class="step-info"> |
|
<div class="box-title">Action Candidates:</div> |
|
<div class="action-candidates"> |
|
""" |
|
|
|
|
|
for i, cand in enumerate(cands): |
|
action = cand['action'] |
|
reward = cand['reward'] |
|
|
|
trajectory_html += f""" |
|
<div class="candidate-box{' selected' if i == 0 else ''}"> |
|
<div class="box-title"> |
|
Action {i+1}{' (Selected)' if i == 0 else ''} |
|
<span class="reward-text">Reward: {reward}</span> |
|
</div> |
|
<pre>{action}</pre> |
|
</div> |
|
""" |
|
|
|
trajectory_html += """ |
|
</div> |
|
</div> |
|
</div> |
|
</div> |
|
""" |
|
|
|
trajectory_html += "</div>" |
|
|
|
last_checklist_view, last_trajectory_html = checklist_view, trajectory_html |
|
yield state, screenshot, last_checklist_view, rm_cards_html, last_trajectory_html |
|
yield state, screenshot, last_checklist_view, rm_cards_html, last_trajectory_html |
|
|