|
|
|
|
|
''' |
|
@File : spark_llm.py |
|
@Time : 2023/10/16 18:53:26 |
|
@Author : Logan Zou |
|
@Version : 1.0 |
|
@Contact : loganzou0421@163.com |
|
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA |
|
@Desc : 基于讯飞星火大模型自定义 LLM 类 |
|
''' |
|
|
|
from langchain.llms.base import LLM |
|
from typing import Any, List, Mapping, Optional, Dict, Union, Tuple |
|
from pydantic import Field |
|
from llm.self_llm import Self_LLM |
|
import json |
|
import requests |
|
from langchain.callbacks.manager import CallbackManagerForLLMRun |
|
import _thread as thread |
|
import base64 |
|
import datetime |
|
import hashlib |
|
import hmac |
|
import json |
|
from urllib.parse import urlparse |
|
import ssl |
|
from datetime import datetime |
|
from time import mktime |
|
from urllib.parse import urlencode |
|
from wsgiref.handlers import format_date_time |
|
import websocket |
|
import queue |
|
|
|
class Spark_LLM(Self_LLM): |
|
|
|
|
|
url : str = "wss://spark-api.xf-yun.com/v1/x1" |
|
|
|
appid : str = None |
|
|
|
api_secret : str = None |
|
|
|
domain :str = "x1" |
|
|
|
max_tokens : int = 4096 |
|
|
|
model : str = "Spark-X1" |
|
|
|
text: list[str] = [] |
|
|
|
def __init__(self, model: str = "Spark-X1", temperature: float = 0.0, appid: str = None, api_secret: str = None, api_key: str = None): |
|
super().__init__() |
|
self.temperature = temperature |
|
self.appid = appid |
|
self.api_secret = api_secret |
|
self.api_key = api_key |
|
self.text = [] |
|
|
|
def getText(self, role, content): |
|
jsoncon = {} |
|
jsoncon["role"] = role |
|
jsoncon["content"] = content |
|
self.text.append(jsoncon) |
|
return self.text |
|
|
|
def _call(self, prompt : str, stop: Optional[List[str]] = None, |
|
run_manager: Optional[CallbackManagerForLLMRun] = None, |
|
**kwargs: Any): |
|
|
|
if self.api_key == None or self.appid == None or self.api_secret == None: |
|
|
|
print("请填入 Key") |
|
raise ValueError("Key 不存在") |
|
|
|
print("正在准备问题...") |
|
question = self.getText("user", prompt) |
|
|
|
try: |
|
print("正在调用星火大模型...") |
|
response = spark_main(self.appid, self.api_key, self.api_secret, self.url, self.domain, question, self.temperature, self.max_tokens) |
|
|
|
print("收到模型回复,正在保存...") |
|
self.getText("assistant", response) |
|
print("Spark_LLM._call 执行完成") |
|
return response |
|
except Exception as e: |
|
print(f"请求失败: {str(e)}") |
|
print("请求失败") |
|
return "请求失败" |
|
|
|
@property |
|
def _llm_type(self) -> str: |
|
return "spark" |
|
|
|
answer = "" |
|
|
|
class Ws_Param(object): |
|
|
|
def __init__(self, APPID, APIKey, APISecret, Spark_url): |
|
self.APPID = APPID |
|
self.APIKey = APIKey |
|
self.APISecret = APISecret |
|
self.host = urlparse(Spark_url).netloc |
|
self.path = urlparse(Spark_url).path |
|
self.Spark_url = Spark_url |
|
|
|
self.temperature = 0 |
|
self.max_tokens = 2048 |
|
|
|
|
|
def create_url(self): |
|
|
|
now = datetime.now() |
|
date = format_date_time(mktime(now.timetuple())) |
|
|
|
|
|
signature_origin = "host: " + self.host + "\n" |
|
signature_origin += "date: " + date + "\n" |
|
signature_origin += "GET " + self.path + " HTTP/1.1" |
|
|
|
|
|
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), |
|
digestmod=hashlib.sha256).digest() |
|
|
|
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') |
|
|
|
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' |
|
|
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') |
|
|
|
|
|
v = { |
|
"authorization": authorization, |
|
"date": date, |
|
"host": self.host |
|
} |
|
|
|
url = self.Spark_url + '?' + urlencode(v) |
|
return url |
|
|
|
|
|
|
|
def on_error(ws, error): |
|
print("### error:", error) |
|
|
|
|
|
|
|
def on_close(ws,one,two): |
|
print(" ") |
|
|
|
|
|
|
|
def on_open(ws): |
|
thread.start_new_thread(run, (ws,)) |
|
|
|
|
|
def run(ws, *args): |
|
try: |
|
data = json.dumps(gen_params(appid=ws.appid, domain=ws.domain, question=ws.question)) |
|
ws.send(data) |
|
except Exception as e: |
|
print(f"发送数据时出错: {str(e)}") |
|
ws.close() |
|
|
|
|
|
|
|
def on_message(ws, message): |
|
try: |
|
data = json.loads(message) |
|
code = data['header']['code'] |
|
content = '' |
|
if code != 0: |
|
print(f'请求错误: {code}, {data}') |
|
ws.close() |
|
else: |
|
choices = data["payload"]["choices"] |
|
status = choices["status"] |
|
text = choices['text'][0] |
|
|
|
if ('reasoning_content' in text and '' != text['reasoning_content']): |
|
reasoning_content = text["reasoning_content"] |
|
print(reasoning_content, end="") |
|
global isFirstcontent |
|
isFirstcontent = True |
|
|
|
if('content' in text and '' != text['content']): |
|
content = text["content"] |
|
if(True == isFirstcontent): |
|
print("\n*******************以上为思维链内容,模型回复内容如下********************\n") |
|
print(content, end="") |
|
isFirstcontent = False |
|
global answer |
|
answer += content |
|
if status == 2: |
|
ws.close() |
|
except Exception as e: |
|
print(f"处理消息时出错: {str(e)}") |
|
ws.close() |
|
|
|
|
|
def gen_params(appid, domain, question, temperature=1.2, max_tokens=32768): |
|
""" |
|
通过appid和用户的提问来生成请参数 |
|
""" |
|
data = { |
|
"header": { |
|
"app_id": appid, |
|
"uid": "1234" |
|
}, |
|
"parameter": { |
|
"chat": { |
|
"domain": domain, |
|
"temperature": temperature, |
|
"max_tokens": max_tokens |
|
} |
|
}, |
|
"payload": { |
|
"message": { |
|
"text": question |
|
} |
|
} |
|
} |
|
return data |
|
|
|
|
|
def spark_main(appid, api_key, api_secret, Spark_url, domain, question, temperature, max_tokens): |
|
print("spark_main 开始执行...") |
|
|
|
if not all([appid, api_key, api_secret]): |
|
raise ValueError("缺少必要的认证参数:appid, api_key, api_secret") |
|
|
|
global answer |
|
answer = "" |
|
global isFirstcontent |
|
isFirstcontent = False |
|
|
|
print("正在创建 WebSocket 连接...") |
|
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url) |
|
websocket.enableTrace(False) |
|
wsUrl = wsParam.create_url() |
|
|
|
ws = websocket.WebSocketApp(wsUrl, |
|
on_message=on_message, |
|
on_error=on_error, |
|
on_close=on_close, |
|
on_open=on_open) |
|
ws.appid = appid |
|
ws.question = question |
|
ws.domain = domain |
|
|
|
print("正在启动 WebSocket 连接...") |
|
|
|
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) |
|
|
|
print("spark_main 执行完成") |
|
return answer |
|
|
|
|
|
|
|
|
|
|
|
|