|
|
|
|
|
''' |
|
@File : call_llm.py |
|
@Time : 2023/10/18 10:45:00 |
|
@Author : Logan Zou |
|
@Version : 1.0 |
|
@Contact : loganzou0421@163.com |
|
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA |
|
@Desc : 将各个大模型的原生接口封装在一个接口 |
|
''' |
|
|
|
import openai |
|
import json |
|
import requests |
|
import _thread as thread |
|
import base64 |
|
import datetime |
|
from dotenv import load_dotenv, find_dotenv |
|
import hashlib |
|
import hmac |
|
import os |
|
import queue |
|
from urllib.parse import urlparse |
|
import ssl |
|
from datetime import datetime |
|
from time import mktime |
|
from urllib.parse import urlencode |
|
from wsgiref.handlers import format_date_time |
|
import zhipuai |
|
from langchain.utils import get_from_dict_or_env |
|
|
|
import websocket |
|
|
|
def get_completion(prompt :str, model :str, temperature=0.1,api_key=None, secret_key=None, access_token=None, appid=None, api_secret=None, max_tokens=2048): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model in ["gpt-3.5-turbo", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-32k"]: |
|
return get_completion_gpt(prompt, model, temperature, api_key, max_tokens) |
|
elif model in ["ERNIE-Bot", "ERNIE-Bot-4", "ERNIE-Bot-turbo"]: |
|
return get_completion_wenxin(prompt, model, temperature, api_key, secret_key) |
|
elif model in ["Spark-1.5", "Spark-2.0", "Spark-X1"]: |
|
return get_completion_spark(prompt, model, temperature, api_key, appid, api_secret, max_tokens) |
|
elif model in ["chatglm_pro", "chatglm_std", "chatglm_lite"]: |
|
return get_completion_glm(prompt, model, temperature, api_key, max_tokens) |
|
elif model in ["qwen-turbo", "qwen-plus", "qwen-max"]: |
|
return get_completion_ali(prompt, model, temperature, api_key, max_tokens) |
|
else: |
|
return "不正确的模型" |
|
|
|
def get_completion_gpt(prompt : str, model : str, temperature : float, api_key:str, max_tokens:int): |
|
|
|
if api_key == None: |
|
api_key = parse_llm_api_key("openai") |
|
openai.api_key = api_key |
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
response = openai.ChatCompletion.create( |
|
model=model, |
|
messages=messages, |
|
temperature=temperature, |
|
max_tokens = max_tokens, |
|
) |
|
|
|
return response.choices[0].message["content"] |
|
|
|
def get_access_token(api_key, secret_key): |
|
""" |
|
使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key |
|
""" |
|
|
|
url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}" |
|
|
|
payload = json.dumps("") |
|
headers = { |
|
'Content-Type': 'application/json', |
|
'Accept': 'application/json' |
|
} |
|
|
|
response = requests.request("POST", url, headers=headers, data=payload) |
|
return response.json().get("access_token") |
|
|
|
def get_completion_wenxin(prompt : str, model : str, temperature : float, api_key:str, secret_key : str): |
|
|
|
if api_key == None or secret_key == None: |
|
api_key, secret_key = parse_llm_api_key("wenxin") |
|
|
|
access_token = get_access_token(api_key, secret_key) |
|
|
|
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token={access_token}" |
|
|
|
payload = json.dumps({ |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": "{}".format(prompt) |
|
} |
|
] |
|
}) |
|
headers = { |
|
'Content-Type': 'application/json' |
|
} |
|
|
|
response = requests.request("POST", url, headers=headers, data=payload) |
|
|
|
js = json.loads(response.text) |
|
return js["result"] |
|
|
|
def get_completion_spark(prompt: str, model: str, temperature: float, api_key: str, appid: str, api_secret: str, max_tokens: int): |
|
if api_key is None or appid is None or api_secret is None: |
|
api_key, appid, api_secret = parse_llm_api_key("spark") |
|
if model == "Spark-X1": |
|
domain = "x1" |
|
Spark_url = "wss://spark-api.xf-yun.com/v1/x1" |
|
elif model == "Spark-1.5": |
|
domain = "general" |
|
Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" |
|
else: |
|
domain = "generalv2" |
|
Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" |
|
|
|
|
|
system_prompt = """你是一个三国大乱斗系统的AI助手。你能提供以下功能。 |
|
系统功能: |
|
1. 角色抽取:随机抽取三国人物卡并展示完整信息,包括: |
|
- 角色名 |
|
- 角色特点 |
|
- 属性值 |
|
- 技能说明 |
|
2.对战规程介绍: |
|
- 回合制对战规则: |
|
- 每回合速度快的一方先出手 |
|
- 行动选择:每回合可选择普通攻击、使用技能、休息(回复1%体力和10灵力) |
|
- 技能使用:需要支付相应消耗,无法支付则无法发动 |
|
- 伤害计算的逻辑:- 普通攻击伤害 = (攻击方攻击-防御方防御)/防御方耐力*2 |
|
- 技能附加效果(如增伤、减防、附加状态)独立计算。 |
|
- 胜负判定:体力降为0或以下即判负 |
|
3. 对战系统:为抽取的角色随机匹配对手进行回合制对战。""" |
|
|
|
question = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": prompt} |
|
] |
|
|
|
response = spark_main(appid, api_key, api_secret, Spark_url, domain, question, temperature, max_tokens) |
|
return response |
|
|
|
def get_completion_glm(prompt : str, model : str, temperature : float, api_key:str, max_tokens : int): |
|
|
|
if api_key == None: |
|
api_key = parse_llm_api_key("zhipuai") |
|
zhipuai.api_key = api_key |
|
|
|
response = zhipuai.model_api.invoke( |
|
model=model, |
|
prompt=[{"role":"user", "content":prompt}], |
|
temperature = temperature, |
|
max_tokens=max_tokens |
|
) |
|
return response["data"]["choices"][0]["content"].strip('"').strip(" ") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
answer = "" |
|
|
|
class Ws_Param(object): |
|
|
|
def __init__(self, APPID, APIKey, APISecret, Spark_url): |
|
self.APPID = APPID |
|
self.APIKey = APIKey |
|
self.APISecret = APISecret |
|
self.host = urlparse(Spark_url).netloc |
|
self.path = urlparse(Spark_url).path |
|
self.Spark_url = Spark_url |
|
|
|
self.temperature = 0 |
|
self.max_tokens = 2048 |
|
|
|
|
|
def create_url(self): |
|
|
|
now = datetime.now() |
|
date = format_date_time(mktime(now.timetuple())) |
|
|
|
|
|
signature_origin = "host: " + self.host + "\n" |
|
signature_origin += "date: " + date + "\n" |
|
signature_origin += "GET " + self.path + " HTTP/1.1" |
|
|
|
|
|
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), |
|
digestmod=hashlib.sha256).digest() |
|
|
|
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') |
|
|
|
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' |
|
|
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') |
|
|
|
|
|
v = { |
|
"authorization": authorization, |
|
"date": date, |
|
"host": self.host |
|
} |
|
|
|
url = self.Spark_url + '?' + urlencode(v) |
|
|
|
return url |
|
|
|
|
|
|
|
def on_error(ws, error): |
|
print("### error:", error) |
|
|
|
|
|
|
|
def on_close(ws,one,two): |
|
print(" ") |
|
|
|
|
|
|
|
def on_open(ws): |
|
thread.start_new_thread(run, (ws,)) |
|
|
|
|
|
def run(ws, *args): |
|
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature = ws.temperature, max_tokens = ws.max_tokens)) |
|
ws.send(data) |
|
|
|
|
|
|
|
def on_message(ws, message): |
|
try: |
|
data = json.loads(message) |
|
code = data['header']['code'] |
|
content = '' |
|
if code != 0: |
|
print(f'请求错误: {code}, {data}') |
|
ws.close() |
|
else: |
|
choices = data["payload"]["choices"] |
|
status = choices["status"] |
|
text = choices['text'][0] |
|
|
|
if ('reasoning_content' in text and '' != text['reasoning_content']): |
|
reasoning_content = text["reasoning_content"] |
|
print(reasoning_content, end="") |
|
global isFirstcontent |
|
isFirstcontent = True |
|
|
|
if('content' in text and '' != text['content']): |
|
content = text["content"] |
|
if(True == isFirstcontent): |
|
print("\n*******************以上为思维链内容,模型回复内容如下********************\n") |
|
print(content, end="") |
|
isFirstcontent = False |
|
global answer |
|
answer += content |
|
if status == 2: |
|
ws.close() |
|
except Exception as e: |
|
print(f"处理消息时出错: {str(e)}") |
|
print(f"原始消息: {message}") |
|
ws.close() |
|
|
|
|
|
def gen_params(appid, domain, question, temperature, max_tokens): |
|
""" |
|
通过appid和用户的提问来生成请参数 |
|
""" |
|
data = { |
|
"header": { |
|
"app_id": appid, |
|
"uid": "1234" |
|
}, |
|
"parameter": { |
|
"chat": { |
|
"domain": domain, |
|
"temperature": temperature, |
|
"max_tokens": max_tokens |
|
} |
|
}, |
|
"payload": { |
|
"message": { |
|
"text": question |
|
} |
|
} |
|
} |
|
return data |
|
|
|
|
|
def spark_main(appid, api_key, api_secret, Spark_url, domain, question, temperature, max_tokens): |
|
|
|
if not all([appid, api_key, api_secret]): |
|
raise ValueError("缺少必要的认证参数:appid, api_key, api_secret") |
|
|
|
global answer |
|
answer = "" |
|
global isFirstcontent |
|
isFirstcontent = False |
|
|
|
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url) |
|
websocket.enableTrace(False) |
|
wsUrl = wsParam.create_url() |
|
|
|
ws = websocket.WebSocketApp(wsUrl, |
|
on_message=on_message, |
|
on_error=on_error, |
|
on_close=on_close, |
|
on_open=on_open) |
|
ws.appid = appid |
|
ws.question = question |
|
ws.domain = domain |
|
ws.temperature = temperature |
|
ws.max_tokens = max_tokens |
|
|
|
|
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) |
|
|
|
return answer |
|
|
|
def parse_llm_api_key(model:str, env_file:dict()=None): |
|
""" |
|
通过 model 和 env_file 的来解析平台参数 |
|
""" |
|
if env_file == None: |
|
_ = load_dotenv(find_dotenv()) |
|
env_file = os.environ |
|
if model == "openai": |
|
return env_file["OPENAI_API_KEY"] |
|
elif model == "wenxin": |
|
return env_file["wenxin_api_key"], env_file["wenxin_secret_key"] |
|
elif model == "spark": |
|
return env_file["spark_api_key"], env_file["spark_appid"], env_file["spark_api_secret"] |
|
elif model == "zhipuai": |
|
return get_from_dict_or_env(env_file, "zhipuai_api_key", "ZHIPUAI_API_KEY") |
|
|
|
elif model == "ali": |
|
return env_file["ali_api_key"] |
|
else: |
|
raise ValueError(f"model{model} not support!!!") |
|
|
|
|
|
def get_completion_ali(prompt: str, model: str, temperature: float, api_key: str, max_tokens: int): |
|
"""阿里通义千问大模型接口""" |
|
if api_key is None: |
|
api_key = parse_llm_api_key("ali") |
|
|
|
url = "https://dashscope.aliyuncs.com/compatible-mode/v1" |
|
headers = { |
|
"Authorization": f"Bearer {api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
payload = { |
|
"model": model, |
|
"input": { |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": prompt |
|
} |
|
] |
|
}, |
|
"parameters": { |
|
"temperature": temperature, |
|
"max_tokens": max_tokens |
|
} |
|
} |
|
|
|
response = requests.post(url, headers=headers, json=payload) |
|
if response.status_code == 200: |
|
return response.json()["output"]["text"] |
|
else: |
|
return f"请求失败: {response.text}" |
|
|