|
import os |
|
import sys |
|
import re |
|
import json |
|
sys.path.append(os.path.dirname(os.path.dirname(__file__))) |
|
import tempfile |
|
from dotenv import load_dotenv, find_dotenv |
|
from embedding.call_embedding import get_embedding |
|
from langchain.document_loaders import UnstructuredFileLoader |
|
from langchain.document_loaders import UnstructuredMarkdownLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter |
|
from langchain.document_loaders import PyMuPDFLoader |
|
from langchain.document_loaders import UnstructuredWordDocumentLoader |
|
from langchain.vectorstores import Chroma |
|
from langchain.schema import Document |
|
|
|
|
|
|
|
|
|
|
|
CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") |
|
|
|
|
|
os.environ['HF_HOME'] = CACHE_DIR |
|
|
|
|
|
|
|
DEFAULT_DB_PATH = "./knowledge_db/sanguo_characters" |
|
DEFAULT_PERSIST_PATH = "./vector_db/chroma_sanguo" |
|
|
|
|
|
|
|
|
|
class CharacterTextSplitter(TextSplitter): |
|
"""专门用于处理角色JSON数据的文本分割器""" |
|
|
|
def split_text(self, text: str) -> list[str]: |
|
|
|
pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}' |
|
matches = re.finditer(pattern, text) |
|
|
|
|
|
chunks = [] |
|
for match in matches: |
|
try: |
|
|
|
char_data = json.loads(match.group()) |
|
|
|
|
|
if 'name' not in char_data: |
|
print(f"警告:发现缺少name字段的JSON数据: {match.group()[:100]}...") |
|
continue |
|
|
|
|
|
if 'skills' in char_data: |
|
for skill in char_data['skills']: |
|
if 'stamina_cost' in skill: |
|
skill['endurance_cost'] = skill.pop('stamina_cost') |
|
|
|
|
|
char_text = f"角色:{char_data['name']}\n" |
|
char_text += f"攻击力:{char_data['attack']}\n" |
|
char_text += f"防御力:{char_data['defense']}\n" |
|
char_text += f"体力:{char_data['stamina']}\n" |
|
char_text += f"耐力:{char_data['endurance']}\n" |
|
char_text += f"法力:{char_data['mana']}\n" |
|
char_text += f"闪避:{char_data['dodge']}\n" |
|
char_text += f"速度:{char_data['speed']}\n" |
|
char_text += "技能:\n" |
|
for skill in char_data['skills']: |
|
char_text += f"- {skill['name']}:{skill['effect']}\n" |
|
if 'endurance_cost' in skill and 'mana_cost' in skill: |
|
char_text += f" 耐力消耗:{skill['endurance_cost']},法力消耗:{skill['mana_cost']}\n" |
|
chunks.append(char_text) |
|
except json.JSONDecodeError as e: |
|
print(f"JSON解析错误: {e}") |
|
print(f"问题数据: {match.group()[:100]}...") |
|
continue |
|
except KeyError as e: |
|
print(f"缺少字段: {e}") |
|
print(f"问题数据: {match.group()[:100]}...") |
|
continue |
|
return chunks |
|
|
|
def split_documents(self, documents: list[Document]) -> list[Document]: |
|
"""分割文档列表""" |
|
texts = [] |
|
metadatas = [] |
|
for doc in documents: |
|
texts.extend(self.split_text(doc.page_content)) |
|
metadatas.extend([doc.metadata] * len(self.split_text(doc.page_content))) |
|
return [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)] |
|
|
|
|
|
def get_files(dir_path): |
|
file_list = [] |
|
for filepath, dirnames, filenames in os.walk(dir_path): |
|
for filename in filenames: |
|
file_list.append(os.path.join(filepath, filename)) |
|
return file_list |
|
|
|
|
|
def file_loader(file, loaders): |
|
if isinstance(file, tempfile._TemporaryFileWrapper): |
|
file = file.name |
|
if not os.path.isfile(file): |
|
[file_loader(os.path.join(file, f), loaders) for f in os.listdir(file)] |
|
return |
|
file_type = file.split('.')[-1].lower() |
|
if file_type == 'pdf': |
|
loaders.append(PyMuPDFLoader(file)) |
|
elif file_type == 'md': |
|
pattern = r"不存在|风控" |
|
match = re.search(pattern, file) |
|
if not match: |
|
loaders.append(UnstructuredMarkdownLoader(file)) |
|
elif file_type == 'txt': |
|
loaders.append(UnstructuredFileLoader(file)) |
|
elif file_type == 'docx': |
|
loaders.append(UnstructuredWordDocumentLoader(file)) |
|
return |
|
|
|
|
|
def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH): |
|
if embeddings == 'openai' or embeddings == 'm3e' or embeddings =='zhipuai': |
|
vectordb = create_db(files, persist_directory, embeddings) |
|
return "" |
|
|
|
|
|
def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="openai"): |
|
""" |
|
该函数用于加载文件,切分文档,生成文档的嵌入向量,创建向量数据库。 |
|
|
|
参数: |
|
file: 存放文件的路径。 |
|
embeddings: 用于生产 Embedding 的模型 |
|
|
|
返回: |
|
vectordb: 创建的数据库。 |
|
""" |
|
if files == None: |
|
return "can't load empty file" |
|
if type(files) != list: |
|
files = [files] |
|
|
|
print(f"正在处理文件路径: {files}") |
|
|
|
loaders = [] |
|
[file_loader(file, loaders) for file in files] |
|
print(f"找到的加载器数量: {len(loaders)}") |
|
|
|
docs = [] |
|
for loader in loaders: |
|
if loader is not None: |
|
loaded_docs = loader.load() |
|
print(f"\n加载的文档数量: {len(loaded_docs)}") |
|
|
|
if loaded_docs: |
|
print("\n文档内容示例:") |
|
print("-" * 50) |
|
print(loaded_docs[0].page_content[:500]) |
|
print("-" * 50) |
|
print("\n文档元数据:") |
|
print(loaded_docs[0].metadata) |
|
print("-" * 50) |
|
docs.extend(loaded_docs) |
|
|
|
print(f"\n总文档数量: {len(docs)}") |
|
|
|
if len(docs) == 0: |
|
print("警告:没有找到任何文档!") |
|
return None |
|
|
|
|
|
text_splitter = CharacterTextSplitter() |
|
split_docs = text_splitter.split_documents(docs) |
|
print(f"\n分割后的文档数量: {len(split_docs)}") |
|
|
|
if len(split_docs) == 0: |
|
print("警告:分割后没有文档!") |
|
return None |
|
|
|
|
|
split_docs_dir = os.path.join(os.path.dirname(persist_directory), "split_docs") |
|
os.makedirs(split_docs_dir, exist_ok=True) |
|
split_docs_file = os.path.join(split_docs_dir, "split_documents.txt") |
|
|
|
with open(split_docs_file, "w", encoding="utf-8") as f: |
|
for i, doc in enumerate(split_docs, 1): |
|
f.write(f"\n文档 {i}:\n") |
|
f.write("-" * 50 + "\n") |
|
f.write(doc.page_content) |
|
f.write("\n" + "-" * 50 + "\n") |
|
|
|
print(f"\n分割后的文档已保存到: {split_docs_file}") |
|
|
|
if type(embeddings) == str: |
|
embeddings = get_embedding(embedding=embeddings) |
|
|
|
vectordb = Chroma.from_documents( |
|
documents=split_docs, |
|
embedding=embeddings, |
|
persist_directory=persist_directory, |
|
collection_metadata={"hnsw:space": "cosine"} |
|
) |
|
|
|
vectordb.persist() |
|
return vectordb |
|
|
|
|
|
def presit_knowledge_db(vectordb): |
|
""" |
|
该函数用于持久化向量数据库。 |
|
|
|
参数: |
|
vectordb: 要持久化的向量数据库。 |
|
""" |
|
vectordb.persist() |
|
|
|
|
|
def load_knowledge_db(path, embeddings): |
|
""" |
|
该函数用于加载向量数据库。 |
|
|
|
参数: |
|
path: 要加载的向量数据库路径。 |
|
embeddings: 向量数据库使用的 embedding 模型。 |
|
|
|
返回: |
|
vectordb: 加载的数据库。 |
|
""" |
|
vectordb = Chroma( |
|
persist_directory=path, |
|
embedding_function=embeddings |
|
) |
|
return vectordb |
|
|
|
|
|
if __name__ == "__main__": |
|
create_db(embeddings="m3e") |
|
|