Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
import os | |
import json | |
import logging | |
import re | |
from typing import Dict, Any | |
from pathlib import Path | |
from unstructured.partition.pdf import partition_pdf | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from dotenv import load_dotenv | |
from bloatectomy import bloatectomy | |
from werkzeug.utils import secure_filename | |
from langchain_groq import ChatGroq | |
from typing_extensions import TypedDict, NotRequired | |
# | |
# --- Logging --- | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
logger = logging.getLogger("patient-assistant") | |
# --- Load environment --- | |
load_dotenv() | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY") | |
if not GROQ_API_KEY: | |
logger.error("GROQ_API_KEY not set in environment") | |
exit(1) | |
# --- Flask app setup --- | |
BASE_DIR = Path(__file__).resolve().parent | |
REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", str(BASE_DIR / "reports"))) | |
static_folder = BASE_DIR / "static" | |
app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static") | |
CORS(app) | |
# Ensure the reports directory exists | |
os.makedirs(REPORTS_ROOT, exist_ok=True) | |
# --- LLM setup --- | |
llm = ChatGroq( | |
model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"), | |
temperature=0.0, | |
max_tokens=1024, | |
api_key=GROQ_API_KEY, | |
) | |
def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str: | |
"""Helper function to clean up text using the bloatectomy library.""" | |
try: | |
b = bloatectomy(text, style=style, output="html") | |
tokens = getattr(b, "tokens", None) | |
if not tokens: | |
return text | |
return "\n".join(tokens) | |
except Exception: | |
logger.exception("Bloatectomy cleaning failed; returning original text") | |
return text | |
# --- Agent prompt instructions --- | |
PATIENT_ASSISTANT_PROMPT = """ | |
You are a patient assistant helping to analyze medical records and reports. Your primary task is to get the patient ID (PID) from the user at the start of the conversation. | |
Once you have the PID, you will be provided with a summary of the patient's medical reports. Use this information, along with the conversation history, to provide a comprehensive response. | |
Your tasks include: | |
- **First, ask for the patient ID.** Do not proceed with any other task until you have the PID. | |
- Analyzing medical records and reports to detect anomalies, redundant tests, or misleading treatments. | |
- Suggesting preventive care based on the overall patient health history. | |
- Optimizing healthcare costs by comparing past visits and treatments. | |
- Offering personalized lifestyle recommendations. | |
- Generating a natural, helpful reply to the user. | |
STRICT OUTPUT FORMAT (JSON ONLY): | |
Return a single JSON object with the following keys: | |
- assistant_reply: string // a natural language reply to the user (short, helpful, always present) | |
- patientDetails: object // keys may include name, problem, pid (patient ID), city, contact (update if user shared info) | |
- conversationSummary: string (optional) // short summary of conversation + relevant patient docs | |
Rules: | |
- ALWAYS include `assistant_reply` as a non-empty string. | |
- Do NOT produce any text outside the JSON object. | |
- Be concise in `assistant_reply`. If you need more details, ask a targeted follow-up question. | |
- Do not make up information that is not present in the provided medical reports or conversation history. | |
""" | |
# --- JSON extraction helper --- | |
def extract_json_from_llm_response(raw_response: str) -> dict: | |
"""Safely extracts a JSON object from a string that might contain extra text or markdown.""" | |
default = { | |
"assistant_reply": "I'm sorry — I couldn't understand that. Could you please rephrase?", | |
"patientDetails": {}, | |
"conversationSummary": "", | |
} | |
if not raw_response or not isinstance(raw_response, str): | |
return default | |
# Find the JSON object, ignoring any markdown code fences | |
m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response) | |
json_string = m.group(1).strip() if m else raw_response | |
# Find the first opening brace and the last closing brace | |
first = json_string.find('{') | |
last = json_string.rfind('}') | |
if first == -1 or last == -1 or first >= last: | |
try: | |
return json.loads(json_string) | |
except Exception: | |
logger.warning("Could not locate JSON braces in LLM output. Falling back to default.") | |
return default | |
candidate = json_string[first:last+1] | |
# Remove trailing commas that might cause parsing issues | |
candidate = re.sub(r',\s*(?=[}\]])', '', candidate) | |
try: | |
parsed = json.loads(candidate) | |
except Exception as e: | |
logger.warning("Failed to parse JSON from LLM output: %s", e) | |
return default | |
# Basic validation of the parsed JSON | |
if isinstance(parsed, dict) and "assistant_reply" in parsed and isinstance(parsed["assistant_reply"], str) and parsed["assistant_reply"].strip(): | |
parsed.setdefault("patientDetails", {}) | |
parsed.setdefault("conversationSummary", "") | |
return parsed | |
else: | |
logger.warning("Parsed JSON missing 'assistant_reply' or invalid format. Returning default.") | |
return default | |
# --- Flask routes --- | |
def serve_frontend(): | |
"""Serves the frontend HTML file.""" | |
try: | |
return app.send_static_file("frontend.html") | |
except Exception: | |
return "<h3>frontend2.html not found in static/ — please add your frontend2.html there.</h3>", 404 | |
def upload_report(): | |
"""Handles the upload of a new PDF report for a specific patient.""" | |
if 'report' not in request.files: | |
return jsonify({"error": "No file part in the request"}), 400 | |
file = request.files['report'] | |
patient_id = request.form.get("patient_id") | |
if file.filename == '' or not patient_id: | |
return jsonify({"error": "No selected file or patient ID"}), 400 | |
if file: | |
filename = secure_filename(file.filename) | |
patient_folder = REPORTS_ROOT / f"p_{patient_id}" | |
os.makedirs(patient_folder, exist_ok=True) | |
file_path = patient_folder / filename | |
file.save(file_path) | |
return jsonify({"message": f"File '{filename}' uploaded successfully for patient ID '{patient_id}'."}), 200 | |
def chat(): | |
"""Handles the chat conversation with the assistant.""" | |
data = request.get_json(force=True) | |
if not isinstance(data, dict): | |
return jsonify({"error": "invalid request body"}), 400 | |
chat_history = data.get("chat_history") or [] | |
patient_state = data.get("patient_state") or {} | |
patient_id = patient_state.get("patientDetails", {}).get("pid") | |
# --- Prepare the state for the LLM --- | |
state = patient_state.copy() | |
state["lastUserMessage"] = "" | |
if chat_history: | |
# Find the last user message | |
for msg in reversed(chat_history): | |
if msg.get("role") == "user" and msg.get("content"): | |
state["lastUserMessage"] = msg["content"] | |
break | |
combined_text_parts = [] | |
# If a PID is not yet known, prompt the agent to ask for it. | |
if not patient_id: | |
# A simple prompt to get the agent to ask for the PID. | |
user_prompt = "Hello. I need to get the patient's ID to proceed." | |
# Check if the user's last message contains a possible number for the PID | |
last_message = state.get("lastUserMessage", "") | |
# A very basic check to see if the user provided a number | |
if re.search(r'\d+', last_message): | |
inferred_pid = re.search(r'(\d+)', last_message).group(1) | |
state["patientDetails"] = {"pid": inferred_pid} | |
patient_id = inferred_pid | |
# Now that we have a PID, let the agent know to process the reports. | |
user_prompt = f"The user provided a patient ID: {inferred_pid}. Please access their reports and respond." | |
else: | |
# If no PID is found, the agent should ask for it. | |
user_prompt = "The patient has not provided a patient ID. Please ask them to provide it to proceed." | |
# If a PID is known, load the patient reports. | |
if patient_id: | |
patient_folder = REPORTS_ROOT / f"p_{patient_id}" | |
if patient_folder.exists() and patient_folder.is_dir(): | |
for fname in sorted(os.listdir(patient_folder)): | |
file_path = patient_folder / fname | |
page_text = "" | |
if partition_pdf is not None and str(file_path).lower().endswith('.pdf'): | |
try: | |
elements = partition_pdf(filename=str(file_path)) | |
page_text = "\n".join([el.text for el in elements if hasattr(el, 'text') and el.text]) | |
except Exception: | |
logger.exception("Failed to parse PDF %s", file_path) | |
else: | |
try: | |
page_text = file_path.read_text(encoding='utf-8', errors='ignore') | |
except Exception: | |
page_text = "" | |
if page_text: | |
cleaned = clean_notes_with_bloatectomy(page_text, style="remov") | |
if cleaned: | |
combined_text_parts.append(cleaned) | |
# Update the conversation summary with the parsed documents | |
base_summary = state.get("conversationSummary", "") or "" | |
docs_summary = "\n\n".join(combined_text_parts) | |
if docs_summary: | |
state["conversationSummary"] = (base_summary + "\n\n" + docs_summary).strip() | |
else: | |
state["conversationSummary"] = base_summary | |
# --- Direct LLM Invocation --- | |
user_prompt = f""" | |
Current patientDetails: {json.dumps(state.get("patientDetails", {}))} | |
Current conversationSummary: {state.get("conversationSummary", "")} | |
Last user message: {state.get("lastUserMessage", "")} | |
Return ONLY valid JSON with keys: assistant_reply, patientDetails, conversationSummary. | |
""" | |
messages = [ | |
{"role": "system", "content": PATIENT_ASSISTANT_PROMPT}, | |
{"role": "user", "content": user_prompt} | |
] | |
try: | |
logger.info("Invoking LLM with prepared state and prompt...") | |
llm_response = llm.invoke(messages) | |
raw_response = "" | |
if hasattr(llm_response, "content"): | |
raw_response = llm_response.content | |
else: | |
raw_response = str(llm_response) | |
logger.info(f"Raw LLM response: {raw_response}") | |
parsed_result = extract_json_from_llm_response(raw_response) | |
except Exception as e: | |
logger.exception("LLM invocation failed") | |
return jsonify({"error": "LLM invocation failed", "detail": str(e)}), 500 | |
updated_state = parsed_result or {} | |
assistant_reply = updated_state.get("assistant_reply") | |
if not assistant_reply or not isinstance(assistant_reply, str) or not assistant_reply.strip(): | |
# Fallback to a polite message if the LLM response is invalid or empty | |
assistant_reply = "I'm here to help — could you tell me more about your symptoms?" | |
response_payload = { | |
"assistant_reply": assistant_reply, | |
"updated_state": updated_state, | |
} | |
return jsonify(response_payload) | |
def upload_reports(): | |
""" | |
Upload one or more files for a patient. | |
Expects multipart/form-data with: | |
- patient_id (form field) | |
- files (one or multiple files; use the same field name 'files' for each file) | |
Example curl: | |
curl -X POST http://localhost:7860/upload_reports \ | |
-F "patient_id=12345" \ | |
-F "files[]=@/path/to/report1.pdf" \ | |
-F "files[]=@/path/to/report2.pdf" | |
""" | |
try: | |
# patient id can be in form or args (for convenience) | |
patient_id = request.form.get("patient_id") or request.args.get("patient_id") | |
if not patient_id: | |
return jsonify({"error": "patient_id form field required"}), 400 | |
# get uploaded files (support both files and files[] naming) | |
uploaded_files = request.files.getlist("files") | |
if not uploaded_files: | |
# fallback: single file under name 'file' | |
single = request.files.get("file") | |
if single: | |
uploaded_files = [single] | |
if not uploaded_files: | |
return jsonify({"error": "no files uploaded (use form field 'files')"}), 400 | |
# create patient folder under REPORTS_ROOT/<patient_id> | |
patient_folder = REPORTS_ROOT / str(patient_id) | |
patient_folder.mkdir(parents=True, exist_ok=True) | |
saved = [] | |
skipped = [] | |
for file_storage in uploaded_files: | |
orig_name = getattr(file_storage, "filename", "") or "" | |
filename = secure_filename(orig_name) | |
if not filename: | |
skipped.append({"filename": orig_name, "reason": "invalid filename"}) | |
continue | |
# extension check | |
ext = filename.rsplit(".", 1)[1].lower() if "." in filename else "" | |
if ext not in ALLOWED_EXTENSIONS: | |
skipped.append({"filename": filename, "reason": f"extension '{ext}' not allowed"}) | |
continue | |
# avoid overwriting: if collision, add numeric suffix | |
dest = patient_folder / filename | |
if dest.exists(): | |
base, dot, extension = filename.rpartition(".") | |
# if no base (e.g. ".bashrc") fallback | |
base = base or filename | |
i = 1 | |
while True: | |
candidate = f"{base}__{i}.{extension}" if extension else f"{base}__{i}" | |
dest = patient_folder / candidate | |
if not dest.exists(): | |
filename = candidate | |
break | |
i += 1 | |
try: | |
file_storage.save(str(dest)) | |
saved.append(filename) | |
except Exception as e: | |
logger.exception("Failed to save uploaded file %s: %s", filename, e) | |
skipped.append({"filename": filename, "reason": f"save failed: {e}"}) | |
return jsonify({ | |
"patient_id": str(patient_id), | |
"saved": saved, | |
"skipped": skipped, | |
"patient_folder": str(patient_folder) | |
}), 200 | |
except Exception as exc: | |
logger.exception("Upload failed: %s", exc) | |
return jsonify({"error": "upload failed", "detail": str(exc)}), 500 | |
def ping(): | |
return jsonify({"status": "ok"}) | |
if __name__ == "__main__": | |
port = int(os.getenv("PORT", 7860)) | |
app.run(host="0.0.0.0", port=port, debug=True) |