File size: 8,622 Bytes
d04a18c 45c5d09 d04a18c 45c5d09 d04a18c 45c5d09 d04a18c 45c5d09 d04a18c 45c5d09 d04a18c 45c5d09 d04a18c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
import os
import sys
import re
import json
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import tempfile
from dotenv import load_dotenv, find_dotenv
from embedding.call_embedding import get_embedding
from langchain.document_loaders import UnstructuredFileLoader
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.document_loaders import PyMuPDFLoader
from langchain.document_loaders import UnstructuredWordDocumentLoader
from langchain.vectorstores import Chroma
from langchain.schema import Document
# 禁用 Pebblo 安全模块
# os.environ["PEBBLO_DISABLED"] = "1" # 新增环境变量
# 设置模型缓存目录
CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
# os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
# 修改环境变量设置(保留一个即可)
os.environ['HF_HOME'] = CACHE_DIR
# 首先实现基本配置
# 原代码
DEFAULT_DB_PATH = "./knowledge_db/sanguo_characters"
DEFAULT_PERSIST_PATH = "./vector_db/chroma_sanguo"
class CharacterTextSplitter(TextSplitter):
"""专门用于处理角色JSON数据的文本分割器"""
def split_text(self, text: str) -> list[str]:
# 使用更健壮的正则表达式匹配每个角色的JSON数据
pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
matches = re.finditer(pattern, text)
# 将每个匹配的JSON字符串转换为文本块
chunks = []
for match in matches:
try:
# 解析JSON数据
char_data = json.loads(match.group())
# 检查必要字段
if 'name' not in char_data:
print(f"警告:发现缺少name字段的JSON数据: {match.group()[:100]}...")
continue
# 处理技能数据,将stamina_cost转换为endurance_cost
if 'skills' in char_data:
for skill in char_data['skills']:
if 'stamina_cost' in skill:
skill['endurance_cost'] = skill.pop('stamina_cost')
# 将JSON数据转换为易读的文本格式
char_text = f"角色:{char_data['name']}\n"
char_text += f"攻击力:{char_data['attack']}\n"
char_text += f"防御力:{char_data['defense']}\n"
char_text += f"体力:{char_data['stamina']}\n"
char_text += f"耐力:{char_data['endurance']}\n"
char_text += f"法力:{char_data['mana']}\n"
char_text += f"闪避:{char_data['dodge']}\n"
char_text += f"速度:{char_data['speed']}\n"
char_text += "技能:\n"
for skill in char_data['skills']:
char_text += f"- {skill['name']}:{skill['effect']}\n"
if 'endurance_cost' in skill and 'mana_cost' in skill:
char_text += f" 耐力消耗:{skill['endurance_cost']},法力消耗:{skill['mana_cost']}\n"
chunks.append(char_text)
except json.JSONDecodeError as e:
print(f"JSON解析错误: {e}")
print(f"问题数据: {match.group()[:100]}...")
continue
except KeyError as e:
print(f"缺少字段: {e}")
print(f"问题数据: {match.group()[:100]}...")
continue
return chunks
def split_documents(self, documents: list[Document]) -> list[Document]:
"""分割文档列表"""
texts = []
metadatas = []
for doc in documents:
texts.extend(self.split_text(doc.page_content))
metadatas.extend([doc.metadata] * len(self.split_text(doc.page_content)))
return [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
def get_files(dir_path):
file_list = []
for filepath, dirnames, filenames in os.walk(dir_path):
for filename in filenames:
file_list.append(os.path.join(filepath, filename))
return file_list
def file_loader(file, loaders):
if isinstance(file, tempfile._TemporaryFileWrapper):
file = file.name
if not os.path.isfile(file):
[file_loader(os.path.join(file, f), loaders) for f in os.listdir(file)]
return
file_type = file.split('.')[-1].lower()
if file_type == 'pdf':
loaders.append(PyMuPDFLoader(file))
elif file_type == 'md':
pattern = r"不存在|风控"
match = re.search(pattern, file)
if not match:
loaders.append(UnstructuredMarkdownLoader(file))
elif file_type == 'txt':
loaders.append(UnstructuredFileLoader(file))
elif file_type == 'docx':
loaders.append(UnstructuredWordDocumentLoader(file))
return
def create_db_info(files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH):
if embeddings == 'openai' or embeddings == 'm3e' or embeddings =='zhipuai':
vectordb = create_db(files, persist_directory, embeddings)
return ""
def create_db(files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="openai"):
"""
该函数用于加载文件,切分文档,生成文档的嵌入向量,创建向量数据库。
参数:
file: 存放文件的路径。
embeddings: 用于生产 Embedding 的模型
返回:
vectordb: 创建的数据库。
"""
if files == None:
return "can't load empty file"
if type(files) != list:
files = [files]
print(f"正在处理文件路径: {files}")
loaders = []
[file_loader(file, loaders) for file in files]
print(f"找到的加载器数量: {len(loaders)}")
docs = []
for loader in loaders:
if loader is not None:
loaded_docs = loader.load()
print(f"\n加载的文档数量: {len(loaded_docs)}")
# 打印第一个文档的内容示例
if loaded_docs:
print("\n文档内容示例:")
print("-" * 50)
print(loaded_docs[0].page_content[:500]) # 只打印前500个字符
print("-" * 50)
print("\n文档元数据:")
print(loaded_docs[0].metadata)
print("-" * 50)
docs.extend(loaded_docs)
print(f"\n总文档数量: {len(docs)}")
if len(docs) == 0:
print("警告:没有找到任何文档!")
return None
# 使用自定义的角色文本分割器
text_splitter = CharacterTextSplitter()
split_docs = text_splitter.split_documents(docs)
print(f"\n分割后的文档数量: {len(split_docs)}")
if len(split_docs) == 0:
print("警告:分割后没有文档!")
return None
# 保存分割后的文档到文件
split_docs_dir = os.path.join(os.path.dirname(persist_directory), "split_docs")
os.makedirs(split_docs_dir, exist_ok=True)
split_docs_file = os.path.join(split_docs_dir, "split_documents.txt")
with open(split_docs_file, "w", encoding="utf-8") as f:
for i, doc in enumerate(split_docs, 1):
f.write(f"\n文档 {i}:\n")
f.write("-" * 50 + "\n")
f.write(doc.page_content)
f.write("\n" + "-" * 50 + "\n")
print(f"\n分割后的文档已保存到: {split_docs_file}")
if type(embeddings) == str:
embeddings = get_embedding(embedding=embeddings)
# 修正参数名称和初始化方式
vectordb = Chroma.from_documents(
documents=split_docs,
embedding=embeddings,
persist_directory=persist_directory,
collection_metadata={"hnsw:space": "cosine"} # 新增元数据配置
)
vectordb.persist()
return vectordb
def presit_knowledge_db(vectordb):
"""
该函数用于持久化向量数据库。
参数:
vectordb: 要持久化的向量数据库。
"""
vectordb.persist()
def load_knowledge_db(path, embeddings):
"""
该函数用于加载向量数据库。
参数:
path: 要加载的向量数据库路径。
embeddings: 向量数据库使用的 embedding 模型。
返回:
vectordb: 加载的数据库。
"""
vectordb = Chroma(
persist_directory=path,
embedding_function=embeddings
)
return vectordb
if __name__ == "__main__":
create_db(embeddings="m3e")
|