File size: 11,590 Bytes
ceaa691
 
 
 
 
42e73f2
ceaa691
42e73f2
ceaa691
 
 
 
42e73f2
ceaa691
 
 
42e73f2
ceaa691
42e73f2
ceaa691
42e73f2
ceaa691
42e73f2
 
 
 
 
 
 
 
 
 
 
 
 
47c5ad7
 
 
42e73f2
ceaa691
 
 
42e73f2
 
ceaa691
 
 
42e73f2
ceaa691
 
 
 
 
 
 
 
 
 
42e73f2
 
47c5ad7
 
 
42e73f2
 
47c5ad7
42e73f2
 
47c5ad7
 
42e73f2
 
 
 
 
47c5ad7
42e73f2
 
 
 
 
 
 
ceaa691
 
42e73f2
 
 
 
 
 
 
 
ceaa691
42e73f2
 
ceaa691
42e73f2
 
 
ceaa691
42e73f2
 
 
 
 
 
 
 
 
 
 
 
 
ceaa691
 
42e73f2
 
 
 
ceaa691
42e73f2
 
 
 
ceaa691
42e73f2
 
 
ceaa691
42e73f2
ceaa691
 
42e73f2
ceaa691
47c5ad7
42e73f2
 
ceaa691
47c5ad7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42e73f2
 
 
 
 
 
b2ae2a5
42e73f2
 
47c5ad7
42e73f2
 
 
 
 
 
 
 
 
 
47c5ad7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42e73f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceaa691
42e73f2
 
 
 
 
ceaa691
42e73f2
 
 
 
 
 
 
ceaa691
42e73f2
 
 
ceaa691
42e73f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04342e7
ceaa691
 
 
 
 
47c5ad7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
#!/usr/bin/env python3
import os
import json
import logging
import re
from typing import Dict, Any
from pathlib import Path
from unstructured.partition.pdf import partition_pdf
from flask import Flask, request, jsonify
from flask_cors import CORS
from dotenv import load_dotenv
from bloatectomy import bloatectomy
from werkzeug.utils import secure_filename
from langchain_groq import ChatGroq
from typing_extensions import TypedDict, NotRequired

# --- Logging ---
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("patient-assistant")

# --- Load environment ---
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
    logger.error("GROQ_API_KEY not set in environment")
    exit(1)

# --- Flask app setup ---
BASE_DIR = Path(__file__).resolve().parent
REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", str(BASE_DIR / "reports")))
static_folder = BASE_DIR / "static"

app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static")
CORS(app)

# Ensure the reports directory exists
os.makedirs(REPORTS_ROOT, exist_ok=True)

# --- LLM setup ---
llm = ChatGroq(
    model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
    temperature=0.0,
    max_tokens=1024,
    api_key=GROQ_API_KEY,
)

def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
    """Helper function to clean up text using the bloatectomy library."""
    try:
        b = bloatectomy(text, style=style, output="html")
        tokens = getattr(b, "tokens", None)
        if not tokens:
            return text
        return "\n".join(tokens)
    except Exception:
        logger.exception("Bloatectomy cleaning failed; returning original text")
        return text

# --- Agent prompt instructions ---
PATIENT_ASSISTANT_PROMPT = """
You are a patient assistant helping to analyze medical records and reports. Your primary task is to get the patient ID (PID) from the user at the start of the conversation.

Once you have the PID, you will be provided with a summary of the patient's medical reports. Use this information, along with the conversation history, to provide a comprehensive response.

Your tasks include:
- **First, ask for the patient ID.** Do not proceed with any other task until you have the PID.
- Analyzing medical records and reports to detect anomalies, redundant tests, or misleading treatments.
- Suggesting preventive care based on the overall patient health history.
- Optimizing healthcare costs by comparing past visits and treatments.
- Offering personalized lifestyle recommendations.
- Generating a natural, helpful reply to the user.

STRICT OUTPUT FORMAT (JSON ONLY):
Return a single JSON object with the following keys:
- assistant_reply: string  // a natural language reply to the user (short, helpful, always present)
- patientDetails: object  // keys may include name, problem, pid (patient ID), city, contact (update if user shared info)
- conversationSummary: string (optional)  // short summary of conversation + relevant patient docs

Rules:
- ALWAYS include `assistant_reply` as a non-empty string.
- Do NOT produce any text outside the JSON object.
- Be concise in `assistant_reply`. If you need more details, ask a targeted follow-up question.
- Do not make up information that is not present in the provided medical reports or conversation history.
"""

