import streamlit as st
from transformers import pipeline
import torch
import PyPDF2
from io import BytesIO

st.set_page_config(
    page_title="TextSphere",
    page_icon="🤖",
    layout="wide",
    initial_sidebar_state="expanded"
)

st.markdown("""
    <style>
        .footer {
            position: fixed;
            bottom: 0;
            right: 0;
            padding: 10px;
            font-size: 16px;
            color: #333;
            background-color: #f1f1f1;
        }
    </style>
    <div class="footer">
        Made with ❤️ by Baibhav Malviya
    </div>
""", unsafe_allow_html=True)


@st.cache_resource
def load_models():
    try:
        text_classification_model = pipeline(
            "text-classification",
            model="distilbert-base-uncased-finetuned-sst-2-english"
        )

        question_answering_model = pipeline(
            "question-answering",
            model="distilbert-base-uncased-distilled-squad"
        )

        translation_model = pipeline(
            "translation",
            model="Helsinki-NLP/opus-mt-en-fr"
        )

        summarization_model = pipeline(
            "summarization",
            model="facebook/bart-large-cnn"
        )

    except Exception as e:
        raise RuntimeError(f"Failed to load models: {str(e)}")

    return text_classification_model, question_answering_model, translation_model, summarization_model

def extract_text_from_pdf(uploaded_pdf):
    try:
        pdf_reader = PyPDF2.PdfReader(uploaded_pdf)
        pdf_text = ""
        for page_num in range(len(pdf_reader.pages)):
            page = pdf_reader.pages[page_num]
            pdf_text += page.extract_text()
        return pdf_text
    except Exception as e:
        st.error(f"Error reading the PDF: {e}")
        return None


try:
    classification_model, qa_model, translation_model, summarization_model = load_models()
except Exception as e:
    st.error(f"An error occurred while loading models: {e}")

st.sidebar.title("AI Solutions")
option = st.sidebar.selectbox(
    "Choose a task",
    ["Question Answering", "Text Classification", "Language Translation", "Text Summarization"]
)

if option == "Question Answering":
    st.title("Question Answering")
    st.markdown("<h4 style='font-size: 20px;'>- because Google wasn't enough 😉</h4>", unsafe_allow_html=True)
    uploaded_pdf = st.file_uploader("Upload a PDF file (optional)", type="pdf")
    
    context_input = st.text_area("Enter context (a paragraph of text, or leave empty if using PDF):")
    question = st.text_input("Enter your question:")

    if uploaded_pdf:
        context_input = extract_text_from_pdf(uploaded_pdf)
    
    if st.button("Get Answer"):
        with st.spinner('Getting answer...'):
            try:
                if context_input and question:
                    answer = qa_model(question=question, context=context_input)
                    st.write("Answer:", answer['answer'])

                    st.balloons()
                else:
                    st.error("Please enter both context and a question.")
            except Exception as e:
                st.error(f"An error occurred: {e}")

elif option == "Text Classification":
    st.title("Text Classification")
    st.markdown("<h4 style='font-size: 20px;'>- where machines learn to hate spam as much we do 😅</h4>", unsafe_allow_html=True)
    text = st.text_area("Enter text for classification:")
    if st.button("Classify Text"):
        with st.spinner('Classifying text...'):
            try:
                classification = classification_model(text)
                st.json(classification)

                st.balloons()
            except Exception as e:
                st.error(f"An error occurred: {e}")

elif option == "Language Translation":
    st.title("Language Translation (English to Multiple Languages)")
    st.markdown("<h4 style='font-size: 20px;'>- when 'translate' is the only button you know 😁</h4>", unsafe_allow_html=True)
    target_language = st.selectbox("Choose target language", ["French", "Spanish", "German", "Italian", "Portuguese", "Hindi"])
    
    language_models = {
        "French": "Helsinki-NLP/opus-mt-en-fr",
        "Spanish": "Helsinki-NLP/opus-mt-en-es",
        "German": "Helsinki-NLP/opus-mt-en-de",
        "Italian": "Helsinki-NLP/opus-mt-en-it",
        "Portuguese": "Helsinki-NLP/opus-mt-en-pt",
        "Hindi": "Helsinki-NLP/opus-mt-en-hi"
    }

    selected_model = language_models.get(target_language)
    if selected_model:
        translation_model = pipeline("translation", model=selected_model)

    text_to_translate = st.text_area(f"Enter text to translate from English to {target_language}:")
    if st.button("Translate"):
        with st.spinner('Translating text...'):
            try:
                if text_to_translate:
                    translated_text = translation_model(text_to_translate)
                    st.write(f"Translated Text ({target_language}):", translated_text[0]['translation_text'])
                    
                    st.balloons()
                else:
                    st.error("Please enter text to translate.")
            except Exception as e:
                st.error(f"An error occurred: {e}")

elif option == "Text Summarization":
    st.title("Text Summarization")
    st.markdown("<h4 style='font-size: 20px;'>- because who needs to read the whole article, anyway? 🥵</h4>", unsafe_allow_html=True)
    uploaded_pdf = st.file_uploader("Upload a PDF file (optional)", type="pdf")
    
    text_to_summarize = st.text_area("Enter text to summarize (or leave empty if using PDF):")

    if uploaded_pdf:
        text_to_summarize = extract_text_from_pdf(uploaded_pdf)

    if st.button("Summarize"):
        with st.spinner('Summarizing text...'):
            try:
                if text_to_summarize:
                    summary = summarization_model(text_to_summarize, max_length=130, min_length=30, do_sample=False)
                    st.write("Summary:", summary[0]['summary_text'])

                    st.balloons()
                else:
                    st.error("Please enter text or upload a PDF for summarization.")
            except Exception as e:
                st.error(f"An error occurred: {e}")