|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import RetrievalQA |
|
from langchain.vectorstores import Chroma |
|
import sys |
|
import os |
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
from model_to_llm import model_to_llm |
|
from get_vectordb import get_vectordb |
|
import sys |
|
import re |
|
|
|
class QA_chain_self(): |
|
"""" |
|
不带历史记录的问答链 |
|
- model:调用的模型名称 |
|
- temperature:温度系数,控制生成的随机性 |
|
- top_k:返回检索的前k个相似文档 |
|
- file_path:建库文件所在路径 |
|
- persist_path:向量数据库持久化路径 |
|
- appid:星火需要输入 |
|
- api_key:所有模型都需要 |
|
- Spark_api_secret:星火秘钥 |
|
- Wenxin_secret_key:文心秘钥 |
|
- embeddings:使用的embedding模型 |
|
- embedding_key:使用的embedding模型的秘钥(智谱或者OpenAI) |
|
- template:可以自定义提示模板,没有输入则使用默认的提示模板default_template_rq |
|
""" |
|
|
|
|
|
default_template_rq = """你是一个三国大乱斗系统的AI助手。请根据以下角色信息回答问题。 |
|
|
|
角色信息: |
|
{context} |
|
|
|
问题: {question} |
|
|
|
请按照以下规则回答: |
|
1. 如果问题是抽取角色或角色信息相关,请不要只回答角色名,而是回答以下信息: |
|
- 抽取结果:[角色名] |
|
- 角色特点:[简要描述] |
|
- 属性值:[列出关键属性] |
|
- 技能说明:[列出技能效果] |
|
|
|
|
|
2. 如果问题是战斗规则说明相关: |
|
请直接引用下面"回合制对战规则"部分的内容,包括: |
|
- 回合顺序:每回合速度快的一方先出手 |
|
- 行动选择:每回合可选择普通攻击、使用技能、休息(回复1%体力和10灵力) |
|
- 技能使用:需要支付相应消耗,无法支付则无法发动 |
|
- 伤害计算的逻辑:- 普通攻击伤害 = (攻击方攻击-防御方防御)/防御方耐力*2 |
|
- 技能附加效果(如增伤、减防、附加状态)独立计算。 |
|
- 胜负判定:体力降为0或以下即判负 |
|
|
|
3. 如果问题是对战、模拟战、战斗相关: |
|
- 对战双方:[双方角色名] |
|
- 行动类型:[普通攻击/技能/休息] |
|
- 行动结果:[详细描述] |
|
- 伤害计算:[如果有伤害,显示计算过程] |
|
- 状态变化:[双方状态变化] |
|
|
|
4. 如果问题是状态查询相关: |
|
- 当前状态:[详细描述] |
|
- 可用行动:[可选操作] |
|
- 建议策略:[战术建议] |
|
|
|
回答要严谨简约,并保持专业性和准确性。如果不知道答案,请直接说明,不要编造信息。""" |
|
|
|
def __init__(self, model:str, temperature:float=0.0, top_k:int=4, file_path:str=None, persist_path:str=None, appid:str=None, api_key:str=None, Spark_api_secret:str=None,Wenxin_secret_key:str=None, embedding = "openai", embedding_key = None, template=default_template_rq): |
|
self.model = model |
|
self.temperature = temperature |
|
self.top_k = top_k |
|
self.file_path = file_path |
|
self.persist_path = persist_path |
|
self.appid = appid |
|
self.api_key = api_key |
|
self.Spark_api_secret = Spark_api_secret |
|
self.Wenxin_secret_key = Wenxin_secret_key |
|
self.embedding = embedding |
|
self.embedding_key = embedding_key |
|
self.template = template |
|
self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding,self.embedding_key) |
|
self.llm = model_to_llm(self.model, self.temperature, self.appid, self.api_key, self.Spark_api_secret,self.Wenxin_secret_key) |
|
|
|
self.QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context","question"], |
|
template=self.template) |
|
self.retriever = self.vectordb.as_retriever( |
|
search_type="similarity_score_threshold", |
|
search_kwargs={ |
|
'k': self.top_k, |
|
'score_threshold': 0.3 |
|
} |
|
) |
|
|
|
self.qa_chain = RetrievalQA.from_chain_type(llm=self.llm, |
|
retriever=self.retriever, |
|
return_source_documents=True, |
|
chain_type_kwargs={"prompt":self.QA_CHAIN_PROMPT}) |
|
|
|
|
|
|
|
|
|
def answer(self, question:str=None, temperature = None, top_k = 4): |
|
"""" |
|
核心方法,调用问答链 |
|
arguments: |
|
- question:用户提问 |
|
""" |
|
|
|
if len(question) == 0: |
|
return "" |
|
|
|
if temperature == None: |
|
temperature = self.temperature |
|
|
|
if top_k == None: |
|
top_k = self.top_k |
|
|
|
result = self.qa_chain({"query": question, "temperature": temperature, "top_k": top_k}) |
|
answer = result["result"] |
|
answer = re.sub(r"\\n", '<br/>', answer) |
|
return answer |
|
|
|
def battle_analysis(self, character1: str, character2: str): |
|
""" |
|
分析两个角色之间的战斗 |
|
""" |
|
question = f"请分析{character1}和{character2}之间的战斗,谁会获胜?请详细分析双方的属性、技能和战斗策略。" |
|
return self.answer(question, temperature=0.7, top_k=6) |
|
|
|
def skill_analysis(self, character: str, skill: str): |
|
""" |
|
分析特定角色的特定技能 |
|
""" |
|
question = f"请详细分析{character}的{skill}技能,包括技能效果、消耗和最佳使用时机。" |
|
return self.answer(question, temperature=0.5, top_k=4) |
|
|
|
def character_comparison(self, character1: str, character2: str): |
|
""" |
|
比较两个角色的属性 |
|
""" |
|
question = f"请详细比较{character1}和{character2}的属性,包括攻击力、防御力、体力、耐力、法力、闪避和速度,并分析各自的优势和劣势。" |
|
return self.answer(question, temperature=0.5, top_k=6) |
|
|
|
def battle_strategy(self, character: str, opponent: str): |
|
""" |
|
为特定角色提供战斗策略 |
|
""" |
|
question = f"如果{character}要对抗{opponent},应该采用什么战斗策略?请根据双方的属性和技能给出具体的建议。" |
|
return self.answer(question, temperature=0.7, top_k=6) |
|
|
|
def draw_character(self, specific_character: str = None): |
|
""" |
|
抽取角色卡 |
|
specific_character: 如果指定了具体角色,则抽取该角色;否则随机抽取 |
|
""" |
|
if specific_character: |
|
question = f"请抽取{specific_character}的角色卡,并详细介绍其属性和技能特点。" |
|
else: |
|
question = "请随机抽取一张三国人物卡,并详细介绍该角色的属性和技能特点。" |
|
return self.answer(question, temperature=0.7, top_k=4) |
|
|
|
def battle_simulation(self, player_character: str, action: str): |
|
""" |
|
模拟战斗 |
|
player_character: 玩家角色 |
|
action: 玩家选择的行动(普通攻击/使用技能/休息) |
|
""" |
|
question = f"玩家使用{player_character}选择了{action},请模拟一场战斗,详细描述战斗过程和结果。" |
|
return self.answer(question, temperature=0.8, top_k=6) |
|
|
|
def character_status(self, character: str): |
|
""" |
|
查询角色状态 |
|
""" |
|
question = f"请查询{character}的当前状态,包括体力、灵力、技能冷却等信息。" |
|
return self.answer(question, temperature=0.5, top_k=4) |
|
|