Spaces:
No application file
No application file
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, pipeline | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
# Configuration | |
class Config: | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
embedding_model = "all-MiniLM-L6-v2" | |
vector_dim = 384 | |
top_k = 3 | |
chunk_size = 256 | |
# Vector Database | |
class VectorDB: | |
def __init__(self): | |
self.index = faiss.IndexFlatL2(Config.vector_dim) | |
self.texts = [] | |
self.embedding_model = SentenceTransformer(Config.embedding_model) | |
def add_text(self, text: str): | |
embedding = self.embedding_model.encode([text])[0] | |
embedding = np.array([embedding], dtype=np.float32) | |
faiss.normalize_L2(embedding) | |
self.index.add(embedding) | |
self.texts.append(text) | |
def search(self, query: str): | |
if self.index.ntotal == 0: | |
return [] | |
query_embedding = self.embedding_model.encode([query])[0] | |
query_embedding = np.array([query_embedding], dtype=np.float32) | |
faiss.normalize_L2(query_embedding) | |
D, I = self.index.search(query_embedding, min(Config.top_k, self.index.ntotal)) | |
return [self.texts[i] for i in I[0] if i < len(self.texts)] | |
# Load Model | |
class TinyChatModel: | |
def __init__(self): | |
self.tokenizer = AutoTokenizer.from_pretrained(Config.model_name) | |
self.pipe = pipeline("text-generation", model=Config.model_name, torch_dtype=torch.bfloat16, device_map="auto") | |
def generate_response(self, message: str, context: str = ""): | |
messages = [{"role": "user", "content": message}] | |
if context: | |
messages.insert(0, {"role": "system", "content": f"Context:\n{context}"}) | |
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
outputs = self.pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) | |
return outputs[0]["generated_text"].split("<|assistant|>")[-1].strip() | |
# Initialize | |
vector_db = VectorDB() | |
chat_model = TinyChatModel() | |
def chat_interface(user_input): | |
context = "\n".join(vector_db.search(user_input)) | |
response = chat_model.generate_response(user_input, context) | |
vector_db.add_text(f"User: {user_input}\nAssistant: {response}") | |
return response | |
def add_text_interface(text): | |
vector_db.add_text(text) | |
return "Text added to memory!" | |
# Gradio UI | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("# 🦙 TinyChat - AI Chatbot") | |
with gr.Row(): | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
user_input = gr.Textbox(label="Your Message") | |
send_btn = gr.Button("Send") | |
with gr.Row(): | |
add_text_input = gr.Textbox(label="Add Knowledge to AI") | |
add_text_btn = gr.Button("Add Text") | |
send_btn.click(chat_interface, inputs=user_input, outputs=chatbot) | |
add_text_btn.click(add_text_interface, inputs=add_text_input, outputs=gr.Textbox()) | |
# Launch | |
if __name__ == "__main__": | |
demo.launch() | |