# --- JSON extraction helper ---
def extract_json_from_llm_response(raw_response: str) -> dict:
    """Safely extracts a JSON object from a string that might contain extra text or markdown."""
    default = {
        "assistant_reply": "I'm sorry — I couldn't understand that. Could you please rephrase?",
        "patientDetails": {},
        "conversationSummary": "",
    }

    if not raw_response or not isinstance(raw_response, str):
        return default

    # Find the JSON object, ignoring any markdown code fences
    m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response)
    json_string = m.group(1).strip() if m else raw_response

    # Find the first opening brace and the last closing brace
    first = json_string.find('{')
    last = json_string.rfind('}')
    if first == -1 or last == -1 or first >= last:
        try:
            return json.loads(json_string)
        except Exception:
            logger.warning("Could not locate JSON braces in LLM output. Falling back to default.")
            return default

    candidate = json_string[first:last+1]
    # Remove trailing commas that might cause parsing issues
    candidate = re.sub(r',\s*(?=[}\]])', '', candidate)

    try:
        parsed = json.loads(candidate)
    except Exception as e:
        logger.warning("Failed to parse JSON from LLM output: %s", e)
        return default

    # Basic validation of the parsed JSON
    if isinstance(parsed, dict) and "assistant_reply" in parsed and isinstance(parsed["assistant_reply"], str) and parsed["assistant_reply"].strip():
        parsed.setdefault("patientDetails", {})
        parsed.setdefault("conversationSummary", "")
        return parsed
    else:
        logger.warning("Parsed JSON missing 'assistant_reply' or invalid format. Returning default.")
        return default

# --- Flask routes ---
@app.route("/", methods=["GET"])
def serve_frontend():
    """Serves the frontend HTML file."""
    try:
        return app.send_static_file("frontend_p.html")
    except Exception:
        return "<h3>frontend2.html not found in static/ — please add your frontend2.html there.</h3>", 404

@app.route("/upload_report", methods=["POST"])
def upload_report():
    """Handles the upload of a new PDF report for a specific patient."""
    if 'report' not in request.files:
        return jsonify({"error": "No file part in the request"}), 400
    
    file = request.files['report']
    patient_id = request.form.get("patient_id")

    if file.filename == '' or not patient_id:
        return jsonify({"error": "No selected file or patient ID"}), 400

    if file:
        filename = secure_filename(file.filename)
        patient_folder = REPORTS_ROOT / f"p_{patient_id}"
        os.makedirs(patient_folder, exist_ok=True)
        file_path = patient_folder / filename
        file.save(file_path)
        return jsonify({"message": f"File '{filename}' uploaded successfully for patient ID '{patient_id}'."}), 200

