UniWorld-V1 / univa /utils /prompter.py
LinB203
init
0c8d55e
class Prompter:
def __init__(self):
pass
def get_train_prompt(self, data: list[dict]) -> list[dict]:
pass
def __call__(self, data: list[dict], generate_format: bool = True) -> str:
pass
class Qwen2Prompter(Prompter):
def __init__(self):
self.bos_token = "<|im_start|>"
self.eos_token = "<|im_end|>"
self.role_list = ["user", "assistant", "system"]
self.assistant_role = "assistant"
self.system_role = "system"
self.user_role = "user"
self.default_system_prompt = "You are a helpful assistant."
self.prompt_template = "{bos_token}{role}\n{prompt}{eos_token}"
def get_train_prompt(self, data: list[dict]) -> list[dict]:
prompt_list = []
conversation_length = len(data)
for idx, item in enumerate(data):
if item["from"] not in self.role_list:
raise ValueError(f"Role {item['from']} is not in the role list")
if item["from"] == self.assistant_role:
prompt_list.append(
{
"prompt": f"{self.bos_token}{item['from']}\n",
"is_labels": False,
"from": item["from"],
}
)
prompt_list.append(
{
"prompt": f"{item['value']}{self.eos_token}",
"is_labels": True,
"from": item["from"],
} # Make it labels
)
elif item["from"] == self.system_role or item["from"] == self.user_role:
prompt_list.append(
{
"prompt": f"{self.bos_token}{item['from']}\n{item['value']}{self.eos_token}",
"is_labels": False,
"from": item["from"],
}
)
else:
raise ValueError(f"Role {item['from']} is not in the role list")
if idx != conversation_length - 1:
prompt_list.append(
{"prompt": "\n", "is_labels": False, "from": item["from"]}
)
return prompt_list
def __call__(self, data: list[dict]) -> str:
prompt_list = []
for item in data:
if item["from"] not in self.role_list:
raise ValueError(f"Role {item['from']} is not in the role list")
prompt_list.append(
self.prompt_template.format(
bos_token=self.bos_token,
role=item["from"],
prompt=item["value"],
eos_token=self.eos_token,
)
)
prompt_list.append(
self.prompt_template.format(
bos_token=self.bos_token,
role=self.assistant_role,
prompt="",
eos_token="",
)
)
return "\n".join(prompt_list)
class Qwen2VLPrompter(Prompter):
def __init__(self):
self.bos_token = "<|im_start|>"
self.eos_token = "<|im_end|>"
self.role_list = ["user", "assistant", "system"]
self.assistant_role = "assistant"
self.system_role = "system"
self.user_role = "user"
self.default_system_prompt = "You are a helpful assistant."
self.prompt_template = "{bos_token}{role}\n{prompt}{eos_token}"
def get_train_prompt(self, data: list[dict]) -> list[dict]:
prompt_list = []
conversation_length = len(data)
for idx, item in enumerate(data):
if item["from"] not in self.role_list:
raise ValueError(f"Role {item['from']} is not in the role list")
if item["from"] == self.assistant_role:
prompt_list.append(
{
"prompt": f"{self.bos_token}{item['from']}\n",
"is_labels": False,
"from": item["from"],
}
)
prompt_list.append(
{
"prompt": f"{item['value']}{self.eos_token}",
"is_labels": True,
"from": item["from"],
} # Make it labels
)
elif item["from"] == self.system_role or item["from"] == self.user_role:
prompt_list.append(
{
"prompt": f"{self.bos_token}{item['from']}\n{item['value']}{self.eos_token}",
"is_labels": False,
"from": item["from"],
}
)
else:
raise ValueError(f"Role {item['from']} is not in the role list")
if idx != conversation_length - 1:
prompt_list.append(
{"prompt": "\n", "is_labels": False, "from": item["from"]}
)
return prompt_list
def __call__(self, data: list[dict]) -> str:
prompt_list = []
for item in data:
if item["from"] not in self.role_list:
raise ValueError(f"Role {item['from']} is not in the role list")
prompt_list.append(
self.prompt_template.format(
bos_token=self.bos_token,
role=item["from"],
prompt=item["value"],
eos_token=self.eos_token,
)
)
prompt_list.append(
self.prompt_template.format(
bos_token=self.bos_token,
role=self.assistant_role,
prompt="",
eos_token="",
)
)
return "\n".join(prompt_list)
PROMPT_TYPE = {
'llava': Qwen2Prompter,
'qwen2vl': Qwen2VLPrompter,
'qwen2p5vl': Qwen2VLPrompter,
}