Spaces:
Runtime error
Runtime error
class Prompter: | |
def __init__(self): | |
pass | |
def get_train_prompt(self, data: list[dict]) -> list[dict]: | |
pass | |
def __call__(self, data: list[dict], generate_format: bool = True) -> str: | |
pass | |
class Qwen2Prompter(Prompter): | |
def __init__(self): | |
self.bos_token = "<|im_start|>" | |
self.eos_token = "<|im_end|>" | |
self.role_list = ["user", "assistant", "system"] | |
self.assistant_role = "assistant" | |
self.system_role = "system" | |
self.user_role = "user" | |
self.default_system_prompt = "You are a helpful assistant." | |
self.prompt_template = "{bos_token}{role}\n{prompt}{eos_token}" | |
def get_train_prompt(self, data: list[dict]) -> list[dict]: | |
prompt_list = [] | |
conversation_length = len(data) | |
for idx, item in enumerate(data): | |
if item["from"] not in self.role_list: | |
raise ValueError(f"Role {item['from']} is not in the role list") | |
if item["from"] == self.assistant_role: | |
prompt_list.append( | |
{ | |
"prompt": f"{self.bos_token}{item['from']}\n", | |
"is_labels": False, | |
"from": item["from"], | |
} | |
) | |
prompt_list.append( | |
{ | |
"prompt": f"{item['value']}{self.eos_token}", | |
"is_labels": True, | |
"from": item["from"], | |
} # Make it labels | |
) | |
elif item["from"] == self.system_role or item["from"] == self.user_role: | |
prompt_list.append( | |
{ | |
"prompt": f"{self.bos_token}{item['from']}\n{item['value']}{self.eos_token}", | |
"is_labels": False, | |
"from": item["from"], | |
} | |
) | |
else: | |
raise ValueError(f"Role {item['from']} is not in the role list") | |
if idx != conversation_length - 1: | |
prompt_list.append( | |
{"prompt": "\n", "is_labels": False, "from": item["from"]} | |
) | |
return prompt_list | |
def __call__(self, data: list[dict]) -> str: | |
prompt_list = [] | |
for item in data: | |
if item["from"] not in self.role_list: | |
raise ValueError(f"Role {item['from']} is not in the role list") | |
prompt_list.append( | |
self.prompt_template.format( | |
bos_token=self.bos_token, | |
role=item["from"], | |
prompt=item["value"], | |
eos_token=self.eos_token, | |
) | |
) | |
prompt_list.append( | |
self.prompt_template.format( | |
bos_token=self.bos_token, | |
role=self.assistant_role, | |
prompt="", | |
eos_token="", | |
) | |
) | |
return "\n".join(prompt_list) | |
class Qwen2VLPrompter(Prompter): | |
def __init__(self): | |
self.bos_token = "<|im_start|>" | |
self.eos_token = "<|im_end|>" | |
self.role_list = ["user", "assistant", "system"] | |
self.assistant_role = "assistant" | |
self.system_role = "system" | |
self.user_role = "user" | |
self.default_system_prompt = "You are a helpful assistant." | |
self.prompt_template = "{bos_token}{role}\n{prompt}{eos_token}" | |
def get_train_prompt(self, data: list[dict]) -> list[dict]: | |
prompt_list = [] | |
conversation_length = len(data) | |
for idx, item in enumerate(data): | |
if item["from"] not in self.role_list: | |
raise ValueError(f"Role {item['from']} is not in the role list") | |
if item["from"] == self.assistant_role: | |
prompt_list.append( | |
{ | |
"prompt": f"{self.bos_token}{item['from']}\n", | |
"is_labels": False, | |
"from": item["from"], | |
} | |
) | |
prompt_list.append( | |
{ | |
"prompt": f"{item['value']}{self.eos_token}", | |
"is_labels": True, | |
"from": item["from"], | |
} # Make it labels | |
) | |
elif item["from"] == self.system_role or item["from"] == self.user_role: | |
prompt_list.append( | |
{ | |
"prompt": f"{self.bos_token}{item['from']}\n{item['value']}{self.eos_token}", | |
"is_labels": False, | |
"from": item["from"], | |
} | |
) | |
else: | |
raise ValueError(f"Role {item['from']} is not in the role list") | |
if idx != conversation_length - 1: | |
prompt_list.append( | |
{"prompt": "\n", "is_labels": False, "from": item["from"]} | |
) | |
return prompt_list | |
def __call__(self, data: list[dict]) -> str: | |
prompt_list = [] | |
for item in data: | |
if item["from"] not in self.role_list: | |
raise ValueError(f"Role {item['from']} is not in the role list") | |
prompt_list.append( | |
self.prompt_template.format( | |
bos_token=self.bos_token, | |
role=item["from"], | |
prompt=item["value"], | |
eos_token=self.eos_token, | |
) | |
) | |
prompt_list.append( | |
self.prompt_template.format( | |
bos_token=self.bos_token, | |
role=self.assistant_role, | |
prompt="", | |
eos_token="", | |
) | |
) | |
return "\n".join(prompt_list) | |
PROMPT_TYPE = { | |
'llava': Qwen2Prompter, | |
'qwen2vl': Qwen2VLPrompter, | |
'qwen2p5vl': Qwen2VLPrompter, | |
} |