@app.route("/chat", methods=["POST"])
def chat():
    """Handles the chat conversation with the assistant."""
    data = request.get_json(force=True)
    if not isinstance(data, dict):
        return jsonify({"error": "invalid request body"}), 400

    chat_history = data.get("chat_history") or []
    patient_state = data.get("patient_state") or {}
    patient_id = patient_state.get("patientDetails", {}).get("pid")
    
    # --- Prepare the state for the LLM ---
    state = patient_state.copy()
    state["lastUserMessage"] = ""
    if chat_history:
        # Find the last user message
        for msg in reversed(chat_history):
            if msg.get("role") == "user" and msg.get("content"):
                state["lastUserMessage"] = msg["content"]
                break

    combined_text_parts = []
    # If a PID is not yet known, prompt the agent to ask for it.
    if not patient_id:
        # A simple prompt to get the agent to ask for the PID.
        user_prompt = "Hello. I need to get the patient's ID to proceed."
        
        # Check if the user's last message contains a possible number for the PID
        last_message = state.get("lastUserMessage", "")
        # A very basic check to see if the user provided a number
        if re.search(r'\d+', last_message):
            inferred_pid = re.search(r'(\d+)', last_message).group(1)
            state["patientDetails"] = {"pid": inferred_pid}
            patient_id = inferred_pid
            # Now that we have a PID, let the agent know to process the reports.
            user_prompt = f"The user provided a patient ID: {inferred_pid}. Please access their reports and respond."
        else:
            # If no PID is found, the agent should ask for it.
            user_prompt = "The patient has not provided a patient ID. Please ask them to provide it to proceed."
    
    # If a PID is known, load the patient reports.
    if patient_id:
        patient_folder = REPORTS_ROOT / f"p_{patient_id}"
        if patient_folder.exists() and patient_folder.is_dir():
            for fname in sorted(os.listdir(patient_folder)):
                file_path = patient_folder / fname
                page_text = ""
                if partition_pdf is not None and str(file_path).lower().endswith('.pdf'):
                    try:
                        elements = partition_pdf(filename=str(file_path))
                        page_text = "\n".join([el.text for el in elements if hasattr(el, 'text') and el.text])
                    except Exception:
                        logger.exception("Failed to parse PDF %s", file_path)
                else:
                    try:
                        page_text = file_path.read_text(encoding='utf-8', errors='ignore')
                    except Exception:
                        page_text = ""
                
                if page_text:
                    cleaned = clean_notes_with_bloatectomy(page_text, style="remov")
                    if cleaned:
                        combined_text_parts.append(cleaned)
    
    # Update the conversation summary with the parsed documents
    base_summary = state.get("conversationSummary", "") or ""
    docs_summary = "\n\n".join(combined_text_parts)
    if docs_summary:
        state["conversationSummary"] = (base_summary + "\n\n" + docs_summary).strip()
    else:
        state["conversationSummary"] = base_summary

    # --- Direct LLM Invocation ---
    user_prompt = f"""
Current patientDetails: {json.dumps(state.get("patientDetails", {}))}
Current conversationSummary: {state.get("conversationSummary", "")}
Last user message: {state.get("lastUserMessage", "")}
    
Return ONLY valid JSON with keys: assistant_reply, patientDetails, conversationSummary.
"""

    messages = [
        {"role": "system", "content": PATIENT_ASSISTANT_PROMPT},
        {"role": "user", "content": user_prompt}
    ]
    
    try:
        logger.info("Invoking LLM with prepared state and prompt...")
        llm_response = llm.invoke(messages)
        raw_response = ""
        if hasattr(llm_response, "content"):
            raw_response = llm_response.content
        else:
            raw_response = str(llm_response)

        logger.info(f"Raw LLM response: {raw_response}")
        parsed_result = extract_json_from_llm_response(raw_response)
        
    except Exception as e:
        logger.exception("LLM invocation failed")
        return jsonify({"error": "LLM invocation failed", "detail": str(e)}), 500

    updated_state = parsed_result or {}

    assistant_reply = updated_state.get("assistant_reply")
    if not assistant_reply or not isinstance(assistant_reply, str) or not assistant_reply.strip():
        # Fallback to a polite message if the LLM response is invalid or empty
        assistant_reply = "I'm here to help — could you tell me more about your symptoms?"

    response_payload = {
        "assistant_reply": assistant_reply,
        "updated_state": updated_state,
    }

    return jsonify(response_payload)

@app.route("/ping", methods=["GET"])
def ping():
    return jsonify({"status": "ok"})

if __name__ == "__main__":
    port = int(os.getenv("PORT", 5000))
    app.run(host="0.0.0.0", port=port, debug=True)