patient_bot / app.py
WebashalarForML's picture
Update app.py
42e73f2 verified
raw
history blame
9.62 kB
#!/usr/bin/env python3
import os
import json
import logging
import re
from typing import Dict, Any
from pathlib import Path
from unstructured.partition.pdf import partition_pdf
from flask import Flask, request, jsonify
from flask_cors import CORS
from dotenv import load_dotenv
from bloatectomy import bloatectomy
from werkzeug.utils import secure_filename
from langchain_groq import ChatGroq
from typing_extensions import TypedDict, NotRequired
# --- Logging ---
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("patient-assistant")
# --- Load environment ---
load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
if not GROQ_API_KEY:
logger.error("GROQ_API_KEY not set in environment")
exit(1)
# --- Flask app setup ---
BASE_DIR = Path(__file__).resolve().parent
REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", str(BASE_DIR / "reports")))
static_folder = BASE_DIR / "static"
app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static")
CORS(app)
# --- LLM setup ---
llm = ChatGroq(
model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
temperature=0.0,
max_tokens=1024,
api_key=GROQ_API_KEY,
)
def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
"""Helper function to clean up text using the bloatectomy library."""
try:
b = bloatectomy(text, style=style, output="html")
tokens = getattr(b, "tokens", None)
if not tokens:
return text
return "\n".join(tokens)
except Exception:
logger.exception("Bloatectomy cleaning failed; returning original text")
return text
# --- Agent prompt instructions ---
PATIENT_ASSISTANT_PROMPT = """
You are a patient assistant helping to analyze medical records and reports. Your goal is to provide a comprehensive response based on the patient's health history and the current conversation.
Your tasks include:
- Analyzing medical records and reports to detect anomalies, redundant tests, or misleading treatments.
- Suggesting preventive care based on the overall patient health history.
- Optimizing healthcare costs by comparing past visits and treatments, helping patients make smarter choices.
- Offering personalized lifestyle recommendations, such as adopting healthier food habits, daily routines, and regular health checks.
- Generating a natural, helpful reply to the user.
You will be provided with the last user message, the conversation history, and a summary of the patient's medical reports. Use this information to give a tailored and informative response.
STRICT OUTPUT FORMAT (JSON ONLY):
Return a single JSON object with the following keys:
- assistant_reply: string // a natural language reply to the user (short, helpful, always present)
- patientDetails: object // keys may include name, problem, city, contact (update if user shared info)
- conversationSummary: string (optional) // short summary of conversation + relevant patient docs
Rules:
- ALWAYS include `assistant_reply` as a non-empty string.
- Do NOT produce any text outside the JSON object.
- Be concise in `assistant_reply`. If you need more details, ask a targeted follow-up question.
- Do not make up information that is not present in the provided medical reports or conversation history.
"""
# --- JSON extraction helper ---
def extract_json_from_llm_response(raw_response: str) -> dict:
"""Safely extracts a JSON object from a string that might contain extra text or markdown."""
default = {
"assistant_reply": "I'm sorry — I couldn't understand that. Could you please rephrase?",
"patientDetails": {},
"conversationSummary": "",
}
if not raw_response or not isinstance(raw_response, str):
return default
# Find the JSON object, ignoring any markdown code fences
m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response)
json_string = m.group(1).strip() if m else raw_response
# Find the first opening brace and the last closing brace
first = json_string.find('{')
last = json_string.rfind('}')
if first == -1 or last == -1 or first >= last:
try:
return json.loads(json_string)
except Exception:
logger.warning("Could not locate JSON braces in LLM output. Falling back to default.")
return default
candidate = json_string[first:last+1]
# Remove trailing commas that might cause parsing issues
candidate = re.sub(r',\s*(?=[}\]])', '', candidate)
try:
parsed = json.loads(candidate)
except Exception as e:
logger.warning("Failed to parse JSON from LLM output: %s", e)
return default
# Basic validation of the parsed JSON
if isinstance(parsed, dict) and "assistant_reply" in parsed and isinstance(parsed["assistant_reply"], str) and parsed["assistant_reply"].strip():
parsed.setdefault("patientDetails", {})
parsed.setdefault("conversationSummary", "")
return parsed
else:
logger.warning("Parsed JSON missing 'assistant_reply' or invalid format. Returning default.")
return default
# --- Flask routes ---
@app.route("/", methods=["GET"])
def serve_frontend():
"""Serves the frontend HTML file."""
try:
return app.send_static_file("frontend2.html")
except Exception:
return "<h3>frontend2.html not found in static/ — please add your frontend2.html there.</h3>", 404
@app.route("/chat", methods=["POST"])
def chat():
"""Handles the chat conversation with the assistant."""
data = request.get_json(force=True)
if not isinstance(data, dict):
return jsonify({"error": "invalid request body"}), 400
patient_id = data.get("patient_id")
if not patient_id:
return jsonify({"error": "patient_id required"}), 400
chat_history = data.get("chat_history") or []
patient_state = data.get("patient_state") or {}
# --- Read and parse patient reports ---
patient_folder = REPORTS_ROOT / f"p_{patient_id}"
combined_text_parts = []
if patient_folder.exists() and patient_folder.is_dir():
for fname in sorted(os.listdir(patient_folder)):
file_path = patient_folder / fname
page_text = ""
if partition_pdf is not None and str(file_path).lower().endswith('.pdf'):
try:
elements = partition_pdf(filename=str(file_path))
page_text = "\n".join([el.text for el in elements if hasattr(el, 'text') and el.text])
except Exception:
logger.exception("Failed to parse PDF %s", file_path)
else:
try:
page_text = file_path.read_text(encoding='utf-8', errors='ignore')
except Exception:
page_text = ""
if page_text:
cleaned = clean_notes_with_bloatectomy(page_text, style="remov")
if cleaned:
combined_text_parts.append(cleaned)
# --- Prepare the state for the LLM ---
state = patient_state.copy()
state["lastUserMessage"] = ""
if chat_history:
# Find the last user message
for msg in reversed(chat_history):
if msg.get("role") == "user" and msg.get("content"):
state["lastUserMessage"] = msg["content"]
break
# Update the conversation summary with the parsed documents
base_summary = state.get("conversationSummary", "") or ""
docs_summary = "\n\n".join(combined_text_parts)
if docs_summary:
state["conversationSummary"] = (base_summary + "\n\n" + docs_summary).strip()
else:
state["conversationSummary"] = base_summary
# --- Direct LLM Invocation ---
user_prompt = f"""
Current patientDetails: {json.dumps(state.get("patientDetails", {}))}
Current conversationSummary: {state.get("conversationSummary", "")}
Last user message: {state.get("lastUserMessage", "")}
Return ONLY valid JSON with keys: assistant_reply, patientDetails, conversationSummary.
"""
messages = [
{"role": "system", "content": PATIENT_ASSISTANT_PROMPT},
{"role": "user", "content": user_prompt}
]
try:
logger.info("Invoking LLM with prepared state and prompt...")
llm_response = llm.invoke(messages)
raw_response = ""
if hasattr(llm_response, "content"):
raw_response = llm_response.content
else:
raw_response = str(llm_response)
logger.info(f"Raw LLM response: {raw_response}")
parsed_result = extract_json_from_llm_response(raw_response)
except Exception as e:
logger.exception("LLM invocation failed")
return jsonify({"error": "LLM invocation failed", "detail": str(e)}), 500
updated_state = parsed_result or {}
assistant_reply = updated_state.get("assistant_reply")
if not assistant_reply or not isinstance(assistant_reply, str) or not assistant_reply.strip():
# Fallback to a polite message if the LLM response is invalid or empty
assistant_reply = "I'm here to help — could you tell me more about your symptoms?"
response_payload = {
"assistant_reply": assistant_reply,
"updated_state": updated_state,
}
return jsonify(response_payload)
@app.route("/ping", methods=["GET"])
def ping():
return jsonify({"status": "ok"})
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=True)