Spaces:
Sleeping
Sleeping
import json | |
import os | |
import uuid | |
from datetime import datetime | |
from typing import List, Dict, Optional | |
from tinydb import TinyDB, Query | |
from sentence_transformers import SentenceTransformer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from threading import Lock | |
# === Constants === | |
HISTORY_FILE = "history_backup.json" | |
MEMORY_DB_PATH = "memory.json" | |
# === Persistent Memory with Session Tokens === | |
class PersistentMemory: | |
def __init__(self, path: str = MEMORY_DB_PATH): | |
self.db = TinyDB(path) | |
self.lock = Lock() | |
def add(self, session_id: str, user_msg: str, bot_msg: str) -> None: | |
with self.lock: | |
self.db.insert({ | |
"session_id": session_id, | |
"user": user_msg, | |
"bot": bot_msg, | |
"timestamp": datetime.utcnow().isoformat() | |
}) | |
def get_last(self, session_id: str, n: int = 5) -> str: | |
with self.lock: | |
items = self.db.search(Query().session_id == session_id)[-n:] | |
return "\n".join(f"User: {x['user']}\nAI: {x['bot']}" for x in items) | |
def clear(self, session_id: Optional[str] = None) -> None: | |
with self.lock: | |
if session_id: | |
self.db.remove(Query().session_id == session_id) | |
else: | |
self.db.truncate() | |
def all(self, session_id: Optional[str] = None) -> List[Dict]: | |
with self.lock: | |
return self.db.search(Query().session_id == session_id) if session_id else self.db.all() | |
# === JSON-Backed In-Memory Chat History with Sessions === | |
class ChatHistory: | |
def __init__(self, backup_path: str = HISTORY_FILE): | |
self.histories: Dict[str, List[Dict[str, str]]] = {} | |
self.backup_path = backup_path | |
self.lock = Lock() | |
self.load() | |
def add(self, session_id: str, role: str, message: str) -> None: | |
with self.lock: | |
self.histories.setdefault(session_id, []).append({ | |
"role": role, | |
"message": message, | |
"timestamp": datetime.utcnow().isoformat() | |
}) | |
self.save() | |
def get_all(self, session_id: str) -> List[Dict[str, str]]: | |
return self.histories.get(session_id, []) | |
def save(self) -> None: | |
with self.lock: | |
with open(self.backup_path, "w", encoding="utf-8") as f: | |
json.dump(self.histories, f, indent=2) | |
def load(self) -> None: | |
if os.path.exists(self.backup_path): | |
with open(self.backup_path, "r", encoding="utf-8") as f: | |
self.histories = json.load(f) | |
def export_text(self, session_id: str) -> str: | |
history = self.histories.get(session_id, []) | |
return "\n".join(f"{entry['role']} ({entry['timestamp']}): {entry['message']}" for entry in history) | |
def search(self, session_id: str, query: str) -> List[Dict[str, str]]: | |
return [ | |
entry for entry in self.histories.get(session_id, []) | |
if query.lower() in entry["message"].lower() | |
] | |
# === Semantic Search with Session Context === | |
class SemanticSearch: | |
def __init__(self, model_name: str = "all-MiniLM-L6-v2"): | |
self.model = SentenceTransformer(model_name) | |
self.session_histories: Dict[str, List[Dict[str, str]]] = {} | |
def add_to_history(self, session_id: str, role: str, message: str) -> None: | |
self.session_histories.setdefault(session_id, []).append({ | |
"role": role, | |
"message": message | |
}) | |
def semantic_search(self, session_id: str, query: str, top_k: int = 3) -> List[Dict[str, str]]: | |
history = self.session_histories.get(session_id, []) | |
if not history: | |
return [] | |
docs = [entry["message"] for entry in history] | |
embeddings = self.model.encode(docs + [query], convert_to_tensor=True) | |
query_vec = embeddings[-1].unsqueeze(0) | |
doc_vecs = embeddings[:-1] | |
sims = cosine_similarity(query_vec, doc_vecs)[0] | |
top_indices = sims.argsort()[-top_k:][::-1] | |
return [history[i] for i in top_indices] | |
def export_history(self, session_id: str) -> str: | |
return "\n".join( | |
f"{m['role']}: {m['message']}" for m in self.session_histories.get(session_id, []) | |
) | |
# === Singleton Instances === | |
persistent_memory = PersistentMemory() | |
chat_history = ChatHistory() | |
semantic_search = SemanticSearch() | |
# === Unified Session Chat API === | |
def create_session_id() -> str: | |
return str(uuid.uuid4()) | |
def add_chat_message(session_id: str, user_msg: str, bot_msg: str) -> None: | |
persistent_memory.add(session_id, user_msg, bot_msg) | |
chat_history.add(session_id, "User", user_msg) | |
chat_history.add(session_id, "AI", bot_msg) | |
semantic_search.add_to_history(session_id, "User", user_msg) | |
semantic_search.add_to_history(session_id, "AI", bot_msg) | |
def get_recent_conversation(session_id: str, n: int = 5) -> str: | |
return persistent_memory.get_last(session_id, n) | |
def export_full_history_text(session_id: str) -> str: | |
return chat_history.export_text(session_id) | |
def search_chat_history_simple(session_id: str, query: str) -> List[Dict[str, str]]: | |
return chat_history.search(session_id, query) | |
def search_chat_history_semantic(session_id: str, query: str, top_k: int = 3) -> List[Dict[str, str]]: | |
return semantic_search.semantic_search(session_id, query, top_k) | |
session = create_session_id() | |
add_chat_message(session, "What is LangChain?", "LangChain is a framework for developing applications powered by LLMs.") | |
add_chat_message(session, "What is OpenAI?", "OpenAI is an AI research lab behind ChatGPT.") | |
print(get_recent_conversation(session)) | |
print(search_chat_history_simple(session, "LangChain")) | |
print(search_chat_history_semantic(session, "framework")) | |