|
import os |
|
import sys |
|
sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
|
|
|
from embedding.zhipuai_embedding import ZhipuAIEmbeddings |
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings |
|
from langchain.embeddings.openai import OpenAIEmbeddings |
|
from llm.call_llm import parse_llm_api_key |
|
|
|
|
|
|
|
|
|
def get_embedding(embedding: str, embedding_key: str = None, env_file: str = None): |
|
if embedding == 'm3e': |
|
try: |
|
|
|
model = HuggingFaceEmbeddings( |
|
model_name="moka-ai/m3e-base", |
|
model_kwargs={'device': 'cpu'}, |
|
encode_kwargs={'normalize_embeddings': True} |
|
) |
|
return model |
|
except Exception as e: |
|
print(f"m3e 模型初始化失败: {str(e)}") |
|
raise |
|
|
|
if embedding_key == None: |
|
embedding_key = parse_llm_api_key(embedding) |
|
|
|
if embedding == "openai": |
|
try: |
|
model = OpenAIEmbeddings(openai_api_key=embedding_key) |
|
return model |
|
except Exception as e: |
|
print(f"OpenAI embedding 模型初始化失败: {str(e)}") |
|
raise |
|
elif embedding == "zhipuai": |
|
try: |
|
model = ZhipuAIEmbeddings(zhipuai_api_key=embedding_key) |
|
return model |
|
except Exception as e: |
|
print(f"智谱 embedding 模型初始化失败: {str(e)}") |
|
raise |
|
else: |
|
raise ValueError(f"embedding {embedding} not support ") |
|
|