Spaces:
Running
Running
import os | |
import json | |
import fitz # PyMuPDF | |
import re | |
from tqdm import tqdm | |
from docx import Document | |
from PIL import Image | |
import pytesseract | |
import io | |
import torch | |
import chromadb | |
from sentence_transformers import SentenceTransformer, util | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import gradio as gr | |
# --------------------------- | |
# π Configuration | |
# --------------------------- | |
MANUALS_FOLDER = "./Manuals" | |
CHROMA_PATH = "./chroma_store" | |
COLLECTION_NAME = "manual_chunks" | |
CHUNK_SIZE = 750 | |
CHUNK_OVERLAP = 100 | |
MAX_CONTEXT_CHUNKS = 3 | |
HF_MODEL = "meta-llama/Llama-3.1-8B-Instruct" | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
# --------------------------- | |
# π§Ή Helpers | |
# --------------------------- | |
def clean(text): | |
lines = text.splitlines() | |
return "\n".join(line.strip() for line in lines if line.strip()) | |
def split_sentences(text): | |
return re.split(r'(?<=[.!?])\s+', text.strip()) | |
def chunk_sentences(sentences, max_len=CHUNK_SIZE, overlap=CHUNK_OVERLAP): | |
chunks, chunk, length = [], [], 0 | |
for sent in sentences: | |
tokens = len(sent.split()) | |
if length + tokens > max_len and chunk: | |
chunks.append(" ".join(chunk)) | |
chunk = chunk[-overlap:] if overlap else [] | |
length = sum(len(s.split()) for s in chunk) | |
chunk.append(sent) | |
length += tokens | |
if chunk: | |
chunks.append(" ".join(chunk)) | |
return chunks | |
def extract_text_from_pdf(path): | |
doc = fitz.open(path) | |
full_text = [] | |
for page in doc: | |
text = page.get_text().strip() | |
if not text: | |
try: | |
pix = page.get_pixmap(dpi=300) | |
img_data = pix.tobytes("png") | |
img = Image.open(io.BytesIO(img_data)) | |
text = pytesseract.image_to_string(img).strip() | |
except Exception: | |
text = "" | |
full_text.append(text) | |
return "\n".join(full_text) | |
def extract_text_from_docx(path): | |
doc = Document(path) | |
return "\n".join([para.text for para in doc.paragraphs if para.text.strip()]) | |
def extract_metadata(filename): | |
name = filename.lower() | |
model = next((m for m in ["se3hd", "se3", "se4", "symbio", "explore", "integrity x", "integrity sl", "everest", "engage", "inspire", "discover", "95t", "95x", "95c", "95r", "97c"] if m in name), "unknown") | |
if "om" in name or "owner" in name: | |
doc_type = "owner manual" | |
elif "sm" in name or "service" in name: | |
doc_type = "service manual" | |
elif "assembly" in name: | |
doc_type = "assembly instructions" | |
elif "alert" in name: | |
doc_type = "installer alert" | |
elif "parts" in name: | |
doc_type = "parts manual" | |
elif "bulletin" in name: | |
doc_type = "service bulletin" | |
else: | |
doc_type = "unknown" | |
return model, doc_type | |
# --------------------------- | |
# π Build ChromaDB at Startup | |
# --------------------------- | |
def embed_all(): | |
client = chromadb.PersistentClient(path=CHROMA_PATH) | |
if COLLECTION_NAME in [c.name for c in client.list_collections()]: | |
client.delete_collection(COLLECTION_NAME) | |
collection = client.create_collection(COLLECTION_NAME) | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
records = [] | |
for fname in os.listdir(MANUALS_FOLDER): | |
path = os.path.join(MANUALS_FOLDER, fname) | |
if not fname.lower().endswith((".pdf", ".docx")): | |
continue | |
text = extract_text_from_pdf(path) if fname.endswith(".pdf") else extract_text_from_docx(path) | |
sents = split_sentences(clean(text)) | |
chunks = chunk_sentences(sents) | |
model, doc_type = extract_metadata(fname) | |
for i, chunk in enumerate(chunks): | |
records.append({ | |
"id": f"{fname}::chunk_{i+1}", | |
"text": chunk, | |
"metadata": {"source_file": fname, "model": model, "doc_type": doc_type} | |
}) | |
for i in range(0, len(records), 16): | |
batch = records[i:i+16] | |
texts = [r["text"] for r in batch] | |
ids = [r["id"] for r in batch] | |
metas = [r["metadata"] for r in batch] | |
embeddings = embedder.encode(texts).tolist() | |
collection.add(documents=texts, ids=ids, metadatas=metas, embeddings=embeddings) | |
return collection, embedder | |
# --------------------------- | |
# π¬ Load HF Model | |
# --------------------------- | |
llm_pipe = None | |
if HF_TOKEN: | |
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_TOKEN) | |
model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_TOKEN, torch_dtype=torch.float32) | |
llm_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=-1) | |
# --------------------------- | |
# π RAG Function | |
# --------------------------- | |
def run_query(question): | |
if not question.strip(): | |
return "Please enter a question." | |
if not db or not embedder: | |
return "Chroma or embedder not ready." | |
q_embed = embedder.encode(question).tolist() | |
res = db.query(query_embeddings=[q_embed], n_results=MAX_CONTEXT_CHUNKS) | |
contexts = res["documents"][0] | |
prompt = """ | |
You are a technical assistant. | |
Answer only using the context below. | |
Say 'I don't know' if not found. | |
""" | |
context_text = "\n\n".join(contexts) | |
final_prompt = prompt + f"Context:\n{context_text}\n\nQuestion: {question}\nAnswer:" | |
if llm_pipe: | |
result = llm_pipe(final_prompt, max_new_tokens=300)[0]['generated_text'] | |
return result.split("Answer:")[-1].strip() | |
return "Model not loaded." | |
# --------------------------- | |
# π§ Init embeddings once | |
# --------------------------- | |
db, embedder = embed_all() | |
# --------------------------- | |
# ποΈ Gradio Interface | |
# --------------------------- | |
with gr.Blocks() as demo: | |
gr.Markdown("# π€ SmartManuals-AI: Ask Technical Questions about Your Manuals") | |
question = gr.Textbox(placeholder="e.g. How do I reset the treadmill console?", label="Enter Question") | |
submit = gr.Button("Get Answer") | |
output = gr.Textbox(label="Answer") | |
submit.click(fn=run_query, inputs=question, outputs=output) | |
demo.launch() | |