File size: 8,622 Bytes
d04a18c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45c5d09
 
 
d04a18c
 
45c5d09
 
d04a18c
 
 
45c5d09
d04a18c
45c5d09
 
 
d04a18c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45c5d09
d04a18c
 
 
45c5d09
 
 
d04a18c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import os
import sys
import re
import json
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import tempfile
from dotenv import load_dotenv, find_dotenv
from embedding.call_embedding import get_embedding
from langchain.document_loaders import UnstructuredFileLoader
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.document_loaders import PyMuPDFLoader
from langchain.document_loaders import UnstructuredWordDocumentLoader
from langchain.vectorstores import Chroma
from langchain.schema import Document

# 禁用 Pebblo 安全模块
# os.environ["PEBBLO_DISABLED"] = "1"  # 新增环境变量

# 设置模型缓存目录
CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
# os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
# 修改环境变量设置(保留一个即可)
os.environ['HF_HOME'] = CACHE_DIR

# 首先实现基本配置
# 原代码
DEFAULT_DB_PATH = "./knowledge_db/sanguo_characters"
DEFAULT_PERSIST_PATH = "./vector_db/chroma_sanguo"




class CharacterTextSplitter(TextSplitter):
    """专门用于处理角色JSON数据的文本分割器"""
    
    def split_text(self, text: str) -> list[str]:
        # 使用更健壮的正则表达式匹配每个角色的JSON数据
        pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
        matches = re.finditer(pattern, text)
        
        # 将每个匹配的JSON字符串转换为文本块
        chunks = []
        for match in matches:
            try:
                # 解析JSON数据
                char_data = json.loads(match.group())
                
                # 检查必要字段
                if 'name' not in char_data:
                    print(f"警告:发现缺少name字段的JSON数据: {match.group()[:100]}...")
                    continue
                
                # 处理技能数据,将stamina_cost转换为endurance_cost
                if 'skills' in char_data:
                    for skill in char_data['skills']:
                        if 'stamina_cost' in skill:
                            skill['endurance_cost'] = skill.pop('stamina_cost')
                
                # 将JSON数据转换为易读的文本格式
                char_text = f"角色:{char_data['name']}\n"
                char_text += f"攻击力:{char_data['attack']}\n"
                char_text += f"防御力:{char_data['defense']}\n"
                char_text += f"体力:{char_data['stamina']}\n"
                char_text += f"耐力:{char_data['endurance']}\n"
                char_text += f"法力:{char_data['mana']}\n"
                char_text += f"闪避:{char_data['dodge']}\n"
                char_text += f"速度:{char_data['speed']}\n"
                char_text += "技能:\n"
                for skill in char_data['skills']:
                    char_text += f"- {skill['name']}{skill['effect']}\n"
                    if 'endurance_cost' in skill and 'mana_cost' in skill:
                        char_text += f"  耐力消耗:{skill['endurance_cost']},法力消耗:{skill['mana_cost']}\n"
                chunks.append(char_text)
            except json.JSONDecodeError as e:
                print(f"JSON解析错误: {e}")
                print(f"问题数据: {match.group()[:100]}...")
                continue
            except KeyError as e:
                print(f"缺少字段: {e}")
                print(f"问题数据: {match.group()[:100]}...")
                continue
        return chunks

    def split_documents(self, documents: list[Document]) -> list[Document]:
        """分割文档列表"""
        texts = []
        metadatas = []
        for doc in documents:
            texts.extend(self.split_text(doc.page_content))
            metadatas.extend([doc.metadata] * len(self.split_text(doc.page_content)))
        return [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]


def get_files(dir_path):
    file_list = []
    for filepath, dirnames, filenames in os.walk(dir_path):
        for filename in filenames:
            file_list.append(os.path.join(filepath, filename))
    return file_list


