patient_bot / app.py
WebashalarForML's picture
Update app.py
77f14d0 verified
raw
history blame
19.8 kB
#!/usr/bin/env python3
# app.py - Health Reports processing agent (PDF -> cleaned text -> structured JSON)
import os
import json
import logging
import re
from pathlib import Path
from typing import List, Dict, Any
from flask import Flask, request, jsonify
from flask_cors import CORS
from dotenv import load_dotenv
from unstructured.partition.pdf import partition_pdf
# Bloatectomy class (as per the source you provided)
from bloatectomy import bloatectomy
# LLM / agent
from langchain_groq import ChatGroq
from langgraph.prebuilt import create_react_agent
# LangGraph imports
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict, NotRequired
# --- Logging ---------------------------------------------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("health-agent")
# --- Environment & config -------------------------------------------------
load_dotenv()
from pathlib import Path
REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", r"app\reports")) # e.g. /app/reports/<patient_id>/<file.pdf>
SSRI_FILE = Path(os.getenv("SSRI_FILE", r"app\medicationCategories\SSRI_list.txt"))
MISC_FILE = Path(os.getenv("MISC_FILE", r"app\medicationCategories\MISC_list.txt"))
GROQ_API_KEY = os.getenv("GROQ_API_KEY", None)
# --- LLM setup -------------------------------------------------------------
llm = ChatGroq(
model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
temperature=0.0,
max_tokens=None,
)
# Top-level strict system prompt for report JSON pieces (each node will use a more specific prompt)
NODE_BASE_INSTRUCTIONS = """
You are HealthAI — a clinical assistant producing JSON for downstream processing.
Produce only valid JSON (no extra text). Follow field types exactly. If missing data, return empty strings or empty arrays.
Be conservative: do not assert diagnoses; provide suggestions and ask physician confirmation where needed.
"""
# Build a generic agent and a JSON resolver agent (to fix broken JSON from LLM)
agent = create_react_agent(model=llm, tools=[], prompt=NODE_BASE_INSTRUCTIONS)
agent_json_resolver = create_react_agent(model=llm, tools=[], prompt="""
You are a JSON fixer. Input: a possibly-malformed JSON-like text. Output: valid JSON only (enclosed in triple backticks).
Fix missing quotes, trailing commas, unescaped newlines, stray assistant labels, and ensure schema compliance.
""")
# -------------------- JSON extraction / sanitizer ---------------------------
def extract_json_from_llm_response(raw_response: str) -> dict:
"""
Try extracting a JSON object from raw LLM text. Performs common cleanups seen in LLM outputs.
Raises JSONDecodeError if parsing still fails.
"""
# --- 1) Pull out the JSON code-block if present ---
md = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response)
json_string = md.group(1).strip() if md else raw_response
# --- 2) Trim to the outermost { … } so we drop any prefix/suffix junk ---
first, last = json_string.find('{'), json_string.rfind('}')
if 0 <= first < last:
json_string = json_string[first:last+1]
# --- 3) PRE-CLEANUP: remove rogue assistant labels, fix boolean quotes ---
json_string = re.sub(r'\b\w+\s*{', '{', json_string)
json_string = re.sub(r'"assistant"\s*:', '', json_string)
json_string = re.sub(r'\b(false|true)"', r'\1', json_string)
# --- 4) Escape embedded quotes in long string fields (best-effort) ---
def _esc(m):
prefix, body = m.group(1), m.group(2)
return prefix + body.replace('"', r'\"')
json_string = re.sub(
r'("logic"\s*:\s*")([\s\S]+?)(?=",\s*"[A-Za-z_]\w*"\s*:\s*)',
_esc,
json_string
)
# --- 5) Remove trailing commas before } or ] ---
json_string = re.sub(r',\s*(?=[}\],])', '', json_string)
json_string = re.sub(r',\s*,', ',', json_string)
# --- 6) Balance braces if obvious excess ---
ob, cb = json_string.count('{'), json_string.count('}')
if cb > ob:
excess = cb - ob
json_string = json_string.rstrip()[:-excess]
# --- 7) Escape literal newlines inside strings so json.loads can parse ---
def _escape_newlines_in_strings(s: str) -> str:
return re.sub(
r'"((?:[^"\\]|\\.)*?)"',
lambda m: '"' + m.group(1).replace('\n', '\\n').replace('\r', '\\r') + '"',
s,
flags=re.DOTALL
)
json_string = _escape_newlines_in_strings(json_string)
# Final parse
return json.loads(json_string)
# -------------------- Utility: Bloatectomy wrapper ------------------------
def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
"""
Uses the bloatectomy class to remove duplicates.
style: 'highlight'|'bold'|'remov' ; we use 'remov' to delete duplicates.
Returns cleaned text (single string).
"""
try:
b = bloatectomy(text, style=style, output="html")
tokens = getattr(b, "tokens", None)
if not tokens:
return text
return "\n".join(tokens)
except Exception:
logger.exception("Bloatectomy cleaning failed; returning original text")
return text
# --------------- Utility: medication extraction (adapted) -----------------
def readDrugs_from_file(path: Path):
if not path.exists():
return {}, []
txt = path.read_text(encoding="utf-8", errors="ignore")
generics = re.findall(r"^(.*?)\|", txt, re.MULTILINE)
generics = [g.lower() for g in generics if g]
lines = [ln.strip().lower() for ln in txt.splitlines() if ln.strip()]
return dict(zip(generics, lines)), generics
def addToDrugs_line(line: str, drugs_flags: List[int], listing: Dict[str,str], genList: List[str]) -> List[int]:
gen_index = {g:i for i,g in enumerate(genList)}
for generic, pattern_line in listing.items():
try:
if re.search(pattern_line, line, re.I):
idx = gen_index.get(generic)
if idx is not None:
drugs_flags[idx] = 1
except re.error:
continue
return drugs_flags
def extract_medications_from_text(text: str) -> List[str]:
ssri_map, ssri_generics = readDrugs_from_file(SSRI_FILE)
misc_map, misc_generics = readDrugs_from_file(MISC_FILE)
combined_map = {**ssri_map, **misc_map}
combined_generics = []
if ssri_generics:
combined_generics.extend(ssri_generics)
if misc_generics:
combined_generics.extend(misc_generics)
flags = [0]* len(combined_generics)
meds_found = set()
for ln in text.splitlines():
ln = ln.strip()
if not ln:
continue
if combined_map:
flags = addToDrugs_line(ln, flags, combined_map, combined_generics)
m = re.search(r"\b(Rx|Drug|Medication|Prescribed|Tablet)\s*[:\-]?\s*([A-Za-z0-9\-\s/\.]+)", ln, re.I)
if m:
meds_found.add(m.group(2).strip())
m2 = re.findall(r"\b([A-Z][a-z0-9\-]{2,}\s*(?:[0-9]{1,4}\s*(?:mg|mcg|g|IU))?)", ln)
for s in m2:
if re.search(r"\b(mg|mcg|g|IU)\b", s, re.I):
meds_found.add(s.strip())
for i, f in enumerate(flags):
if f == 1:
meds_found.add(combined_generics[i])
return list(meds_found)
# -------------------- Node prompts --------------------------
PATIENT_NODE_PROMPT = """
You will extract patientDetails from the provided document texts.
Return ONLY JSON with this exact shape:
{ "patientDetails": {"name": "", "age": "", "sex": "", "pid": ""} }
Fill fields using text evidence or leave empty strings.
"""
DOCTOR_NODE_PROMPT = """
You will extract doctorDetails found in the documents.
Return ONLY JSON with this exact shape:
{ "doctorDetails": {"referredBy": ""} }
"""
TEST_REPORT_NODE_PROMPT = """
You will extract per-test structured results from the documents.
Return ONLY JSON with this exact shape:
{
"reports": [
{
"testName": "",
"dateReported": "",
"timeReported": "",
"abnormalFindings": [
{"investigation": "", "result": 0, "unit": "", "status": "", "referenceValue": ""}
],
"interpretation": "",
"trends": []
}
]
}
- Include only findings that are outside reference ranges OR explicitly called 'abnormal' in the report.
- For result numeric parsing, prefer numeric values; if not numeric, keep original string.
- Use statuses: Low, High, Borderline, Positive, Negative, Normal.
"""
ANALYSIS_NODE_PROMPT = """
You will create an overallAnalysis based on the extracted reports (the agent will give you the 'reports' JSON).
Return ONLY JSON:
{ "overallAnalysis": { "summary": "", "recommendations": "", "longTermTrends": "",""risk_prediction": "","drug_interaction": "" } }
Be conservative, evidence-based, and suggest follow-up steps for physicians.
"""
CONDITION_LOOP_NODE_PROMPT = """
Validation and condition node:
Input: partial JSON (patientDetails, doctorDetails, reports, overallAnalysis).
Task: Check required keys exist and that each report has at least testName and abnormalFindings list.
Return ONLY JSON:
{ "valid": true, "missing": [] }
If missing fields, list keys in 'missing'. Do NOT modify content.
"""
# -------------------- Node helpers -------------------------
def call_node_agent(node_prompt: str, payload: dict) -> dict:
"""
Call the generic agent with a targeted node prompt and the payload.
Tries to parse JSON. If parsing fails, uses the JSON resolver agent once.
"""
try:
content = {
"prompt": node_prompt,
"payload": payload
}
resp = agent.invoke({"messages": [{"role": "user", "content": json.dumps(content)}]})
# Extract raw text from AIMessage or other response types
raw = None
if isinstance(resp, str):
raw = resp
elif hasattr(resp, "content"): # AIMessage or similar
raw = resp.content
elif isinstance(resp, dict):
msgs = resp.get("messages")
if msgs:
last_msg = msgs[-1]
if isinstance(last_msg, str):
raw = last_msg
elif hasattr(last_msg, "content"):
raw = last_msg.content
elif isinstance(last_msg, dict):
raw = last_msg.get("content", "")
else:
raw = str(last_msg)
else:
raw = json.dumps(resp)
else:
raw = str(resp)
parsed = extract_json_from_llm_response(raw)
return parsed
except Exception as e:
logger.warning("Node agent JSON parse failed: %s. Attempting JSON resolver.", e)
try:
resolver_prompt = f"Fix this JSON. Input:\n```json\n{raw}\n```\nReturn valid JSON only."
r = agent_json_resolver.invoke({"messages": [{"role": "user", "content": resolver_prompt}]})
rtxt = None
if isinstance(r, str):
rtxt = r
elif hasattr(r, "content"):
rtxt = r.content
elif isinstance(r, dict):
msgs = r.get("messages")
if msgs:
last_msg = msgs[-1]
if isinstance(last_msg, str):
rtxt = last_msg
elif hasattr(last_msg, "content"):
rtxt = last_msg.content
elif isinstance(last_msg, dict):
rtxt = last_msg.get("content", "")
else:
rtxt = str(last_msg)
else:
rtxt = json.dumps(r)
else:
rtxt = str(r)
corrected = extract_json_from_llm_response(rtxt)
return corrected
except Exception as e2:
logger.exception("JSON resolver also failed: %s", e2)
return {}
# -------------------- Define LangGraph State schema -------------------------
class State(TypedDict):
patient_meta: NotRequired[Dict[str, Any]]
patient_id: str
documents: List[Dict[str, Any]]
medications: List[str]
patientDetails: NotRequired[Dict[str, Any]]
doctorDetails: NotRequired[Dict[str, Any]]
reports: NotRequired[List[Dict[str, Any]]]
overallAnalysis: NotRequired[Dict[str, Any]]
valid: NotRequired[bool]
missing: NotRequired[List[str]]
# -------------------- Node implementations as LangGraph nodes -------------------------
def patient_details_node(state: State) -> dict:
payload = {
"patient_meta": state.get("patient_meta", {}),
"documents": state.get("documents", []),
"medications": state.get("medications", [])
}
logger.info("Running patient_details_node")
out = call_node_agent(PATIENT_NODE_PROMPT, payload)
return {"patientDetails": out.get("patientDetails", {}) if isinstance(out, dict) else {}}
def doctor_details_node(state: State) -> dict:
payload = {
"documents": state.get("documents", []),
"medications": state.get("medications", [])
}
logger.info("Running doctor_details_node")
out = call_node_agent(DOCTOR_NODE_PROMPT, payload)
return {"doctorDetails": out.get("doctorDetails", {}) if isinstance(out, dict) else {}}
def test_report_node(state: State) -> dict:
payload = {
"documents": state.get("documents", []),
"medications": state.get("medications", [])
}
logger.info("Running test_report_node")
out = call_node_agent(TEST_REPORT_NODE_PROMPT, payload)
return {"reports": out.get("reports", []) if isinstance(out, dict) else []}
def analysis_node(state: State) -> dict:
payload = {
"patientDetails": state.get("patientDetails", {}),
"doctorDetails": state.get("doctorDetails", {}),
"reports": state.get("reports", []),
"medications": state.get("medications", [])
}
logger.info("Running analysis_node")
out = call_node_agent(ANALYSIS_NODE_PROMPT, payload)
return {"overallAnalysis": out.get("overallAnalysis", {}) if isinstance(out, dict) else {}}
def condition_loop_node(state: State) -> dict:
payload = {
"patientDetails": state.get("patientDetails", {}),
"doctorDetails": state.get("doctorDetails", {}),
"reports": state.get("reports", []),
"overallAnalysis": state.get("overallAnalysis", {})
}
logger.info("Running condition_loop_node (validation)")
out = call_node_agent(CONDITION_LOOP_NODE_PROMPT, payload)
if isinstance(out, dict) and "valid" in out:
return {"valid": bool(out.get("valid")), "missing": out.get("missing", [])}
missing = []
if not state.get("patientDetails"):
missing.append("patientDetails")
if not state.get("reports"):
missing.append("reports")
return {"valid": len(missing) == 0, "missing": missing}
# -------------------- Build LangGraph StateGraph -------------------------
graph_builder = StateGraph(State)
graph_builder.add_node("patient_details", patient_details_node)
graph_builder.add_node("doctor_details", doctor_details_node)
graph_builder.add_node("test_report", test_report_node)
graph_builder.add_node("analysis", analysis_node)
graph_builder.add_node("condition_loop", condition_loop_node)
graph_builder.add_edge(START, "patient_details")
graph_builder.add_edge("patient_details", "doctor_details")
graph_builder.add_edge("doctor_details", "test_report")
graph_builder.add_edge("test_report", "analysis")
graph_builder.add_edge("analysis", "condition_loop")
graph_builder.add_edge("condition_loop", END)
graph = graph_builder.compile()
# -------------------- Flask app & endpoints -------------------------------
BASE_DIR = Path(__file__).resolve().parent
static_folder = BASE_DIR / "static"
app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static")
CORS(app) # dev convenience; lock down in production
# serve frontend root
@app.route("/", methods=["GET"])
def serve_frontend():
try:
return app.send_static_file("frontend.html")
except Exception:
return "<h3>frontend.html not found in static/ — drop your frontend.html there.</h3>", 404
@app.route("/process_reports", methods=["POST"])
def process_reports():
data = request.get_json(force=True)
patient_id = data.get("patient_id")
filenames = data.get("filenames", [])
extra_patient_meta = data.get("patientDetails", {})
if not patient_id or not filenames:
return jsonify({"error": "missing patient_id or filenames"}), 400
patient_folder = REPORTS_ROOT / str(patient_id)
if not patient_folder.exists() or not patient_folder.is_dir():
return jsonify({"error": f"patient folder not found: {patient_folder}"}), 404
documents = []
combined_text_parts = []
for fname in filenames:
file_path = patient_folder / fname
if not file_path.exists():
logger.warning("file not found: %s", file_path)
continue
try:
elements = partition_pdf(filename=str(file_path))
page_text = "\n".join([el.text for el in elements if hasattr(el, "text") and el.text])
except Exception:
logger.exception("Failed to parse PDF %s", file_path)
page_text = ""
cleaned = clean_notes_with_bloatectomy(page_text, style="remov")
documents.append({
"filename": fname,
"raw_text": page_text,
"cleaned_text": cleaned
})
combined_text_parts.append(cleaned)
if not documents:
return jsonify({"error": "no valid documents found"}), 400
combined_text = "\n\n".join(combined_text_parts)
meds = extract_medications_from_text(combined_text)
initial_state = {
"patient_meta": extra_patient_meta,
"patient_id": patient_id,
"documents": documents,
"medications": meds
}
try:
result_state = graph.invoke(initial_state)
# Validate and fill placeholders if needed
if not result_state.get("valid", True):
missing = result_state.get("missing", [])
logger.info("Validation failed; missing keys: %s", missing)
if "patientDetails" in missing:
result_state["patientDetails"] = extra_patient_meta or {"name": "", "age": "", "sex": "", "pid": patient_id}
if "reports" in missing:
result_state["reports"] = []
# Re-run analysis node to keep overallAnalysis consistent
result_state.update(analysis_node(result_state))
# Re-validate
cond = condition_loop_node(result_state)
result_state.update(cond)
safe_response = {
"patientDetails": result_state.get("patientDetails", {"name": "", "age": "", "sex": "", "pid": patient_id}),
"doctorDetails": result_state.get("doctorDetails", {"referredBy": ""}),
"reports": result_state.get("reports", []),
"overallAnalysis": result_state.get("overallAnalysis", {"summary": "", "recommendations": "", "longTermTrends": ""}),
"_pre_extracted_medications": result_state.get("medications", []),
"_validation": {
"valid": result_state.get("valid", True),
"missing": result_state.get("missing", [])
}
}
return jsonify(safe_response), 200
except Exception as e:
logger.exception("Node pipeline failed")
return jsonify({"error": "Node pipeline failed", "detail": str(e)}), 500
@app.route("/ping", methods=["GET"])
def ping():
return jsonify({"status": "ok"})
if __name__ == "__main__":
port = int(os.getenv("PORT", 5000))
app.run(host="0.0.0.0", port=port, debug=True)