Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
import os | |
import json | |
import logging | |
import re | |
from typing import Dict, Any | |
from pathlib import Path | |
from unstructured.partition.pdf import partition_pdf | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from dotenv import load_dotenv | |
from bloatectomy import bloatectomy | |
from werkzeug.utils import secure_filename | |
from langchain_groq import ChatGroq | |
from typing_extensions import TypedDict, NotRequired | |
# --- Logging --- | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
logger = logging.getLogger("patient-assistant") | |
# --- Load environment --- | |
load_dotenv() | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
if not GROQ_API_KEY: | |
logger.error("GROQ_API_KEY not set in environment") | |
exit(1) | |
# --- Flask app setup --- | |
BASE_DIR = Path(__file__).resolve().parent | |
REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", str(BASE_DIR / "reports"))) | |
static_folder = BASE_DIR / "static" | |
app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static") | |
CORS(app) | |
# --- LLM setup --- | |
llm = ChatGroq( | |
model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"), | |
temperature=0.0, | |
max_tokens=1024, | |
api_key=GROQ_API_KEY, | |
) | |
def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str: | |
"""Helper function to clean up text using the bloatectomy library.""" | |
try: | |
b = bloatectomy(text, style=style, output="html") | |
tokens = getattr(b, "tokens", None) | |
if not tokens: | |
return text | |
return "\n".join(tokens) | |
except Exception: | |
logger.exception("Bloatectomy cleaning failed; returning original text") | |
return text | |
# --- Agent prompt instructions --- | |
PATIENT_ASSISTANT_PROMPT = """ | |
You are a patient assistant helping to analyze medical records and reports. Your goal is to provide a comprehensive response based on the patient's health history and the current conversation. | |
Your tasks include: | |
- Analyzing medical records and reports to detect anomalies, redundant tests, or misleading treatments. | |
- Suggesting preventive care based on the overall patient health history. | |
- Optimizing healthcare costs by comparing past visits and treatments, helping patients make smarter choices. | |
- Offering personalized lifestyle recommendations, such as adopting healthier food habits, daily routines, and regular health checks. | |
- Generating a natural, helpful reply to the user. | |
You will be provided with the last user message, the conversation history, and a summary of the patient's medical reports. Use this information to give a tailored and informative response. | |
STRICT OUTPUT FORMAT (JSON ONLY): | |
Return a single JSON object with the following keys: | |
- assistant_reply: string // a natural language reply to the user (short, helpful, always present) | |
- patientDetails: object // keys may include name, problem, city, contact (update if user shared info) | |
- conversationSummary: string (optional) // short summary of conversation + relevant patient docs | |
Rules: | |
- ALWAYS include `assistant_reply` as a non-empty string. | |
- Do NOT produce any text outside the JSON object. | |
- Be concise in `assistant_reply`. If you need more details, ask a targeted follow-up question. | |
- Do not make up information that is not present in the provided medical reports or conversation history. | |
""" | |
# --- JSON extraction helper --- | |
def extract_json_from_llm_response(raw_response: str) -> dict: | |
"""Safely extracts a JSON object from a string that might contain extra text or markdown.""" | |
default = { | |
"assistant_reply": "I'm sorry — I couldn't understand that. Could you please rephrase?", | |
"patientDetails": {}, | |
"conversationSummary": "", | |
} | |
if not raw_response or not isinstance(raw_response, str): | |
return default | |
# Find the JSON object, ignoring any markdown code fences | |
m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response) | |
json_string = m.group(1).strip() if m else raw_response | |
# Find the first opening brace and the last closing brace | |
first = json_string.find('{') | |
last = json_string.rfind('}') | |
if first == -1 or last == -1 or first >= last: | |
try: | |
return json.loads(json_string) | |
except Exception: | |
logger.warning("Could not locate JSON braces in LLM output. Falling back to default.") | |
return default | |
candidate = json_string[first:last+1] | |
# Remove trailing commas that might cause parsing issues | |
candidate = re.sub(r',\s*(?=[}\]])', '', candidate) | |
try: | |
parsed = json.loads(candidate) | |
except Exception as e: | |
logger.warning("Failed to parse JSON from LLM output: %s", e) | |
return default | |
# Basic validation of the parsed JSON | |
if isinstance(parsed, dict) and "assistant_reply" in parsed and isinstance(parsed["assistant_reply"], str) and parsed["assistant_reply"].strip(): | |
parsed.setdefault("patientDetails", {}) | |
parsed.setdefault("conversationSummary", "") | |
return parsed | |
else: | |
logger.warning("Parsed JSON missing 'assistant_reply' or invalid format. Returning default.") | |
return default | |
# --- Flask routes --- | |
def serve_frontend(): | |
"""Serves the frontend HTML file.""" | |
try: | |
return app.send_static_file("frontend2.html") | |
except Exception: | |
return "<h3>frontend2.html not found in static/ — please add your frontend2.html there.</h3>", 404 | |
def chat(): | |
"""Handles the chat conversation with the assistant.""" | |
data = request.get_json(force=True) | |
if not isinstance(data, dict): | |
return jsonify({"error": "invalid request body"}), 400 | |
patient_id = data.get("patient_id") | |
if not patient_id: | |
return jsonify({"error": "patient_id required"}), 400 | |
chat_history = data.get("chat_history") or [] | |
patient_state = data.get("patient_state") or {} | |
# --- Read and parse patient reports --- | |
patient_folder = REPORTS_ROOT / f"p_{patient_id}" | |
combined_text_parts = [] | |
if patient_folder.exists() and patient_folder.is_dir(): | |
for fname in sorted(os.listdir(patient_folder)): | |
file_path = patient_folder / fname | |
page_text = "" | |
if partition_pdf is not None and str(file_path).lower().endswith('.pdf'): | |
try: | |
elements = partition_pdf(filename=str(file_path)) | |
page_text = "\n".join([el.text for el in elements if hasattr(el, 'text') and el.text]) | |
except Exception: | |
logger.exception("Failed to parse PDF %s", file_path) | |
else: | |
try: | |
page_text = file_path.read_text(encoding='utf-8', errors='ignore') | |
except Exception: | |
page_text = "" | |
if page_text: | |
cleaned = clean_notes_with_bloatectomy(page_text, style="remov") | |
if cleaned: | |
combined_text_parts.append(cleaned) | |
# --- Prepare the state for the LLM --- | |
state = patient_state.copy() | |
state["lastUserMessage"] = "" | |
if chat_history: | |
# Find the last user message | |
for msg in reversed(chat_history): | |
if msg.get("role") == "user" and msg.get("content"): | |
state["lastUserMessage"] = msg["content"] | |
break | |
# Update the conversation summary with the parsed documents | |
base_summary = state.get("conversationSummary", "") or "" | |
docs_summary = "\n\n".join(combined_text_parts) | |
if docs_summary: | |
state["conversationSummary"] = (base_summary + "\n\n" + docs_summary).strip() | |
else: | |
state["conversationSummary"] = base_summary | |
# --- Direct LLM Invocation --- | |
user_prompt = f""" | |
Current patientDetails: {json.dumps(state.get("patientDetails", {}))} | |
Current conversationSummary: {state.get("conversationSummary", "")} | |
Last user message: {state.get("lastUserMessage", "")} | |
Return ONLY valid JSON with keys: assistant_reply, patientDetails, conversationSummary. | |
""" | |
messages = [ | |
{"role": "system", "content": PATIENT_ASSISTANT_PROMPT}, | |
{"role": "user", "content": user_prompt} | |
] | |
try: | |
logger.info("Invoking LLM with prepared state and prompt...") | |
llm_response = llm.invoke(messages) | |
raw_response = "" | |
if hasattr(llm_response, "content"): | |
raw_response = llm_response.content | |
else: | |
raw_response = str(llm_response) | |
logger.info(f"Raw LLM response: {raw_response}") | |
parsed_result = extract_json_from_llm_response(raw_response) | |
except Exception as e: | |
logger.exception("LLM invocation failed") | |
return jsonify({"error": "LLM invocation failed", "detail": str(e)}), 500 | |
updated_state = parsed_result or {} | |
assistant_reply = updated_state.get("assistant_reply") | |
if not assistant_reply or not isinstance(assistant_reply, str) or not assistant_reply.strip(): | |
# Fallback to a polite message if the LLM response is invalid or empty | |
assistant_reply = "I'm here to help — could you tell me more about your symptoms?" | |
response_payload = { | |
"assistant_reply": assistant_reply, | |
"updated_state": updated_state, | |
} | |
return jsonify(response_payload) | |
def ping(): | |
return jsonify({"status": "ok"}) | |
if __name__ == "__main__": | |
port = int(os.getenv("PORT", 7860)) | |
app.run(host="0.0.0.0", port=port, debug=True) | |