File size: 3,624 Bytes
dfb3884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9746dc2
dfb3884
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import tempfile
import streamlit as st
from dotenv import load_dotenv
 
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import BedrockEmbeddings
from langchain_community.chat_models import BedrockChat
from langchain.chains import RetrievalQA
import boto3
 
# Load AWS credentials from .env if available
load_dotenv()
 
# Setup AWS Bedrock runtime
bedrock_runtime = boto3.client("bedrock-runtime", region_name="us-east-1")
 
# UI setup
st.set_page_config(page_title="PDF chatbot", layout="wide")
st.title("RAG Demo - PDF Q&A")
 
st.markdown("""
1. **Upload Your Documents**: You can upload multiple PDF files for processing.
 
2. **Ask a Question**: Then ask any question based on the documents' content.
""")
 
CHROMA_PATH = os.path.join(os.getcwd(), "chroma_db")
 
 
def main():
    st.header("Ask a question")
 
    # Initialize vector store with Amazon Titan Embeddings
    embeddings = BedrockEmbeddings(
        client=bedrock_runtime,
        model_id="amazon.titan-embed-text-v1"
    )
    vectorstore = Chroma(
        persist_directory=CHROMA_PATH,
        embedding_function=embeddings
    )
 
    # Sidebar: Upload & Process PDFs
    with st.sidebar:
        st.title("Menu:")
        uploaded_files = st.file_uploader(
            "Upload PDF files and click Submit",
            accept_multiple_files=True,
            key="pdf_uploader"
        )
        if st.button("Submit & Process", key="process_button") and uploaded_files:
            with st.spinner("Processing..."):
                for uploaded_file in uploaded_files:
                    try:
                        with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
                            tmp_file.write(uploaded_file.getvalue())
                            tmp_path = tmp_file.name
 
                        loader = PyPDFLoader(tmp_path)
                        pages = loader.load()
 
                        for page in pages:
                            page.metadata["page_number"] = pages.index(page) + 1
 
                        text_splitter = RecursiveCharacterTextSplitter(
                            chunk_size=1000,
                            chunk_overlap=200,
                            separators=["\n\n", "\n", " ", ""]
                        )
                        chunks = text_splitter.split_documents(pages)
                        os.unlink(tmp_path)
 
                        vectorstore.add_documents(chunks)
                        vectorstore.persist()
 
                    except Exception as e:
                        st.error(f"Error processing {uploaded_file.name}: {str(e)}")
                        continue
 
                st.success("Vector store updated with uploaded documents.")
 
    # Main QA interface
    user_question = st.text_input("Ask a Question from the PDF Files", key="user_question")
    if user_question:
        retriever = vectorstore.as_retriever(search_type='similarity', search_kwargs={'k': 5})
 
        llm = BedrockChat(
            client=bedrock_runtime,
            model_id="anthropic.claude-v2",  # or v2:1
            model_kwargs={"temperature": 0.0}
        )
 
        chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
 
        with st.spinner("Generating answer..."):
            answer = chain.invoke({"query": user_question})
            st.write("**Reply:**", answer["result"])
 
 
if __name__ == "__main__":
    main()