File size: 1,575 Bytes
d04a18c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45c5d09
 
 
 
 
 
d04a18c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from embedding.zhipuai_embedding import ZhipuAIEmbeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings
from llm.call_llm import parse_llm_api_key




def get_embedding(embedding: str, embedding_key: str = None, env_file: str = None):
    if embedding == 'm3e':
        try:
            # 修改为直接使用HuggingFace模型名称自动下载
            model = HuggingFaceEmbeddings(
                model_name="moka-ai/m3e-base",
                model_kwargs={'device': 'cpu'},  # 根据配置选择cpu/cuda
                encode_kwargs={'normalize_embeddings': True}
            )
            return model
        except Exception as e:
            print(f"m3e 模型初始化失败: {str(e)}")
            raise
            
    if embedding_key == None:
        embedding_key = parse_llm_api_key(embedding)
        
    if embedding == "openai":
        try:
            model = OpenAIEmbeddings(openai_api_key=embedding_key)
            return model
        except Exception as e:
            print(f"OpenAI embedding 模型初始化失败: {str(e)}")
            raise
    elif embedding == "zhipuai":
        try:
            model = ZhipuAIEmbeddings(zhipuai_api_key=embedding_key)
            return model
        except Exception as e:
            print(f"智谱 embedding 模型初始化失败: {str(e)}")
            raise
    else:
        raise ValueError(f"embedding {embedding} not support ")