SANGUO / database /create_db.py
konghuan's picture
1
45c5d09
import os
import sys
import re
import json
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import tempfile
from dotenv import load_dotenv, find_dotenv
from embedding.call_embedding import get_embedding
from langchain.document_loaders import UnstructuredFileLoader
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.document_loaders import PyMuPDFLoader
from langchain.document_loaders import UnstructuredWordDocumentLoader
from langchain.vectorstores import Chroma
from langchain.schema import Document
# 禁用 Pebblo 安全模块
# os.environ["PEBBLO_DISABLED"] = "1" # 新增环境变量
# 设置模型缓存目录
CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
# os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
# 修改环境变量设置(保留一个即可)
os.environ['HF_HOME'] = CACHE_DIR
# 首先实现基本配置
# 原代码
DEFAULT_DB_PATH = "./knowledge_db/sanguo_characters"
DEFAULT_PERSIST_PATH = "./vector_db/chroma_sanguo"
class CharacterTextSplitter(TextSplitter):
"""专门用于处理角色JSON数据的文本分割器"""
def split_text(self, text: str) -> list[str]:
# 使用更健壮的正则表达式匹配每个角色的JSON数据
pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
matches = re.finditer(pattern, text)
# 将每个匹配的JSON字符串转换为文本块
chunks = []
for match in matches:
try:
# 解析JSON数据
char_data = json.loads(match.group())
# 检查必要字段
if 'name' not in char_data:
print(f"警告:发现缺少name字段的JSON数据: {match.group()[:100]}...")
continue
# 处理技能数据,将stamina_cost转换为endurance_cost
if 'skills' in char_data:
for skill in char_data['skills']:
if 'stamina_cost' in skill:
skill['endurance_cost'] = skill.pop('stamina_cost')
# 将JSON数据转换为易读的文本格式
char_text = f"角色:{char_data['name']}\n"
char_text += f"攻击力:{char_data['attack']}\n"
char_text += f"防御力:{char_data['defense']}\n"
char_text += f"体力:{char_data['stamina']}\n"
char_text += f"耐力:{char_data['endurance']}\n"
char_text += f"法力:{char_data['mana']}\n"
char_text += f"闪避:{char_data['dodge']}\n"
char_text += f"速度:{char_data['speed']}\n"
char_text += "技能:\n"
for skill in char_data['skills']:
char_text += f"- {skill['name']}{skill['effect']}\n"
if 'endurance_cost' in skill and 'mana_cost' in skill:
char_text += f" 耐力消耗:{skill['endurance_cost']},法力消耗:{skill['mana_cost']}\n"
chunks.append(char_text)
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}")
print(f"问题数据: {match.group()[:100]}...")
continue
except KeyError as e:
print(f"缺少字段: {e}")
print(f"问题数据: {match.group()[:100]}...")
continue
return chunks
def split_documents(self, documents: list[Document]) -> list[Document]:
"""分割文档列表"""
texts = []
metadatas = []
for doc in documents:
texts.extend(self.split_text(doc.page_content))
metadatas.extend([doc.metadata] * len(self.split_text(doc.page_content)))
return [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
def get_files(dir_path):
file_list = []
for filepath, dirnames, filenames in os.walk(dir_path):
for filename in filenames:
file_list.append(os.path.join(filepath, filename))
return file_list
def file_loader(file, loaders):
if isinstance(file, tempfile._TemporaryFileWrapper):
file = file.name
if not os.path.isfile(file):
[file_loader(os.path.join(file, f), loaders) for f in os.listdir(file)]
return
file_type = file.split('.')[-1].lower()
if file_type == 'pdf':
loaders.append(PyMuPDFLoader(file))
elif file_type == 'md':
pattern = r"不存在|风控"
match = re.search(pattern, file)
if not match:
loaders.append(UnstructuredMarkdownLoader(file))
elif file_type == 'txt':
loaders.append(UnstructuredFileLoader(file))
elif file_type == 'docx':
loaders.append(UnstructuredWordDocumentLoader(file))
return
def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH):
if embeddings == 'openai' or embeddings == 'm3e' or embeddings =='zhipuai':
vectordb = create_db(files, persist_directory, embeddings)
return ""
def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="openai"):
"""
该函数用于加载文件,切分文档,生成文档的嵌入向量,创建向量数据库。
参数:
file: 存放文件的路径。
embeddings: 用于生产 Embedding 的模型
返回:
vectordb: 创建的数据库。
"""
if files == None:
return "can't load empty file"
if type(files) != list:
files = [files]
print(f"正在处理文件路径: {files}")
loaders = []
[file_loader(file, loaders) for file in files]
print(f"找到的加载器数量: {len(loaders)}")
docs = []
for loader in loaders:
if loader is not None:
loaded_docs = loader.load()
print(f"\n加载的文档数量: {len(loaded_docs)}")
# 打印第一个文档的内容示例
if loaded_docs:
print("\n文档内容示例:")
print("-" * 50)
print(loaded_docs[0].page_content[:500]) # 只打印前500个字符
print("-" * 50)
print("\n文档元数据:")
print(loaded_docs[0].metadata)
print("-" * 50)
docs.extend(loaded_docs)
print(f"\n总文档数量: {len(docs)}")
if len(docs) == 0:
print("警告:没有找到任何文档!")
return None
# 使用自定义的角色文本分割器
text_splitter = CharacterTextSplitter()
split_docs = text_splitter.split_documents(docs)
print(f"\n分割后的文档数量: {len(split_docs)}")
if len(split_docs) == 0:
print("警告:分割后没有文档!")
return None
# 保存分割后的文档到文件
split_docs_dir = os.path.join(os.path.dirname(persist_directory), "split_docs")
os.makedirs(split_docs_dir, exist_ok=True)
split_docs_file = os.path.join(split_docs_dir, "split_documents.txt")
with open(split_docs_file, "w", encoding="utf-8") as f:
for i, doc in enumerate(split_docs, 1):
f.write(f"\n文档 {i}:\n")
f.write("-" * 50 + "\n")
f.write(doc.page_content)
f.write("\n" + "-" * 50 + "\n")
print(f"\n分割后的文档已保存到: {split_docs_file}")
if type(embeddings) == str:
embeddings = get_embedding(embedding=embeddings)
# 修正参数名称和初始化方式
vectordb = Chroma.from_documents(
documents=split_docs,
embedding=embeddings,
persist_directory=persist_directory,
collection_metadata={"hnsw:space": "cosine"} # 新增元数据配置
)
vectordb.persist()
return vectordb
def presit_knowledge_db(vectordb):
"""
该函数用于持久化向量数据库。
参数:
vectordb: 要持久化的向量数据库。
"""
vectordb.persist()
def load_knowledge_db(path, embeddings):
"""
该函数用于加载向量数据库。
参数:
path: 要加载的向量数据库路径。
embeddings: 向量数据库使用的 embedding 模型。
返回:
vectordb: 加载的数据库。
"""
vectordb = Chroma(
persist_directory=path,
embedding_function=embeddings
)
return vectordb
if __name__ == "__main__":
create_db(embeddings="m3e")