File size: 1,575 Bytes
d04a18c 45c5d09 d04a18c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from embedding.zhipuai_embedding import ZhipuAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from llm.call_llm import parse_llm_api_key
def get_embedding(embedding: str, embedding_key: str = None, env_file: str = None):
if embedding == 'm3e':
try:
# 修改为直接使用HuggingFace模型名称自动下载
model = HuggingFaceEmbeddings(
model_name="moka-ai/m3e-base",
model_kwargs={'device': 'cpu'}, # 根据配置选择cpu/cuda
encode_kwargs={'normalize_embeddings': True}
)
return model
except Exception as e:
print(f"m3e 模型初始化失败: {str(e)}")
raise
if embedding_key == None:
embedding_key = parse_llm_api_key(embedding)
if embedding == "openai":
try:
model = OpenAIEmbeddings(openai_api_key=embedding_key)
return model
except Exception as e:
print(f"OpenAI embedding 模型初始化失败: {str(e)}")
raise
elif embedding == "zhipuai":
try:
model = ZhipuAIEmbeddings(zhipuai_api_key=embedding_key)
return model
except Exception as e:
print(f"智谱 embedding 模型初始化失败: {str(e)}")
raise
else:
raise ValueError(f"embedding {embedding} not support ")
|