SANGUO / embedding /call_embedding.py
konghuan's picture
1
45c5d09
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from embedding.zhipuai_embedding import ZhipuAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from llm.call_llm import parse_llm_api_key
def get_embedding(embedding: str, embedding_key: str = None, env_file: str = None):
if embedding == 'm3e':
try:
# 修改为直接使用HuggingFace模型名称自动下载
model = HuggingFaceEmbeddings(
model_name="moka-ai/m3e-base",
model_kwargs={'device': 'cpu'}, # 根据配置选择cpu/cuda
encode_kwargs={'normalize_embeddings': True}
)
return model
except Exception as e:
print(f"m3e 模型初始化失败: {str(e)}")
raise
if embedding_key == None:
embedding_key = parse_llm_api_key(embedding)
if embedding == "openai":
try:
model = OpenAIEmbeddings(openai_api_key=embedding_key)
return model
except Exception as e:
print(f"OpenAI embedding 模型初始化失败: {str(e)}")
raise
elif embedding == "zhipuai":
try:
model = ZhipuAIEmbeddings(zhipuai_api_key=embedding_key)
return model
except Exception as e:
print(f"智谱 embedding 模型初始化失败: {str(e)}")
raise
else:
raise ValueError(f"embedding {embedding} not support ")