import os import sys import re import json sys.path.append(os.path.dirname(os.path.dirname(__file__))) import tempfile from dotenv import load_dotenv, find_dotenv from embedding.call_embedding import get_embedding from langchain.document_loaders import UnstructuredFileLoader from langchain.document_loaders import UnstructuredMarkdownLoader from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter from langchain.document_loaders import PyMuPDFLoader from langchain.document_loaders import UnstructuredWordDocumentLoader from langchain.vectorstores import Chroma from langchain.schema import Document # 禁用 Pebblo 安全模块 # os.environ["PEBBLO_DISABLED"] = "1" # 新增环境变量 # 设置模型缓存目录 CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") # os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR # 修改环境变量设置(保留一个即可) os.environ['HF_HOME'] = CACHE_DIR # 首先实现基本配置 # 原代码 DEFAULT_DB_PATH = "./knowledge_db/sanguo_characters" DEFAULT_PERSIST_PATH = "./vector_db/chroma_sanguo" class CharacterTextSplitter(TextSplitter): """专门用于处理角色JSON数据的文本分割器""" def split_text(self, text: str) -> list[str]: # 使用更健壮的正则表达式匹配每个角色的JSON数据 pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}' matches = re.finditer(pattern, text) # 将每个匹配的JSON字符串转换为文本块 chunks = [] for match in matches: try: # 解析JSON数据 char_data = json.loads(match.group()) # 检查必要字段 if 'name' not in char_data: print(f"警告:发现缺少name字段的JSON数据: {match.group()[:100]}...") continue # 处理技能数据,将stamina_cost转换为endurance_cost if 'skills' in char_data: for skill in char_data['skills']: if 'stamina_cost' in skill: skill['endurance_cost'] = skill.pop('stamina_cost') # 将JSON数据转换为易读的文本格式 char_text = f"角色:{char_data['name']}\n" char_text += f"攻击力:{char_data['attack']}\n" char_text += f"防御力:{char_data['defense']}\n" char_text += f"体力:{char_data['stamina']}\n" char_text += f"耐力:{char_data['endurance']}\n" char_text += f"法力:{char_data['mana']}\n" char_text += f"闪避:{char_data['dodge']}\n" char_text += f"速度:{char_data['speed']}\n" char_text += "技能:\n" for skill in char_data['skills']: char_text += f"- {skill['name']}:{skill['effect']}\n" if 'endurance_cost' in skill and 'mana_cost' in skill: char_text += f" 耐力消耗:{skill['endurance_cost']},法力消耗:{skill['mana_cost']}\n" chunks.append(char_text) except json.JSONDecodeError as e: print(f"JSON解析错误: {e}") print(f"问题数据: {match.group()[:100]}...") continue except KeyError as e: print(f"缺少字段: {e}") print(f"问题数据: {match.group()[:100]}...") continue return chunks def split_documents(self, documents: list[Document]) -> list[Document]: """分割文档列表""" texts = [] metadatas = [] for doc in documents: texts.extend(self.split_text(doc.page_content)) metadatas.extend([doc.metadata] * len(self.split_text(doc.page_content))) return [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] def get_files(dir_path): file_list = [] for filepath, dirnames, filenames in os.walk(dir_path): for filename in filenames: file_list.append(os.path.join(filepath, filename)) return file_list def file_loader(file, loaders): if isinstance(file, tempfile._TemporaryFileWrapper): file = file.name if not os.path.isfile(file): [file_loader(os.path.join(file, f), loaders) for f in os.listdir(file)] return file_type = file.split('.')[-1].lower() if file_type == 'pdf': loaders.append(PyMuPDFLoader(file)) elif file_type == 'md': pattern = r"不存在|风控" match = re.search(pattern, file) if not match: loaders.append(UnstructuredMarkdownLoader(file)) elif file_type == 'txt': loaders.append(UnstructuredFileLoader(file)) elif file_type == 'docx': loaders.append(UnstructuredWordDocumentLoader(file)) return def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH): if embeddings == 'openai' or embeddings == 'm3e' or embeddings =='zhipuai': vectordb = create_db(files, persist_directory, embeddings) return "" def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="openai"): """ 该函数用于加载文件,切分文档,生成文档的嵌入向量,创建向量数据库。 参数: file: 存放文件的路径。 embeddings: 用于生产 Embedding 的模型 返回: vectordb: 创建的数据库。 """ if files == None: return "can't load empty file" if type(files) != list: files = [files] print(f"正在处理文件路径: {files}") loaders = [] [file_loader(file, loaders) for file in files] print(f"找到的加载器数量: {len(loaders)}") docs = [] for loader in loaders: if loader is not None: loaded_docs = loader.load() print(f"\n加载的文档数量: {len(loaded_docs)}") # 打印第一个文档的内容示例 if loaded_docs: print("\n文档内容示例:") print("-" * 50) print(loaded_docs[0].page_content[:500]) # 只打印前500个字符 print("-" * 50) print("\n文档元数据:") print(loaded_docs[0].metadata) print("-" * 50) docs.extend(loaded_docs) print(f"\n总文档数量: {len(docs)}") if len(docs) == 0: print("警告:没有找到任何文档!") return None # 使用自定义的角色文本分割器 text_splitter = CharacterTextSplitter() split_docs = text_splitter.split_documents(docs) print(f"\n分割后的文档数量: {len(split_docs)}") if len(split_docs) == 0: print("警告:分割后没有文档!") return None # 保存分割后的文档到文件 split_docs_dir = os.path.join(os.path.dirname(persist_directory), "split_docs") os.makedirs(split_docs_dir, exist_ok=True) split_docs_file = os.path.join(split_docs_dir, "split_documents.txt") with open(split_docs_file, "w", encoding="utf-8") as f: for i, doc in enumerate(split_docs, 1): f.write(f"\n文档 {i}:\n") f.write("-" * 50 + "\n") f.write(doc.page_content) f.write("\n" + "-" * 50 + "\n") print(f"\n分割后的文档已保存到: {split_docs_file}") if type(embeddings) == str: embeddings = get_embedding(embedding=embeddings) # 修正参数名称和初始化方式 vectordb = Chroma.from_documents( documents=split_docs, embedding=embeddings, persist_directory=persist_directory, collection_metadata={"hnsw:space": "cosine"} # 新增元数据配置 ) vectordb.persist() return vectordb def presit_knowledge_db(vectordb): """ 该函数用于持久化向量数据库。 参数: vectordb: 要持久化的向量数据库。 """ vectordb.persist() def load_knowledge_db(path, embeddings): """ 该函数用于加载向量数据库。 参数: path: 要加载的向量数据库路径。 embeddings: 向量数据库使用的 embedding 模型。 返回: vectordb: 加载的数据库。 """ vectordb = Chroma( persist_directory=path, embedding_function=embeddings ) return vectordb if __name__ == "__main__": create_db(embeddings="m3e")