def file_loader(file, loaders):
    if isinstance(file, tempfile._TemporaryFileWrapper):
        file = file.name
    if not os.path.isfile(file):
        [file_loader(os.path.join(file, f), loaders) for f in  os.listdir(file)]
        return
    file_type = file.split('.')[-1].lower()
    if file_type == 'pdf':
        loaders.append(PyMuPDFLoader(file))
    elif file_type == 'md':
        pattern = r"不存在|风控"
        match = re.search(pattern, file)
        if not match:
            loaders.append(UnstructuredMarkdownLoader(file))
    elif file_type == 'txt':
        loaders.append(UnstructuredFileLoader(file))
    elif file_type == 'docx':
        loaders.append(UnstructuredWordDocumentLoader(file))
    return


def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH):
    if embeddings == 'openai' or embeddings == 'm3e' or embeddings =='zhipuai':
        vectordb = create_db(files, persist_directory, embeddings)
    return ""


def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="openai"):
    """
    该函数用于加载文件,切分文档,生成文档的嵌入向量,创建向量数据库。

    参数:
    file: 存放文件的路径。
    embeddings: 用于生产 Embedding 的模型

    返回:
    vectordb: 创建的数据库。
    """
    if files == None:
        return "can't load empty file"
    if type(files) != list:
        files = [files]
    
    print(f"正在处理文件路径: {files}")
    
    loaders = []
    [file_loader(file, loaders) for file in files]
    print(f"找到的加载器数量: {len(loaders)}")
    
    docs = []
    for loader in loaders:
        if loader is not None:
            loaded_docs = loader.load()
            print(f"\n加载的文档数量: {len(loaded_docs)}")
            # 打印第一个文档的内容示例
            if loaded_docs:
                print("\n文档内容示例:")
                print("-" * 50)
                print(loaded_docs[0].page_content[:500])  # 只打印前500个字符
                print("-" * 50)
                print("\n文档元数据:")
                print(loaded_docs[0].metadata)
                print("-" * 50)
            docs.extend(loaded_docs)
    
    print(f"\n总文档数量: {len(docs)}")
    
    if len(docs) == 0:
        print("警告:没有找到任何文档!")
        return None
    
    # 使用自定义的角色文本分割器
    text_splitter = CharacterTextSplitter()
    split_docs = text_splitter.split_documents(docs)
    print(f"\n分割后的文档数量: {len(split_docs)}")
    
    if len(split_docs) == 0:
        print("警告:分割后没有文档!")
        return None
    
    # 保存分割后的文档到文件
    split_docs_dir = os.path.join(os.path.dirname(persist_directory), "split_docs")
    os.makedirs(split_docs_dir, exist_ok=True)
    split_docs_file = os.path.join(split_docs_dir, "split_documents.txt")
    
    with open(split_docs_file, "w", encoding="utf-8") as f:
        for i, doc in enumerate(split_docs, 1):
            f.write(f"\n文档 {i}:\n")
            f.write("-" * 50 + "\n")
            f.write(doc.page_content)
            f.write("\n" + "-" * 50 + "\n")
    
    print(f"\n分割后的文档已保存到: {split_docs_file}")
    
    if type(embeddings) == str:
        embeddings = get_embedding(embedding=embeddings)
    # 修正参数名称和初始化方式
    vectordb = Chroma.from_documents(
        documents=split_docs,
        embedding=embeddings,
        persist_directory=persist_directory,
        collection_metadata={"hnsw:space": "cosine"}  # 新增元数据配置
    )

    vectordb.persist()
    return vectordb


def presit_knowledge_db(vectordb):
    """
    该函数用于持久化向量数据库。

    参数:
    vectordb: 要持久化的向量数据库。
    """
    vectordb.persist()


def load_knowledge_db(path, embeddings):
    """
    该函数用于加载向量数据库。

    参数:
    path: 要加载的向量数据库路径。
    embeddings: 向量数据库使用的 embedding 模型。

    返回:
    vectordb: 加载的数据库。
    """
    vectordb = Chroma(
        persist_directory=path,
        embedding_function=embeddings
    )
    return vectordb


if __name__ == "__main__":
    create_db(embeddings="m3e")