import streamlit as st
import plotly.graph_objects as go
from transformers import pipeline
from pydub import AudioSegment
import os
import re
from docx import Document
from docx.shared import Pt
from docx.enum.text import WD_PARAGRAPH_ALIGNMENT
from datetime import datetime

# Page config
st.set_page_config(page_title="Atma.ai - Session Summarizer + Export", layout="wide")

st.title("🧠 Atma.ai – Advanced Mental Health Session Summarizer")
st.markdown("Upload a therapy session audio (Tamil-English mix) to view the transcript, summary, emotional analysis, and export everything to Word!")

# Upload audio
uploaded_file = st.file_uploader("🎙️ Upload audio file", type=["wav", "mp3", "m4a"])

if uploaded_file:
    st.audio(uploaded_file)

    # Convert audio to required format
    audio_path = "temp_audio.wav"
    audio = AudioSegment.from_file(uploaded_file)
    audio = audio.set_channels(1).set_frame_rate(16000)
    audio.export(audio_path, format="wav")

    try:
        # Transcribe
        st.info("🔄 Transcribing with Whisper (mixed-language support)...")
        asr = pipeline("automatic-speech-recognition", model="openai/whisper-large")
        result = asr(audio_path, return_timestamps=True, generate_kwargs={"language": "<|en|>"})
        raw_transcript = result.get("text", "")

        if not raw_transcript:
            st.error("❌ Could not generate a transcript. Please try a different audio.")
        else:
            # Simulated Speaker Diarization
            st.info("🗣️ Simulating speaker separation...")
            sentences = re.split(r'(?<=[.?!])\s+', raw_transcript)
            diarized_transcript = ""
            for idx, sentence in enumerate(sentences):
                speaker = "Speaker 1" if idx % 2 == 0 else "Speaker 2"
                diarized_transcript += f"{speaker}: {sentence}\n\n"

            # Summarization
            st.info("📋 Summarizing conversation...")
            summarizer = pipeline("summarization", model="philschmid/bart-large-cnn-samsum")

        # Static Session Context Recall
        st.info("🧠 Referencing prior session context...")
        past_sessions = [
            {"date": "2024-04-15", "coping": "walking", "emotion": "anxiety", "notes": "high workload"},
            {"date": "2024-04-22", "coping": "journaling", "emotion": "stress", "notes": "difficulty sleeping"}
        ]
        rag_context = "\n".join([f"Session {i+1}: {s['coping']}, {s['emotion']}, {s['notes']}" for i, s in enumerate(past_sessions)])
        prompt_input = f"""Previous session context:\n{rag_context}\n\nCurrent session:\n{raw_transcript}"""
        summary = summarizer(prompt_input, max_length=256, min_length=60, do_sample=False)

        # Emotion tagging
        st.info("🎭 Extracting emotional tones...")
        emotion_model = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", return_all_scores=True)
        emotion_scores = emotion_model(raw_transcript)

        # Layout with Tabs

        with tab1:
            st.subheader("📝 Speaker-Simulated Transcript")
            st.markdown(diarized_transcript, unsafe_allow_html=True)

        with tab2:
            st.subheader("📋 Contextual Summary")

            # Insight Tracking based on previous sessions
            insights = []
            if "music" in raw_transcript.lower():
                if any("walking" in s["coping"] for s in past_sessions):
                    insights.append("Patient previously mentioned walking as a helpful coping mechanism. This time, music is highlighted instead.")
            if "sleep" in raw_transcript.lower():
                insights.append("Sleep continues to be a recurring theme across sessions.")
    
            final_output = f"{summary[0]['summary_text']}\n\nContextual Observations:\n" + "\n".join(insights)
            st.write(final_output)

        with tab3:
            st.subheader("💬 Emotional Insights (Overall)")
            if 'emotion_scores' in locals():
                for emo in emotion_scores[0]:
                    st.write(f"{emo['label']}: {round(emo['score']*100, 2)}%")
            else:
                st.write("No emotional data to display.")
            
        with tab4:
            st.subheader("📈 Emotional Trends Over Time")

            session_dates = ["2024-04-01", "2024-04-08", "2024-04-15", "2024-04-22"]
            anxiety_scores = [70, 65, 55, 40]
            sadness_scores = [30, 20, 25, 15]
        
            fig = go.Figure()
            fig.add_trace(go.Scatter(x=session_dates, y=anxiety_scores, mode='lines+markers', name='Anxiety'))
            fig.add_trace(go.Scatter(x=session_dates, y=sadness_scores, mode='lines+markers', name='Sadness'))
            fig.update_layout(title='Emotional Trends', xaxis_title='Date', yaxis_title='Score (%)')
            st.plotly_chart(fig)

            # Export Button
            st.subheader("📥 Export Session Report")

            def generate_docx(transcript, summary_text, emotions):
                doc = Document()

                # Title
                title = doc.add_heading('Session Summary - Atma.ai', 0)
                title.alignment = WD_PARAGRAPH_ALIGNMENT.CENTER

                # Date
                date_paragraph = doc.add_paragraph(f"Date: {datetime.now().strftime('%Y-%m-%d')}")
                date_paragraph.runs[0].italic = True

                doc.add_paragraph("\n")

                # Transcript
                doc.add_heading('📝 Transcript', level=1)
                transcript_para = doc.add_paragraph(transcript)
                transcript_para.runs[0].font.size = Pt(12)

                doc.add_paragraph("\n")

                # Summary
                doc.add_heading('📋 Summary', level=1)
                summary_para = doc.add_paragraph(summary_text)
                summary_para.runs[0].font.size = Pt(12)

                doc.add_paragraph("\n")

                # Emotional Insights
                doc.add_heading('💬 Emotional Insights', level=1)
                for emo in emotions[0]:
                    emotion_para = doc.add_paragraph(f"{emo['label']}: {round(emo['score']*100, 2)}%")
                    emotion_para.runs[0].font.size = Pt(12)

                # Footer
                doc.add_paragraph("\n\n---\nGenerated by Atma.ai – Confidential", style="Intense Quote")

                output_path = "session_summary.docx"
                doc.save(output_path)
                return output_path

            if st.button("Generate and Download Report (.docx)"):
                output_file = generate_docx(diarized_transcript, summary[0]["summary_text"], emotion_scores)
                with open(output_file, "rb") as f:
                    st.download_button(label="📥 Download Report", data=f, file_name="session_summary.docx", mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document")

    except Exception as err:
        st.error(f"❌ Processing failed: {err}")
    finally:
        if os.path.exists(audio_path):
            os.remove(audio_path)

        tab1, tab2, tab3, tab4 = st.tabs(["📝 Transcript", "📋 Summary", "💬 Emotions", "📈 Trends"])

        with tab1:
            st.subheader("📝 Speaker-Simulated Transcript")
            if 'diarized_transcript' in locals():
                st.markdown(diarized_transcript, unsafe_allow_html=True)
            else:
                st.warning("Transcript not available.")
    
        with tab2:
            st.subheader("📋 Contextual Summary")
            if 'summary' in locals():
        
                # Insight Tracking based on previous sessions
                insights = []
                if "music" in raw_transcript.lower():
                    if any("walking" in s["coping"] for s in past_sessions):
                        insights.append("Patient previously mentioned walking as a helpful coping mechanism. This time, music is highlighted instead.")
                if "sleep" in raw_transcript.lower():
                    insights.append("Sleep continues to be a recurring theme across sessions.")
        
                final_output = f"{summary[0]['summary_text']}\n\nContextual Observations:\n" + "\n".join(insights)
                st.write(final_output)
            else:
                st.warning("Summary not available.")
    
        with tab3:
            st.subheader("💬 Emotional Insights (Overall)")
            if 'emotion_scores' in locals():
                for emo in emotion_scores[0]:
                    st.write(f"{emo['label']}: {round(emo['score']*100, 2)}%")
            else:
                st.warning("No emotional data to display.")
    
        with tab4:
            st.subheader("📈 Emotional Trends Over Time")
            session_dates = ["2024-04-01", "2024-04-08", "2024-04-15", "2024-04-22"]
            anxiety_scores = [70, 65, 55, 40]
            sadness_scores = [30, 20, 25, 15]
        
            fig = go.Figure()
            fig.add_trace(go.Scatter(x=session_dates, y=anxiety_scores, mode='lines+markers', name='Anxiety'))
            fig.add_trace(go.Scatter(x=session_dates, y=sadness_scores, mode='lines+markers', name='Sadness'))
            fig.update_layout(title='Emotional Trends', xaxis_title='Date', yaxis_title='Score (%)')
            st.plotly_chart(fig)