Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
# app.py - Health Reports processing agent (PDF -> cleaned text -> structured JSON) | |
import os | |
import json | |
import logging | |
import re | |
from pathlib import Path | |
from typing import List, Dict, Any | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from dotenv import load_dotenv | |
from unstructured.partition.pdf import partition_pdf | |
# Bloatectomy class (as per the source you provided) | |
from bloatectomy import bloatectomy | |
# LLM / agent | |
from langchain_groq import ChatGroq | |
from langgraph.prebuilt import create_react_agent | |
# LangGraph imports | |
from langgraph.graph import StateGraph, START, END | |
from typing_extensions import TypedDict, NotRequired | |
# --- Logging --------------------------------------------------------------- | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
logger = logging.getLogger("health-agent") | |
# --- Environment & config ------------------------------------------------- | |
load_dotenv() | |
from pathlib import Path | |
REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", r"app\reports")) # e.g. /app/reports/<patient_id>/<file.pdf> | |
SSRI_FILE = Path(os.getenv("SSRI_FILE", r"app\medicationCategories\SSRI_list.txt")) | |
MISC_FILE = Path(os.getenv("MISC_FILE", r"app\medicationCategories\MISC_list.txt")) | |
GROQ_API_KEY = os.getenv("GROQ_API_KEY", None) | |
# --- LLM setup ------------------------------------------------------------- | |
llm = ChatGroq( | |
model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"), | |
temperature=0.0, | |
max_tokens=None, | |
) | |
# Top-level strict system prompt for report JSON pieces (each node will use a more specific prompt) | |
NODE_BASE_INSTRUCTIONS = """ | |
You are HealthAI — a clinical assistant producing JSON for downstream processing. | |
Produce only valid JSON (no extra text). Follow field types exactly. If missing data, return empty strings or empty arrays. | |
Be conservative: do not assert diagnoses; provide suggestions and ask physician confirmation where needed. | |
""" | |
# Build a generic agent and a JSON resolver agent (to fix broken JSON from LLM) | |
agent = create_react_agent(model=llm, tools=[], prompt=NODE_BASE_INSTRUCTIONS) | |
agent_json_resolver = create_react_agent(model=llm, tools=[], prompt=""" | |
You are a JSON fixer. Input: a possibly-malformed JSON-like text. Output: valid JSON only (enclosed in triple backticks). | |
Fix missing quotes, trailing commas, unescaped newlines, stray assistant labels, and ensure schema compliance. | |
""") | |
# -------------------- JSON extraction / sanitizer --------------------------- | |
def extract_json_from_llm_response(raw_response: str) -> dict: | |
""" | |
Try extracting a JSON object from raw LLM text. Performs common cleanups seen in LLM outputs. | |
Raises JSONDecodeError if parsing still fails. | |
""" | |
# --- 1) Pull out the JSON code-block if present --- | |
md = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response) | |
json_string = md.group(1).strip() if md else raw_response | |
# --- 2) Trim to the outermost { … } so we drop any prefix/suffix junk --- | |
first, last = json_string.find('{'), json_string.rfind('}') | |
if 0 <= first < last: | |
json_string = json_string[first:last+1] | |
# --- 3) PRE-CLEANUP: remove rogue assistant labels, fix boolean quotes --- | |
json_string = re.sub(r'\b\w+\s*{', '{', json_string) | |
json_string = re.sub(r'"assistant"\s*:', '', json_string) | |
json_string = re.sub(r'\b(false|true)"', r'\1', json_string) | |
# --- 4) Escape embedded quotes in long string fields (best-effort) --- | |
def _esc(m): | |
prefix, body = m.group(1), m.group(2) | |
return prefix + body.replace('"', r'\"') | |
json_string = re.sub( | |
r'("logic"\s*:\s*")([\s\S]+?)(?=",\s*"[A-Za-z_]\w*"\s*:\s*)', | |
_esc, | |
json_string | |
) | |
# --- 5) Remove trailing commas before } or ] --- | |
json_string = re.sub(r',\s*(?=[}\],])', '', json_string) | |
json_string = re.sub(r',\s*,', ',', json_string) | |
# --- 6) Balance braces if obvious excess --- | |
ob, cb = json_string.count('{'), json_string.count('}') | |
if cb > ob: | |
excess = cb - ob | |
json_string = json_string.rstrip()[:-excess] | |
# --- 7) Escape literal newlines inside strings so json.loads can parse --- | |
def _escape_newlines_in_strings(s: str) -> str: | |
return re.sub( | |
r'"((?:[^"\\]|\\.)*?)"', | |
lambda m: '"' + m.group(1).replace('\n', '\\n').replace('\r', '\\r') + '"', | |
s, | |
flags=re.DOTALL | |
) | |
json_string = _escape_newlines_in_strings(json_string) | |
# Final parse | |
return json.loads(json_string) | |
# -------------------- Utility: Bloatectomy wrapper ------------------------ | |
def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str: | |
""" | |
Uses the bloatectomy class to remove duplicates. | |
style: 'highlight'|'bold'|'remov' ; we use 'remov' to delete duplicates. | |
Returns cleaned text (single string). | |
""" | |
try: | |
b = bloatectomy(text, style=style, output="html") | |
tokens = getattr(b, "tokens", None) | |
if not tokens: | |
return text | |
return "\n".join(tokens) | |
except Exception: | |
logger.exception("Bloatectomy cleaning failed; returning original text") | |
return text | |
# --------------- Utility: medication extraction (adapted) ----------------- | |
def readDrugs_from_file(path: Path): | |
if not path.exists(): | |
return {}, [] | |
txt = path.read_text(encoding="utf-8", errors="ignore") | |
generics = re.findall(r"^(.*?)\|", txt, re.MULTILINE) | |
generics = [g.lower() for g in generics if g] | |
lines = [ln.strip().lower() for ln in txt.splitlines() if ln.strip()] | |
return dict(zip(generics, lines)), generics | |
def addToDrugs_line(line: str, drugs_flags: List[int], listing: Dict[str,str], genList: List[str]) -> List[int]: | |
gen_index = {g:i for i,g in enumerate(genList)} | |
for generic, pattern_line in listing.items(): | |
try: | |
if re.search(pattern_line, line, re.I): | |
idx = gen_index.get(generic) | |
if idx is not None: | |
drugs_flags[idx] = 1 | |
except re.error: | |
continue | |
return drugs_flags | |
def extract_medications_from_text(text: str) -> List[str]: | |
ssri_map, ssri_generics = readDrugs_from_file(SSRI_FILE) | |
misc_map, misc_generics = readDrugs_from_file(MISC_FILE) | |
combined_map = {**ssri_map, **misc_map} | |
combined_generics = [] | |
if ssri_generics: | |
combined_generics.extend(ssri_generics) | |
if misc_generics: | |
combined_generics.extend(misc_generics) | |
flags = [0]* len(combined_generics) | |
meds_found = set() | |
for ln in text.splitlines(): | |
ln = ln.strip() | |
if not ln: | |
continue | |
if combined_map: | |
flags = addToDrugs_line(ln, flags, combined_map, combined_generics) | |
m = re.search(r"\b(Rx|Drug|Medication|Prescribed|Tablet)\s*[:\-]?\s*([A-Za-z0-9\-\s/\.]+)", ln, re.I) | |
if m: | |
meds_found.add(m.group(2).strip()) | |
m2 = re.findall(r"\b([A-Z][a-z0-9\-]{2,}\s*(?:[0-9]{1,4}\s*(?:mg|mcg|g|IU))?)", ln) | |
for s in m2: | |
if re.search(r"\b(mg|mcg|g|IU)\b", s, re.I): | |
meds_found.add(s.strip()) | |
for i, f in enumerate(flags): | |
if f == 1: | |
meds_found.add(combined_generics[i]) | |
return list(meds_found) | |
# -------------------- Node prompts -------------------------- | |
PATIENT_NODE_PROMPT = """ | |
You will extract patientDetails from the provided document texts. | |
Return ONLY JSON with this exact shape: | |
{ "patientDetails": {"name": "", "age": "", "sex": "", "pid": ""} } | |
Fill fields using text evidence or leave empty strings. | |
""" | |
DOCTOR_NODE_PROMPT = """ | |
You will extract doctorDetails found in the documents. | |
Return ONLY JSON with this exact shape: | |
{ "doctorDetails": {"referredBy": ""} } | |
""" | |
TEST_REPORT_NODE_PROMPT = """ | |
You will extract per-test structured results from the documents. | |
Return ONLY JSON with this exact shape: | |
{ | |
"reports": [ | |
{ | |
"testName": "", | |
"dateReported": "", | |
"timeReported": "", | |
"abnormalFindings": [ | |
{"investigation": "", "result": 0, "unit": "", "status": "", "referenceValue": ""} | |
], | |
"interpretation": "", | |
"trends": [] | |
} | |
] | |
} | |
- Include only findings that are outside reference ranges OR explicitly called 'abnormal' in the report. | |
- For result numeric parsing, prefer numeric values; if not numeric, keep original string. | |
- Use statuses: Low, High, Borderline, Positive, Negative, Normal. | |
""" | |
ANALYSIS_NODE_PROMPT = """ | |
You will create an overallAnalysis based on the extracted reports (the agent will give you the 'reports' JSON). | |
Return ONLY JSON: | |
{ "overallAnalysis": { "summary": "", "recommendations": "", "longTermTrends": "",""risk_prediction": "","drug_interaction": "" } } | |
Be conservative, evidence-based, and suggest follow-up steps for physicians. | |
""" | |
CONDITION_LOOP_NODE_PROMPT = """ | |
Validation and condition node: | |
Input: partial JSON (patientDetails, doctorDetails, reports, overallAnalysis). | |
Task: Check required keys exist and that each report has at least testName and abnormalFindings list. | |
Return ONLY JSON: | |
{ "valid": true, "missing": [] } | |
If missing fields, list keys in 'missing'. Do NOT modify content. | |
""" | |
# -------------------- Node helpers ------------------------- | |
def call_node_agent(node_prompt: str, payload: dict) -> dict: | |
""" | |
Call the generic agent with a targeted node prompt and the payload. | |
Tries to parse JSON. If parsing fails, uses the JSON resolver agent once. | |
""" | |
try: | |
content = { | |
"prompt": node_prompt, | |
"payload": payload | |
} | |
resp = agent.invoke({"messages": [{"role": "user", "content": json.dumps(content)}]}) | |
# Extract raw text from AIMessage or other response types | |
raw = None | |
if isinstance(resp, str): | |
raw = resp | |
elif hasattr(resp, "content"): # AIMessage or similar | |
raw = resp.content | |
elif isinstance(resp, dict): | |
msgs = resp.get("messages") | |
if msgs: | |
last_msg = msgs[-1] | |
if isinstance(last_msg, str): | |
raw = last_msg | |
elif hasattr(last_msg, "content"): | |
raw = last_msg.content | |
elif isinstance(last_msg, dict): | |
raw = last_msg.get("content", "") | |
else: | |
raw = str(last_msg) | |
else: | |
raw = json.dumps(resp) | |
else: | |
raw = str(resp) | |
parsed = extract_json_from_llm_response(raw) | |
return parsed | |
except Exception as e: | |
logger.warning("Node agent JSON parse failed: %s. Attempting JSON resolver.", e) | |
try: | |
resolver_prompt = f"Fix this JSON. Input:\n```json\n{raw}\n```\nReturn valid JSON only." | |
r = agent_json_resolver.invoke({"messages": [{"role": "user", "content": resolver_prompt}]}) | |
rtxt = None | |
if isinstance(r, str): | |
rtxt = r | |
elif hasattr(r, "content"): | |
rtxt = r.content | |
elif isinstance(r, dict): | |
msgs = r.get("messages") | |
if msgs: | |
last_msg = msgs[-1] | |
if isinstance(last_msg, str): | |
rtxt = last_msg | |
elif hasattr(last_msg, "content"): | |
rtxt = last_msg.content | |
elif isinstance(last_msg, dict): | |
rtxt = last_msg.get("content", "") | |
else: | |
rtxt = str(last_msg) | |
else: | |
rtxt = json.dumps(r) | |
else: | |
rtxt = str(r) | |
corrected = extract_json_from_llm_response(rtxt) | |
return corrected | |
except Exception as e2: | |
logger.exception("JSON resolver also failed: %s", e2) | |
return {} | |
# -------------------- Define LangGraph State schema ------------------------- | |
class State(TypedDict): | |
patient_meta: NotRequired[Dict[str, Any]] | |
patient_id: str | |
documents: List[Dict[str, Any]] | |
medications: List[str] | |
patientDetails: NotRequired[Dict[str, Any]] | |
doctorDetails: NotRequired[Dict[str, Any]] | |
reports: NotRequired[List[Dict[str, Any]]] | |
overallAnalysis: NotRequired[Dict[str, Any]] | |
valid: NotRequired[bool] | |
missing: NotRequired[List[str]] | |
# -------------------- Node implementations as LangGraph nodes ------------------------- | |
def patient_details_node(state: State) -> dict: | |
payload = { | |
"patient_meta": state.get("patient_meta", {}), | |
"documents": state.get("documents", []), | |
"medications": state.get("medications", []) | |
} | |
logger.info("Running patient_details_node") | |
out = call_node_agent(PATIENT_NODE_PROMPT, payload) | |
return {"patientDetails": out.get("patientDetails", {}) if isinstance(out, dict) else {}} | |
def doctor_details_node(state: State) -> dict: | |
payload = { | |
"documents": state.get("documents", []), | |
"medications": state.get("medications", []) | |
} | |
logger.info("Running doctor_details_node") | |
out = call_node_agent(DOCTOR_NODE_PROMPT, payload) | |
return {"doctorDetails": out.get("doctorDetails", {}) if isinstance(out, dict) else {}} | |
def test_report_node(state: State) -> dict: | |
payload = { | |
"documents": state.get("documents", []), | |
"medications": state.get("medications", []) | |
} | |
logger.info("Running test_report_node") | |
out = call_node_agent(TEST_REPORT_NODE_PROMPT, payload) | |
return {"reports": out.get("reports", []) if isinstance(out, dict) else []} | |
def analysis_node(state: State) -> dict: | |
payload = { | |
"patientDetails": state.get("patientDetails", {}), | |
"doctorDetails": state.get("doctorDetails", {}), | |
"reports": state.get("reports", []), | |
"medications": state.get("medications", []) | |
} | |
logger.info("Running analysis_node") | |
out = call_node_agent(ANALYSIS_NODE_PROMPT, payload) | |
return {"overallAnalysis": out.get("overallAnalysis", {}) if isinstance(out, dict) else {}} | |
def condition_loop_node(state: State) -> dict: | |
payload = { | |
"patientDetails": state.get("patientDetails", {}), | |
"doctorDetails": state.get("doctorDetails", {}), | |
"reports": state.get("reports", []), | |
"overallAnalysis": state.get("overallAnalysis", {}) | |
} | |
logger.info("Running condition_loop_node (validation)") | |
out = call_node_agent(CONDITION_LOOP_NODE_PROMPT, payload) | |
if isinstance(out, dict) and "valid" in out: | |
return {"valid": bool(out.get("valid")), "missing": out.get("missing", [])} | |
missing = [] | |
if not state.get("patientDetails"): | |
missing.append("patientDetails") | |
if not state.get("reports"): | |
missing.append("reports") | |
return {"valid": len(missing) == 0, "missing": missing} | |
# -------------------- Build LangGraph StateGraph ------------------------- | |
graph_builder = StateGraph(State) | |
graph_builder.add_node("patient_details", patient_details_node) | |
graph_builder.add_node("doctor_details", doctor_details_node) | |
graph_builder.add_node("test_report", test_report_node) | |
graph_builder.add_node("analysis", analysis_node) | |
graph_builder.add_node("condition_loop", condition_loop_node) | |
graph_builder.add_edge(START, "patient_details") | |
graph_builder.add_edge("patient_details", "doctor_details") | |
graph_builder.add_edge("doctor_details", "test_report") | |
graph_builder.add_edge("test_report", "analysis") | |
graph_builder.add_edge("analysis", "condition_loop") | |
graph_builder.add_edge("condition_loop", END) | |
graph = graph_builder.compile() | |
# -------------------- Flask app & endpoints ------------------------------- | |
BASE_DIR = Path(__file__).resolve().parent | |
static_folder = BASE_DIR / "static" | |
app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static") | |
CORS(app) # dev convenience; lock down in production | |
# serve frontend root | |
def serve_frontend(): | |
try: | |
return app.send_static_file("frontend.html") | |
except Exception: | |
return "<h3>frontend.html not found in static/ — drop your frontend.html there.</h3>", 404 | |
def process_reports(): | |
data = request.get_json(force=True) | |
patient_id = data.get("patient_id") | |
filenames = data.get("filenames", []) | |
extra_patient_meta = data.get("patientDetails", {}) | |
if not patient_id or not filenames: | |
return jsonify({"error": "missing patient_id or filenames"}), 400 | |
patient_folder = REPORTS_ROOT / str(patient_id) | |
if not patient_folder.exists() or not patient_folder.is_dir(): | |
return jsonify({"error": f"patient folder not found: {patient_folder}"}), 404 | |
documents = [] | |
combined_text_parts = [] | |
for fname in filenames: | |
file_path = patient_folder / fname | |
if not file_path.exists(): | |
logger.warning("file not found: %s", file_path) | |
continue | |
try: | |
elements = partition_pdf(filename=str(file_path)) | |
page_text = "\n".join([el.text for el in elements if hasattr(el, "text") and el.text]) | |
except Exception: | |
logger.exception("Failed to parse PDF %s", file_path) | |
page_text = "" | |
cleaned = clean_notes_with_bloatectomy(page_text, style="remov") | |
documents.append({ | |
"filename": fname, | |
"raw_text": page_text, | |
"cleaned_text": cleaned | |
}) | |
combined_text_parts.append(cleaned) | |
if not documents: | |
return jsonify({"error": "no valid documents found"}), 400 | |
combined_text = "\n\n".join(combined_text_parts) | |
meds = extract_medications_from_text(combined_text) | |
initial_state = { | |
"patient_meta": extra_patient_meta, | |
"patient_id": patient_id, | |
"documents": documents, | |
"medications": meds | |
} | |
try: | |
result_state = graph.invoke(initial_state) | |
# Validate and fill placeholders if needed | |
if not result_state.get("valid", True): | |
missing = result_state.get("missing", []) | |
logger.info("Validation failed; missing keys: %s", missing) | |
if "patientDetails" in missing: | |
result_state["patientDetails"] = extra_patient_meta or {"name": "", "age": "", "sex": "", "pid": patient_id} | |
if "reports" in missing: | |
result_state["reports"] = [] | |
# Re-run analysis node to keep overallAnalysis consistent | |
result_state.update(analysis_node(result_state)) | |
# Re-validate | |
cond = condition_loop_node(result_state) | |
result_state.update(cond) | |
safe_response = { | |
"patientDetails": result_state.get("patientDetails", {"name": "", "age": "", "sex": "", "pid": patient_id}), | |
"doctorDetails": result_state.get("doctorDetails", {"referredBy": ""}), | |
"reports": result_state.get("reports", []), | |
"overallAnalysis": result_state.get("overallAnalysis", {"summary": "", "recommendations": "", "longTermTrends": ""}), | |
"_pre_extracted_medications": result_state.get("medications", []), | |
"_validation": { | |
"valid": result_state.get("valid", True), | |
"missing": result_state.get("missing", []) | |
} | |
} | |
return jsonify(safe_response), 200 | |
except Exception as e: | |
logger.exception("Node pipeline failed") | |
return jsonify({"error": "Node pipeline failed", "detail": str(e)}), 500 | |
def ping(): | |
return jsonify({"status": "ok"}) | |
if __name__ == "__main__": | |
port = int(os.getenv("PORT", 5000)) | |
app.run(host="0.0.0.0", port=port, debug=True) | |