|
from abc import ABC, abstractmethod |
|
import time |
|
import requests |
|
import json |
|
import math |
|
from langsmith import Client |
|
import numpy as np |
|
from langchain_openai import ChatOpenAI |
|
|
|
from .prompts import get_messages |
|
from .prompts.judge_prompt import ( |
|
JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE, |
|
JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE, |
|
JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE, |
|
JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE |
|
) |
|
from .prompts.image_utils import image_to_base64_url |
|
from .prompts.utils import convert_dict_messages |
|
|
|
MAX_RETRY = 3 |
|
RETRY_SLEEP = 5 |
|
MODEL_COST_MAPPING = { |
|
"gpt-4o-mini": { |
|
"input_token_cost": 0.15, |
|
"output_token_cost": 0.6 |
|
}, |
|
"gpt-4o": { |
|
"input_token_cost": 2.5, |
|
"output_token_cost": 10 |
|
}, |
|
} |
|
|
|
|
|
class Agent(ABC): |
|
@abstractmethod |
|
def generate_response(self, inputs: dict) -> str: |
|
pass |
|
|
|
class BaseAgent(Agent): |
|
def __init__(self, agent_config: dict): |
|
self.agent_config = agent_config |
|
self._setup() |
|
|
|
def _init_llm_object(self, **extra_kwargs): |
|
config = self.agent_config |
|
config.update(extra_kwargs) |
|
|
|
use_log_probs = config.get("use_log_probs", False) |
|
if use_log_probs: |
|
self.llm = ChatOpenAI( |
|
model=config["model_name"], |
|
base_url=config["base_url"], |
|
api_key=config["api_key"], |
|
temperature=config["temperature"], |
|
timeout=300, |
|
logprobs=True, |
|
top_logprobs=10, |
|
n=config.get('n', None) |
|
) |
|
else: |
|
self.llm = ChatOpenAI( |
|
model=config["model_name"], |
|
base_url=config["base_url"], |
|
api_key=config["api_key"], |
|
temperature=config["temperature"], |
|
timeout=300, |
|
n=config.get('n', None) |
|
) |
|
|
|
def _setup(self): |
|
self._init_llm_object() |
|
|
|
self.temperature = self.agent_config["temperature"] |
|
self.num_generate = self.agent_config["num_generate"] |
|
self.use_checklist = self.agent_config.get("use_checklist", False) |
|
self.use_multimodal = self.agent_config.get("use_multimodal", False) |
|
|
|
|
|
model_cost = MODEL_COST_MAPPING.get(self.agent_config["model_name"], None) |
|
if model_cost and "api" in self.agent_config["base_url"]: |
|
self.input_token_cost = model_cost["input_token_cost"] |
|
self.output_token_cost = model_cost["output_token_cost"] |
|
else: |
|
self.input_token_cost = 0.0 |
|
self.output_token_cost = 0.0 |
|
|
|
def generate_with_retry(self, model_input, constraint_str_list: list = None): |
|
total_input_tokens = 0 |
|
total_output_tokens = 0 |
|
if self.temperature == 0: |
|
response = self.llm.invoke(model_input) |
|
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"] |
|
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"] |
|
else: |
|
for i in range(MAX_RETRY): |
|
try: |
|
response = self.llm.invoke(model_input) |
|
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"] |
|
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"] |
|
if constraint_str_list: |
|
pass_constraint_num = 0 |
|
for constraint_str in constraint_str_list: |
|
if constraint_str in response.content: |
|
pass_constraint_num += 1 |
|
if pass_constraint_num == len(constraint_str_list): |
|
break |
|
else: |
|
print(f"Agent has fomat issue, retry... {i+1}/{MAX_RETRY}") |
|
else: |
|
break |
|
except Exception as e: |
|
print(f"Agent returned an Error: {e}") |
|
response = None |
|
time.sleep(RETRY_SLEEP) |
|
|
|
cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000 |
|
|
|
if response is None: |
|
return "", cost |
|
else: |
|
return response.content, cost |
|
|
|
def prepare_message(self, model_input: dict, prompt_type: str): |
|
message = [] |
|
return message |
|
|
|
def generate_response(self, model_input: dict, prompt_type: str, constraint_str_list: list = None,): |
|
total_cost = 0 |
|
response_list = [] |
|
|
|
message = self.prepare_message(model_input, prompt_type) |
|
|
|
|
|
for i in range(self.num_generate): |
|
response, cost = self.generate_with_retry(message, constraint_str_list) |
|
response_list.append(response) |
|
total_cost += cost |
|
|
|
return response_list, total_cost |
|
|
|
|
|
class GroundingJudgeAgent(BaseAgent): |
|
def __init__(self, agent_config: dict): |
|
super().__init__(agent_config) |
|
self._setup() |
|
|
|
def prepare_message(self, model_input: dict, prompt_type): |
|
message = get_messages( |
|
input_info=model_input, |
|
inference_mode="judge_grounding", |
|
prompt_type=prompt_type, |
|
use_multimodal=self.use_multimodal, |
|
text_obs=self.agent_config["text_obs_type"], |
|
image_obs=self.agent_config["image_obs_type"] |
|
) |
|
return message |
|
|
|
|
|
class ProgressJudgeAgent(BaseAgent): |
|
def __init__(self, agent_config: dict): |
|
super().__init__(agent_config) |
|
self._setup() |
|
|
|
def prepare_message(self, model_input: dict, prompt_type): |
|
if self.agent_config["input_type"]=="text_only": |
|
use_multimodal = False |
|
text_obs = self.agent_config["text_obs_type"] |
|
image_obs = None |
|
elif self.agent_config["input_type"]=="image_only": |
|
use_multimodal = True |
|
text_obs = None |
|
image_obs = self.agent_config["image_obs_type"] |
|
elif self.agent_config["input_type"]=="text_image": |
|
use_multimodal = True |
|
text_obs = self.agent_config["text_obs_type"] |
|
image_obs = self.agent_config["image_obs_type"] |
|
else: |
|
raise ValueError(f"Invalid input type: {self.agent_config['input_type']}") |
|
|
|
if self.agent_config["use_in_progress"]: |
|
use_in_progress = True |
|
else: |
|
use_in_progress = False |
|
|
|
message = get_messages( |
|
input_info=model_input, |
|
inference_mode="judge_progress", |
|
prompt_type=prompt_type, |
|
use_checklist=self.use_checklist, |
|
use_multimodal=use_multimodal, |
|
text_obs=text_obs, |
|
image_obs=image_obs, |
|
use_in_progress=use_in_progress |
|
) |
|
return message |
|
|
|
def get_judge_probs(self, logprobs: list): |
|
|
|
|
|
|
|
|
|
|
|
target_judge = { |
|
"yes": [ |
|
"ĠYes", "Yes", "ĊYes", |
|
"Ġyes", "yes", "Ċyes", |
|
"ĠYES", "YES", "ĊYES", |
|
"ĠDone", "Done", "ĊDone", |
|
"ĠCompleted", "Completed", "ĊCompleted", |
|
"ĠCorrect", "Correct", "ĊCorrect" |
|
], |
|
"no": [ |
|
"ĠNo", "No", "ĊNo", |
|
"ĠNO", "NO", "ĊNO", |
|
"ĠNot", "Not", "ĊNot", |
|
"ĠNone", "None", "ĊNone", |
|
"ĠNope", "Nope", "ĊNope", |
|
"ĠUn", "Un", "ĊUn", |
|
"ĠWrong", "Wrong", "ĊWrong" |
|
], |
|
"in": [ |
|
"ĠIn", "In", "ĊIn", |
|
"ĠPending", "Pending", "ĊPending", |
|
"ĠPart", "Part", "ĊPart", |
|
"ĠPartial", "Partial", "ĊPartial", |
|
"ĠInProgress", "InProgress", "ĊInProgress" |
|
] |
|
} |
|
response_str = "" |
|
judge_probs_list = [] |
|
for i, log_prob in enumerate(logprobs): |
|
|
|
if "<answer>" in response_str: |
|
find_judge_str = False |
|
for judge_type in target_judge: |
|
if log_prob["token"] in target_judge[judge_type]: |
|
|
|
find_judge_str = True |
|
break |
|
if find_judge_str: |
|
token_judge_dict = { |
|
"yes": None, |
|
"no": None, |
|
"in": None |
|
} |
|
for token_info in log_prob["top_logprobs"]: |
|
for judge_type in target_judge: |
|
for judge_str in target_judge[judge_type]: |
|
if judge_str in token_info["token"] : |
|
if token_judge_dict[judge_type] is None: |
|
token_judge_dict[judge_type] = math.exp(token_info["logprob"]) |
|
else: |
|
token_judge_dict[judge_type] += math.exp(token_info["logprob"]) |
|
|
|
token_judge_dict = { |
|
"yes": math.log(token_judge_dict["yes"]) if token_judge_dict["yes"] is not None else -float('inf'), |
|
"no": math.log(token_judge_dict["no"]) if token_judge_dict["no"] is not None else -float('inf'), |
|
"in": math.log(token_judge_dict["in"]) if token_judge_dict["in"] is not None else -float('inf') |
|
} |
|
judge_probs_list.append(token_judge_dict) |
|
|
|
if "</answer>" in response_str: |
|
break |
|
|
|
response_str += log_prob["token"] |
|
|
|
if len(judge_probs_list) == 0: |
|
return [{ |
|
"yes": 0.0, |
|
"no": 0.0, |
|
"in": 0.0 |
|
}] |
|
else: |
|
|
|
final_judge_probs_list = [] |
|
max_in_prob = -float('inf') |
|
for idx, judge_probs in enumerate(judge_probs_list): |
|
exp_logprobs = [math.exp(x) for x in [judge_probs["yes"], judge_probs["no"], judge_probs["in"]]] |
|
sum_exp_logprobs = sum(exp_logprobs) |
|
softmax_probs = [x / sum_exp_logprobs for x in exp_logprobs] |
|
if softmax_probs[2] > max_in_prob: |
|
max_in_prob = softmax_probs[2] |
|
final_judge_probs_list.append({ |
|
"yes": softmax_probs[0], |
|
"no": softmax_probs[1], |
|
"in": softmax_probs[2] |
|
}) |
|
return final_judge_probs_list |
|
|
|
def generate_probs(self, model_input: dict, prompt_type: str, n=1, temperature=None): |
|
total_cost = 0 |
|
|
|
message = self.prepare_message(model_input, prompt_type) |
|
messages = convert_dict_messages(message) |
|
|
|
kwargs = {'n': n} |
|
if temperature is not None: |
|
kwargs['temperature'] = temperature |
|
self._init_llm_object(**kwargs) |
|
|
|
try: |
|
response = self.llm.generate([messages]) |
|
finally: |
|
print('request url: ', self.agent_config['base_url']) |
|
|
|
|
|
|
|
response_list = [] |
|
for generation in response.generations[0]: |
|
|
|
logprobs = generation.message.response_metadata["logprobs"]["content"] |
|
response_list.append( |
|
{ |
|
"response": generation.message.content, |
|
"judge_probs": self.get_judge_probs(logprobs) |
|
} |
|
) |
|
|
|
|
|
total_input_tokens = response.llm_output["token_usage"]["prompt_tokens"] |
|
total_output_tokens = response.llm_output["token_usage"]["completion_tokens"] |
|
total_cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000 |
|
|
|
return response_list, total_cost |
|
|
|
|
|
class ChecklistGenerationAgent(BaseAgent): |
|
def __init__(self, agent_config: dict): |
|
super().__init__(agent_config) |
|
self._setup() |
|
|
|
def prepare_message(self, model_input: dict, prompt_type): |
|
message = get_messages( |
|
input_info=model_input, |
|
inference_mode="checklist_generation", |
|
prompt_type=prompt_type |
|
) |
|
return message |
|
|
|
|
|
class ClassifierRewardAgent(Agent): |
|
def __init__(self, url: str, use_checklist: bool = False, use_multimodal: bool = False): |
|
self.url = url |
|
self.use_checklist = use_checklist |
|
self.use_multimodal = use_multimodal |
|
|
|
def _process_multimodal_message(self, prompt: str, image_list: list[str]): |
|
multimodal_message = [] |
|
text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0] |
|
text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1] |
|
multimodal_message = [ |
|
{"type": "text", "text": text_prompt_prefix}, |
|
|
|
{"type": "image", "image": image_to_base64_url(image_list[0])}, |
|
{"type": "text", "text": text_prompt_suffix} |
|
] |
|
return multimodal_message |
|
|
|
def _make_query(self, user_prompt_template: dict, model_input: dict | list[dict]): |
|
if self.use_multimodal: |
|
tmp_user_prompt = user_prompt_template["user"].format( |
|
**model_input |
|
) |
|
user_prompt = self._process_multimodal_message(tmp_user_prompt, model_input["image_list"]) |
|
else: |
|
user_prompt = user_prompt_template["user"].format( |
|
**model_input |
|
) |
|
assistant_prompt = user_prompt_template["assistant"].format( |
|
**model_input |
|
) |
|
query = [ |
|
{"role": "user", "content": user_prompt}, |
|
{"role": "assistant", "content": assistant_prompt} |
|
] |
|
return query |
|
|
|
def prepare_message(self, model_input: dict | list[dict], batch: bool = False): |
|
if self.use_checklist: |
|
if self.use_multimodal: |
|
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE |
|
else: |
|
user_prompt_template = JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE |
|
else: |
|
if self.use_multimodal: |
|
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE |
|
else: |
|
user_prompt_template = JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE |
|
|
|
if self.use_multimodal: |
|
if batch: |
|
message = [self._make_query(user_prompt_template, input) for input in model_input] |
|
else: |
|
message = [self._make_query(user_prompt_template, model_input)] |
|
else: |
|
if batch: |
|
message = { |
|
"query": [self._make_query(user_prompt_template, input) for input in model_input], |
|
"promptts": [] |
|
} |
|
else: |
|
message = { |
|
"query": self._make_query(user_prompt_template, model_input), |
|
"prompts": [] |
|
} |
|
|
|
return message |
|
|
|
def get_rm_scroe(self, message: dict | list): |
|
headers = {"Content-Type": "application/json"} |
|
|
|
try: |
|
if self.use_multimodal: |
|
response = requests.post( |
|
self.url, |
|
json={"messages": message}, |
|
timeout=600 |
|
) |
|
else: |
|
response = requests.post( |
|
self.url, |
|
headers=headers, |
|
data=json.dumps(message), |
|
timeout=300 |
|
) |
|
response.raise_for_status() |
|
|
|
response_json = response.json() |
|
|
|
if "rewards" not in response_json: |
|
print(f"Error: 'rewards' key not found in API response: {response_json}") |
|
return [] |
|
|
|
if "get_reward" in self.url: |
|
|
|
return response_json["rewards"] |
|
elif "pooling" in self.url: |
|
|
|
return response_json["reward"] |
|
else: |
|
|
|
raise ValueError(f"Invalid URL: {self.url}") |
|
|
|
except requests.exceptions.Timeout: |
|
print(f"Error: Request timed out to {self.url}") |
|
return [] |
|
except requests.exceptions.RequestException as e: |
|
print(f"Error during request to {self.url}: {e}") |
|
return [] |
|
except json.JSONDecodeError: |
|
print(f"Error: Failed to decode JSON response from {self.url}") |
|
return [] |
|
except KeyError as e: |
|
print(f"Error: Missing key {e} in response from {self.url}") |
|
return [] |
|
|
|
|
|
def generate_response(self, model_input: dict | list[dict], batch: bool = False): |
|
if batch: |
|
message = self.prepare_message(model_input, batch=True) |
|
else: |
|
message = self.prepare_message(model_input) |
|
rewards = self.get_rm_scroe(message) |
|
|
|
return rewards, 0 |