|
from abc import ABC, abstractmethod |
|
|
|
from .action import ACTION_SPACE_PROMPT |
|
from .eval_type import ( |
|
GROUNDING, |
|
PROGRESS_LIKERT_SCALE, |
|
PROGRESS_THREE_CLASS, |
|
PROGRESS_WITH_CHECKLIST, |
|
PROGRESS_WITH_CHECKLIST_IN_PROGRESS, |
|
PROGRESS_OURS |
|
) |
|
from .input_information import ( |
|
USER_INSTRUCTION, |
|
TRAJECTORY, |
|
AGENT_RESPONSE, |
|
CHECKLIST, |
|
CURRENT_URL, |
|
TEXT_OBSERVATION, |
|
SOM_IMAGE_OBSERVATION, |
|
COORD_IMAGE_OBSERVATION |
|
) |
|
from .judge_prompt import ( |
|
JUDGE_GROUNDING_PROMPT_TEMPLATE, |
|
JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE, |
|
JUDGE_THREE_CLASS_PROMPT_TEMPLATE, |
|
JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE, |
|
JUDGE_OURS_PROMPT_TEMPLATE, |
|
JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE |
|
) |
|
from .checklist_prompt import ( |
|
CHECKLIST_SYSTEM_PROMPT, |
|
CHECKLIST_USER_PROMPT, |
|
CHECKLIST_OURS_SYSTEM_PROMPT, |
|
CHECKLIST_OURS_USER_PROMPT |
|
) |
|
from .image_utils import image_to_base64_url |
|
|
|
|
|
class Message(ABC): |
|
@abstractmethod |
|
def get_messages(self): |
|
pass |
|
|
|
class BaseMessage(Message): |
|
def __init__(self, input_info:dict, use_multimodal:bool=False): |
|
self.input_info = input_info |
|
self.use_multimodal = use_multimodal |
|
|
|
def _get_system_message(self): |
|
system_message = {"role": "system", "content": "You are a helpful assistant."} |
|
return system_message |
|
|
|
def _process_multimodal_message(self, prompt: str, image_list: list[str]): |
|
multimodal_message = [] |
|
text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0] |
|
text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1] |
|
multimodal_message.append({"type": "text", "text": text_prompt_prefix}) |
|
for i, image in enumerate(image_list): |
|
|
|
|
|
multimodal_message.append({"type": "image_url", "image_url": {"url": image_to_base64_url(image), "detail": "low"}}) |
|
multimodal_message.append({"type": "text", "text": text_prompt_suffix}) |
|
return {"role": "user", "content": multimodal_message} |
|
|
|
def _get_user_message(self): |
|
user_prompt = "What is the capital of France?" |
|
if self.use_multimodal: |
|
image_list = self.input_info.get("image_list", []) |
|
user_message = self._process_multimodal_message(user_prompt, image_list) |
|
else: |
|
user_message = {"role": "user", "content": user_prompt} |
|
return user_message |
|
|
|
def get_messages(self): |
|
message = [] |
|
system_message = self._get_system_message() |
|
user_message = self._get_user_message() |
|
|
|
message.append(system_message) |
|
|
|
message.append(user_message) |
|
return message |
|
|
|
|
|
class ProgressMessage(BaseMessage): |
|
''' |
|
Progress Judge Message |
|
''' |
|
def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str, use_checklist:bool, use_in_progress:bool): |
|
super().__init__(input_info, use_multimodal) |
|
self.prompt_type = prompt_type |
|
self.text_obs = text_obs |
|
self.image_obs = image_obs |
|
self.use_checklist = use_checklist |
|
self.use_in_progress = use_in_progress |
|
|
|
def _get_system_message(self): |
|
if self.prompt_type == "likert_scale": |
|
system_message = {"role": "system", "content": JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["system"]} |
|
elif self.prompt_type == "three_class": |
|
system_message = {"role": "system", "content": JUDGE_THREE_CLASS_PROMPT_TEMPLATE["system"]} |
|
elif self.prompt_type == "with_checklist": |
|
system_message = {"role": "system", "content": JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["system"]} |
|
elif self.prompt_type == "ours": |
|
system_message = {"role": "system", "content": JUDGE_OURS_PROMPT_TEMPLATE["system"]} |
|
else: |
|
raise ValueError(f"Invalid prompt type: {self.prompt_type}") |
|
return system_message |
|
|
|
def _setup_input_information(self): |
|
observation = "## Current State\n" |
|
|
|
observation += CURRENT_URL |
|
|
|
|
|
if self.text_obs: |
|
observation += TEXT_OBSERVATION |
|
|
|
|
|
if self.image_obs == "som": |
|
observation += SOM_IMAGE_OBSERVATION |
|
elif self.image_obs == "coord": |
|
observation += COORD_IMAGE_OBSERVATION |
|
|
|
|
|
if self.use_checklist: |
|
input_information = USER_INSTRUCTION + TRAJECTORY + observation + CHECKLIST + AGENT_RESPONSE |
|
else: |
|
input_information = USER_INSTRUCTION + TRAJECTORY + observation + AGENT_RESPONSE |
|
|
|
return input_information |
|
|
|
def _setup_task_info(self): |
|
if self.prompt_type == "likert_scale": |
|
task_description = PROGRESS_LIKERT_SCALE["task_description"] |
|
output_format = PROGRESS_LIKERT_SCALE["output_format"] |
|
elif self.prompt_type == "three_class": |
|
task_description = PROGRESS_THREE_CLASS["task_description"] |
|
output_format = PROGRESS_THREE_CLASS["output_format"] |
|
elif self.prompt_type == "with_checklist": |
|
if self.use_in_progress: |
|
task_description = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["task_description"] |
|
output_format = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["output_format"] |
|
else: |
|
task_description = PROGRESS_WITH_CHECKLIST["task_description"] |
|
output_format = PROGRESS_WITH_CHECKLIST["output_format"] |
|
else: |
|
raise ValueError(f"Invalid prompt type: {self.prompt_type}") |
|
return task_description, output_format |
|
|
|
def _get_user_prompt_template(self): |
|
if self.prompt_type == "likert_scale": |
|
user_prompt = JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["user"] |
|
elif self.prompt_type == "three_class": |
|
user_prompt = JUDGE_THREE_CLASS_PROMPT_TEMPLATE["user"] |
|
elif self.prompt_type == "with_checklist": |
|
user_prompt = JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["user"] |
|
else: |
|
raise ValueError(f"Invalid prompt type: {self.prompt_type}") |
|
return user_prompt |
|
|
|
def _get_user_message(self): |
|
|
|
input_information_template = self._setup_input_information() |
|
input_information = input_information_template.format(**self.input_info) |
|
|
|
if self.prompt_type == "ours": |
|
if self.use_checklist: |
|
user_prompt = JUDGE_OURS_PROMPT_TEMPLATE["user"].format( |
|
input_information=input_information, |
|
) |
|
else: |
|
user_prompt = JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE["user"].format( |
|
input_information=input_information, |
|
) |
|
else: |
|
task_description, output_format = self._setup_task_info() |
|
|
|
user_prompt_template = self._get_user_prompt_template() |
|
user_prompt = user_prompt_template.format( |
|
action_space=ACTION_SPACE_PROMPT, |
|
task_description=task_description, |
|
input_information=input_information, |
|
output_format=output_format |
|
) |
|
|
|
|
|
if self.use_multimodal: |
|
image_list = self.input_info.get("image_list", []) |
|
user_message = self._process_multimodal_message(user_prompt, image_list) |
|
else: |
|
user_message = {"role": "user", "content": user_prompt} |
|
|
|
return user_message |
|
|
|
|
|
class GroundingMessage(BaseMessage): |
|
''' |
|
Grounding Judge Message |
|
''' |
|
def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str): |
|
super().__init__(input_info, use_multimodal) |
|
self.prompt_type = prompt_type |
|
self.text_obs = text_obs |
|
self.image_obs = image_obs |
|
|
|
def _get_system_message(self): |
|
if self.prompt_type == "ours": |
|
|
|
system_message = {"role": "system", "content": "You are a helpful assistant."} |
|
elif self.prompt_type == "default": |
|
system_message = {"role": "system", "content": JUDGE_GROUNDING_PROMPT_TEMPLATE["system"]} |
|
else: |
|
raise ValueError(f"Invalid prompt type: {self.prompt_type}") |
|
return system_message |
|
|
|
def _setup_input_information(self): |
|
observation = "## Current State\n" |
|
|
|
observation += CURRENT_URL |
|
|
|
|
|
if self.text_obs: |
|
observation += TEXT_OBSERVATION |
|
|
|
|
|
if self.image_obs == "som": |
|
observation += SOM_IMAGE_OBSERVATION |
|
elif self.image_obs == "coord": |
|
observation += COORD_IMAGE_OBSERVATION |
|
|
|
|
|
input_information = USER_INSTRUCTION + observation + AGENT_RESPONSE |
|
|
|
return input_information |
|
|
|
def _get_user_message(self): |
|
if self.prompt_type == "ours": |
|
|
|
user_message = {"role": "user", "content": "TODO"} |
|
elif self.prompt_type == "default": |
|
action_space = ACTION_SPACE_PROMPT |
|
task_description = GROUNDING["task_description"] |
|
output_format = GROUNDING["output_format"] |
|
input_information_template = self._setup_input_information() |
|
input_information = input_information_template.format(**self.input_info) |
|
|
|
user_prompt = JUDGE_GROUNDING_PROMPT_TEMPLATE["user"].format( |
|
action_space=action_space, |
|
task_description=task_description, |
|
input_information=input_information, |
|
output_format=output_format |
|
) |
|
|
|
|
|
if self.use_multimodal: |
|
image_list = self.input_info.get("image_list", []) |
|
user_message = self._process_multimodal_message(user_prompt, image_list) |
|
else: |
|
user_message = {"role": "user", "content": user_prompt} |
|
else: |
|
raise ValueError(f"Invalid prompt type: {self.prompt_type}") |
|
return user_message |
|
|
|
|
|
class ChecklistMessage(BaseMessage): |
|
''' |
|
Checklist Message |
|
''' |
|
def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str): |
|
super().__init__(input_info, use_multimodal) |
|
self.prompt_type = prompt_type |
|
|
|
def _get_system_message(self): |
|
if self.prompt_type == "ours": |
|
|
|
system_message = {"role": "system", "content": CHECKLIST_OURS_SYSTEM_PROMPT} |
|
elif self.prompt_type == "default": |
|
system_message = {"role": "system", "content": CHECKLIST_SYSTEM_PROMPT} |
|
else: |
|
raise ValueError(f"Invalid prompt type: {self.prompt_type}") |
|
return system_message |
|
|
|
def _get_user_message(self): |
|
if self.prompt_type == "ours": |
|
user_message = {"role": "user", "content": CHECKLIST_OURS_USER_PROMPT.format(**self.input_info)} |
|
elif self.prompt_type == "default": |
|
user_message = {"role": "user", "content": CHECKLIST_USER_PROMPT.format(**self.input_info)} |
|
else: |
|
raise ValueError(f"Invalid prompt type: {self.prompt_type}") |
|
return user_message |
|
|
|
|
|
def get_messages(input_info:dict, inference_mode:str, prompt_type:str, text_obs:str=None, image_obs:str=None, use_multimodal:bool=False, use_checklist:bool=False, use_in_progress:bool=False): |
|
message_list = [] |
|
if inference_mode == "judge_grounding": |
|
message = GroundingMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs) |
|
elif inference_mode == "judge_progress": |
|
message = ProgressMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs, use_checklist=use_checklist, use_in_progress=use_in_progress) |
|
elif inference_mode == "checklist_generation": |
|
message = ChecklistMessage(input_info, use_multimodal=False, prompt_type=prompt_type) |
|
else: |
|
raise ValueError(f"Invalid inference mode: {inference_mode}") |
|
|
|
system_message, user_message = message.get_messages() |
|
|
|
message_list.append(system_message) |
|
message_list.append(user_message) |
|
return message_list |