Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,117 +1,48 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
-
# app.py - Health Reports processing agent (PDF -> cleaned text -> structured JSON)
|
3 |
-
|
4 |
import os
|
5 |
import json
|
6 |
import logging
|
7 |
import re
|
|
|
8 |
from pathlib import Path
|
9 |
-
from
|
10 |
-
from werkzeug.utils import secure_filename
|
11 |
from flask import Flask, request, jsonify
|
12 |
from flask_cors import CORS
|
13 |
from dotenv import load_dotenv
|
14 |
-
from unstructured.partition.pdf import partition_pdf
|
15 |
-
|
16 |
-
# Bloatectomy class (as per the source you provided)
|
17 |
from bloatectomy import bloatectomy
|
18 |
-
|
19 |
-
# LLM / agent
|
20 |
from langchain_groq import ChatGroq
|
21 |
-
from langgraph.prebuilt import create_react_agent
|
22 |
-
|
23 |
-
# LangGraph imports
|
24 |
-
from langgraph.graph import StateGraph, START, END
|
25 |
from typing_extensions import TypedDict, NotRequired
|
26 |
|
27 |
-
# --- Logging
|
28 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
29 |
-
logger = logging.getLogger("
|
30 |
|
31 |
-
# ---
|
32 |
load_dotenv()
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
llm = ChatGroq(
|
41 |
model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
|
42 |
temperature=0.0,
|
43 |
-
max_tokens=
|
|
|
44 |
)
|
45 |
|
46 |
-
# Top-level strict system prompt for report JSON pieces (each node will use a more specific prompt)
|
47 |
-
NODE_BASE_INSTRUCTIONS = """
|
48 |
-
You are HealthAI — a clinical assistant producing JSON for downstream processing.
|
49 |
-
Produce only valid JSON (no extra text). Follow field types exactly. If missing data, return empty strings or empty arrays.
|
50 |
-
Be conservative: do not assert diagnoses; provide suggestions and ask physician confirmation where needed.
|
51 |
-
"""
|
52 |
-
|
53 |
-
# Build a generic agent and a JSON resolver agent (to fix broken JSON from LLM)
|
54 |
-
agent = create_react_agent(model=llm, tools=[], prompt=NODE_BASE_INSTRUCTIONS)
|
55 |
-
agent_json_resolver = create_react_agent(model=llm, tools=[], prompt="""
|
56 |
-
You are a JSON fixer. Input: a possibly-malformed JSON-like text. Output: valid JSON only (enclosed in triple backticks).
|
57 |
-
Fix missing quotes, trailing commas, unescaped newlines, stray assistant labels, and ensure schema compliance.
|
58 |
-
""")
|
59 |
-
|
60 |
-
# -------------------- JSON extraction / sanitizer ---------------------------
|
61 |
-
def extract_json_from_llm_response(raw_response: str) -> dict:
|
62 |
-
try:
|
63 |
-
# --- 1) Pull out the JSON code-block if present ---
|
64 |
-
md = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response)
|
65 |
-
json_string = md.group(1).strip() if md else raw_response
|
66 |
-
|
67 |
-
# --- 2) Trim to the outermost { … } so we drop any prefix/suffix junk ---
|
68 |
-
first, last = json_string.find('{'), json_string.rfind('}')
|
69 |
-
if 0 <= first < last:
|
70 |
-
json_string = json_string[first:last+1]
|
71 |
-
|
72 |
-
# --- 3) PRE-CLEANUP: remove rogue assistant labels, fix boolean quotes ---
|
73 |
-
json_string = re.sub(r'\b\w+\s*{', '{', json_string)
|
74 |
-
json_string = re.sub(r'"assistant"\s*:', '', json_string)
|
75 |
-
json_string = re.sub(r'\b(false|true)"', r'\1', json_string)
|
76 |
-
|
77 |
-
# --- 4) Escape embedded quotes in long string fields (best-effort) ---
|
78 |
-
def _esc(m):
|
79 |
-
prefix, body = m.group(1), m.group(2)
|
80 |
-
return prefix + body.replace('"', r'\"')
|
81 |
-
json_string = re.sub(
|
82 |
-
r'("logic"\s*:\s*")([\s\S]+?)(?=",\s*"[A-Za-z_]\w*"\s*:\s*)',
|
83 |
-
_esc,
|
84 |
-
json_string
|
85 |
-
)
|
86 |
-
|
87 |
-
# --- 5) Remove trailing commas before } or ] ---
|
88 |
-
json_string = re.sub(r',\s*(?=[}\],])', '', json_string)
|
89 |
-
json_string = re.sub(r',\s*,', ',', json_string)
|
90 |
-
|
91 |
-
# --- 6) Balance braces if obvious excess ---
|
92 |
-
ob, cb = json_string.count('{'), json_string.count('}')
|
93 |
-
if cb > ob:
|
94 |
-
excess = cb - ob
|
95 |
-
json_string = json_string.rstrip()[:-excess]
|
96 |
-
|
97 |
-
# --- 7) Escape literal newlines inside strings so json.loads can parse ---
|
98 |
-
def _escape_newlines_in_strings(s: str) -> str:
|
99 |
-
return re.sub(
|
100 |
-
r'"((?:[^"\\]|\\.)*?)"',
|
101 |
-
lambda m: '"' + m.group(1).replace('\n', '\\n').replace('\r', '\\r') + '"',
|
102 |
-
s,
|
103 |
-
flags=re.DOTALL
|
104 |
-
)
|
105 |
-
json_string = _escape_newlines_in_strings(json_string)
|
106 |
-
|
107 |
-
# Final parse
|
108 |
-
return json.loads(json_string)
|
109 |
-
except Exception as e:
|
110 |
-
logger.error(f"Failed to extract JSON from LLM response: {e}")
|
111 |
-
raise
|
112 |
-
|
113 |
-
# -------------------- Utility: Bloatectomy wrapper ------------------------
|
114 |
def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
|
|
|
115 |
try:
|
116 |
b = bloatectomy(text, style=style, output="html")
|
117 |
tokens = getattr(b, "tokens", None)
|
@@ -122,480 +53,185 @@ def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
|
|
122 |
logger.exception("Bloatectomy cleaning failed; returning original text")
|
123 |
return text
|
124 |
|
125 |
-
#
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
continue
|
150 |
-
return drugs_flags
|
151 |
-
except Exception:
|
152 |
-
logger.exception("Error in addToDrugs_line")
|
153 |
-
return drugs_flags
|
154 |
-
|
155 |
-
def extract_medications_from_text(text: str) -> List[str]:
|
156 |
-
try:
|
157 |
-
ssri_map, ssri_generics = readDrugs_from_file(SSRI_FILE)
|
158 |
-
misc_map, misc_generics = readDrugs_from_file(MISC_FILE)
|
159 |
-
combined_map = {**ssri_map, **misc_map}
|
160 |
-
combined_generics = []
|
161 |
-
if ssri_generics:
|
162 |
-
combined_generics.extend(ssri_generics)
|
163 |
-
if misc_generics:
|
164 |
-
combined_generics.extend(misc_generics)
|
165 |
-
|
166 |
-
flags = [0]* len(combined_generics)
|
167 |
-
meds_found = set()
|
168 |
-
for ln in text.splitlines():
|
169 |
-
ln = ln.strip()
|
170 |
-
if not ln:
|
171 |
-
continue
|
172 |
-
if combined_map:
|
173 |
-
flags = addToDrugs_line(ln, flags, combined_map, combined_generics)
|
174 |
-
m = re.search(r"\b(Rx|Drug|Medication|Prescribed|Tablet)\s*[:\-]?\s*([A-Za-z0-9\-\s/\.]+)", ln, re.I)
|
175 |
-
if m:
|
176 |
-
meds_found.add(m.group(2).strip())
|
177 |
-
m2 = re.findall(r"\b([A-Z][a-z0-9\-]{2,}\s*(?:[0-9]{1,4}\s*(?:mg|mcg|g|IU))?)", ln)
|
178 |
-
for s in m2:
|
179 |
-
if re.search(r"\b(mg|mcg|g|IU)\b", s, re.I):
|
180 |
-
meds_found.add(s.strip())
|
181 |
-
for i, f in enumerate(flags):
|
182 |
-
if f == 1:
|
183 |
-
meds_found.add(combined_generics[i])
|
184 |
-
return list(meds_found)
|
185 |
-
except Exception:
|
186 |
-
logger.exception("Failed to extract medications from text")
|
187 |
-
return []
|
188 |
-
|
189 |
-
# -------------------- Node prompts --------------------------
|
190 |
-
PATIENT_NODE_PROMPT = """
|
191 |
-
You will extract patientDetails from the provided document texts.
|
192 |
-
Return ONLY JSON with this exact shape:
|
193 |
-
{ "patientDetails": {"name": "", "age": "", "sex": "", "pid": ""} }
|
194 |
-
Fill fields using text evidence or leave empty strings.
|
195 |
"""
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
"""
|
|
|
|
|
|
|
202 |
|
203 |
-
|
204 |
-
|
205 |
-
Return ONLY JSON with this exact shape:
|
206 |
-
{
|
207 |
-
"reports": [
|
208 |
-
{
|
209 |
-
"testName": "",
|
210 |
-
"dateReported": "",
|
211 |
-
"timeReported": "",
|
212 |
-
"abnormalFindings": [
|
213 |
-
{"investigation": "", "result": 0, "unit": "", "status": "", "referenceValue": ""}
|
214 |
-
],
|
215 |
-
"interpretation": "",
|
216 |
-
"trends": []
|
217 |
-
}
|
218 |
-
]
|
219 |
-
}
|
220 |
-
- Include only findings that are outside reference ranges OR explicitly called 'abnormal' in the report.
|
221 |
-
- For result numeric parsing, prefer numeric values; if not numeric, keep original string.
|
222 |
-
- Use statuses: Low, High, Borderline, Positive, Negative, Normal.
|
223 |
-
"""
|
224 |
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
{ "overallAnalysis": { "summary": "", "recommendations": "", "longTermTrends": "",""risk_prediction": "","drug_interaction": "" } }
|
229 |
-
Be conservative, evidence-based, and suggest follow-up steps for physicians.
|
230 |
-
"""
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
""
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
-
# -------------------- Node helpers -------------------------
|
242 |
-
def call_node_agent(node_prompt: str, payload: dict) -> dict:
|
243 |
-
"""
|
244 |
-
Call the generic agent with a targeted node prompt and the payload.
|
245 |
-
Tries to parse JSON. If parsing fails, uses the JSON resolver agent once.
|
246 |
-
"""
|
247 |
try:
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
resp = agent.invoke({"messages": [{"role": "user", "content": json.dumps(content)}]})
|
253 |
-
|
254 |
-
# Extract raw text from AIMessage or other response types
|
255 |
-
raw = None
|
256 |
-
if isinstance(resp, str):
|
257 |
-
raw = resp
|
258 |
-
elif hasattr(resp, "content"): # AIMessage or similar
|
259 |
-
raw = resp.content
|
260 |
-
elif isinstance(resp, dict):
|
261 |
-
msgs = resp.get("messages")
|
262 |
-
if msgs:
|
263 |
-
last_msg = msgs[-1]
|
264 |
-
if isinstance(last_msg, str):
|
265 |
-
raw = last_msg
|
266 |
-
elif hasattr(last_msg, "content"):
|
267 |
-
raw = last_msg.content
|
268 |
-
elif isinstance(last_msg, dict):
|
269 |
-
raw = last_msg.get("content", "")
|
270 |
-
else:
|
271 |
-
raw = str(last_msg)
|
272 |
-
else:
|
273 |
-
raw = json.dumps(resp)
|
274 |
-
else:
|
275 |
-
raw = str(resp)
|
276 |
|
277 |
-
|
|
|
|
|
|
|
278 |
return parsed
|
|
|
|
|
|
|
279 |
|
280 |
-
|
281 |
-
logger.warning("Node agent JSON parse failed: %s. Attempting JSON resolver.", e)
|
282 |
-
try:
|
283 |
-
resolver_prompt = f"Fix this JSON. Input:\n```json\n{raw}\n```\nReturn valid JSON only."
|
284 |
-
r = agent_json_resolver.invoke({"messages": [{"role": "user", "content": resolver_prompt}]})
|
285 |
-
|
286 |
-
rtxt = None
|
287 |
-
if isinstance(r, str):
|
288 |
-
rtxt = r
|
289 |
-
elif hasattr(r, "content"):
|
290 |
-
rtxt = r.content
|
291 |
-
elif isinstance(r, dict):
|
292 |
-
msgs = r.get("messages")
|
293 |
-
if msgs:
|
294 |
-
last_msg = msgs[-1]
|
295 |
-
if isinstance(last_msg, str):
|
296 |
-
rtxt = last_msg
|
297 |
-
elif hasattr(last_msg, "content"):
|
298 |
-
rtxt = last_msg.content
|
299 |
-
elif isinstance(last_msg, dict):
|
300 |
-
rtxt = last_msg.get("content", "")
|
301 |
-
else:
|
302 |
-
rtxt = str(last_msg)
|
303 |
-
else:
|
304 |
-
rtxt = json.dumps(r)
|
305 |
-
else:
|
306 |
-
rtxt = str(r)
|
307 |
-
|
308 |
-
corrected = extract_json_from_llm_response(rtxt)
|
309 |
-
return corrected
|
310 |
-
except Exception as e2:
|
311 |
-
logger.exception("JSON resolver also failed: %s", e2)
|
312 |
-
return {}
|
313 |
-
|
314 |
-
# -------------------- Define LangGraph State schema -------------------------
|
315 |
-
class State(TypedDict):
|
316 |
-
patient_meta: NotRequired[Dict[str, Any]]
|
317 |
-
patient_id: str
|
318 |
-
documents: List[Dict[str, Any]]
|
319 |
-
medications: List[str]
|
320 |
-
patientDetails: NotRequired[Dict[str, Any]]
|
321 |
-
doctorDetails: NotRequired[Dict[str, Any]]
|
322 |
-
reports: NotRequired[List[Dict[str, Any]]]
|
323 |
-
overallAnalysis: NotRequired[Dict[str, Any]]
|
324 |
-
valid: NotRequired[bool]
|
325 |
-
missing: NotRequired[List[str]]
|
326 |
-
|
327 |
-
# -------------------- Node implementations as LangGraph nodes -------------------------
|
328 |
-
def patient_details_node(state: State) -> dict:
|
329 |
-
payload = {
|
330 |
-
"patient_meta": state.get("patient_meta", {}),
|
331 |
-
"documents": state.get("documents", []),
|
332 |
-
"medications": state.get("medications", [])
|
333 |
-
}
|
334 |
-
logger.info("Running patient_details_node")
|
335 |
-
out = call_node_agent(PATIENT_NODE_PROMPT, payload)
|
336 |
-
return {"patientDetails": out.get("patientDetails", {}) if isinstance(out, dict) else {}}
|
337 |
-
|
338 |
-
def doctor_details_node(state: State) -> dict:
|
339 |
-
payload = {
|
340 |
-
"documents": state.get("documents", []),
|
341 |
-
"medications": state.get("medications", [])
|
342 |
-
}
|
343 |
-
logger.info("Running doctor_details_node")
|
344 |
-
out = call_node_agent(DOCTOR_NODE_PROMPT, payload)
|
345 |
-
return {"doctorDetails": out.get("doctorDetails", {}) if isinstance(out, dict) else {}}
|
346 |
-
|
347 |
-
def test_report_node(state: State) -> dict:
|
348 |
-
payload = {
|
349 |
-
"documents": state.get("documents", []),
|
350 |
-
"medications": state.get("medications", [])
|
351 |
-
}
|
352 |
-
logger.info("Running test_report_node")
|
353 |
-
out = call_node_agent(TEST_REPORT_NODE_PROMPT, payload)
|
354 |
-
return {"reports": out.get("reports", []) if isinstance(out, dict) else []}
|
355 |
-
|
356 |
-
def analysis_node(state: State) -> dict:
|
357 |
-
payload = {
|
358 |
-
"patientDetails": state.get("patientDetails", {}),
|
359 |
-
"doctorDetails": state.get("doctorDetails", {}),
|
360 |
-
"reports": state.get("reports", []),
|
361 |
-
"medications": state.get("medications", [])
|
362 |
-
}
|
363 |
-
logger.info("Running analysis_node")
|
364 |
-
out = call_node_agent(ANALYSIS_NODE_PROMPT, payload)
|
365 |
-
return {"overallAnalysis": out.get("overallAnalysis", {}) if isinstance(out, dict) else {}}
|
366 |
-
|
367 |
-
def condition_loop_node(state: State) -> dict:
|
368 |
-
payload = {
|
369 |
-
"patientDetails": state.get("patientDetails", {}),
|
370 |
-
"doctorDetails": state.get("doctorDetails", {}),
|
371 |
-
"reports": state.get("reports", []),
|
372 |
-
"overallAnalysis": state.get("overallAnalysis", {})
|
373 |
-
}
|
374 |
-
logger.info("Running condition_loop_node (validation)")
|
375 |
-
out = call_node_agent(CONDITION_LOOP_NODE_PROMPT, payload)
|
376 |
-
if isinstance(out, dict) and "valid" in out:
|
377 |
-
return {"valid": bool(out.get("valid")), "missing": out.get("missing", [])}
|
378 |
-
missing = []
|
379 |
-
if not state.get("patientDetails"):
|
380 |
-
missing.append("patientDetails")
|
381 |
-
if not state.get("reports"):
|
382 |
-
missing.append("reports")
|
383 |
-
return {"valid": len(missing) == 0, "missing": missing}
|
384 |
-
|
385 |
-
# -------------------- Build LangGraph StateGraph -------------------------
|
386 |
-
graph_builder = StateGraph(State)
|
387 |
-
|
388 |
-
graph_builder.add_node("patient_details", patient_details_node)
|
389 |
-
graph_builder.add_node("doctor_details", doctor_details_node)
|
390 |
-
graph_builder.add_node("test_report", test_report_node)
|
391 |
-
graph_builder.add_node("analysis", analysis_node)
|
392 |
-
graph_builder.add_node("condition_loop", condition_loop_node)
|
393 |
-
|
394 |
-
graph_builder.add_edge(START, "patient_details")
|
395 |
-
graph_builder.add_edge("patient_details", "doctor_details")
|
396 |
-
graph_builder.add_edge("doctor_details", "test_report")
|
397 |
-
graph_builder.add_edge("test_report", "analysis")
|
398 |
-
graph_builder.add_edge("analysis", "condition_loop")
|
399 |
-
graph_builder.add_edge("condition_loop", END)
|
400 |
-
|
401 |
-
graph = graph_builder.compile()
|
402 |
-
|
403 |
-
# -------------------- Flask app & endpoints -------------------------------
|
404 |
-
# -------------------- Flask app & endpoints -------------------------------
|
405 |
-
BASE_DIR = Path(__file__).resolve().parent
|
406 |
-
static_folder = BASE_DIR / "static"
|
407 |
-
app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static")
|
408 |
-
CORS(app) # dev convenience; lock down in production
|
409 |
-
|
410 |
-
# serve frontend root
|
411 |
@app.route("/", methods=["GET"])
|
412 |
def serve_frontend():
|
|
|
413 |
try:
|
414 |
-
return app.send_static_file("
|
415 |
-
except Exception
|
416 |
-
|
417 |
-
return "<h3>frontend.html not found in static/ — drop your frontend.html there.</h3>", 404
|
418 |
|
419 |
-
@app.route("/
|
420 |
-
def
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
return jsonify({"error": "Invalid JSON request"}), 400
|
426 |
|
427 |
patient_id = data.get("patient_id")
|
428 |
-
|
429 |
-
|
430 |
|
431 |
-
|
432 |
-
|
433 |
|
434 |
-
|
435 |
-
|
436 |
-
return jsonify({"error": f"patient folder not found: {patient_folder}"}), 404
|
437 |
-
|
438 |
-
documents = []
|
439 |
combined_text_parts = []
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
if not file_path.exists():
|
444 |
-
logger.warning("file not found: %s", file_path)
|
445 |
-
continue
|
446 |
-
try:
|
447 |
-
elements = partition_pdf(filename=str(file_path))
|
448 |
-
page_text = "\n".join([el.text for el in elements if hasattr(el, "text") and el.text])
|
449 |
-
except Exception:
|
450 |
-
logger.exception(f"Failed to parse PDF {file_path}")
|
451 |
page_text = ""
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
|
|
|
|
|
|
|
|
|
|
|
481 |
try:
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
if
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
result_state["patientDetails"] = extra_patient_meta or {"name": "", "age": "", "sex": "", "pid": patient_id}
|
490 |
-
if "reports" in missing:
|
491 |
-
result_state["reports"] = []
|
492 |
-
# Re-run analysis node to keep overallAnalysis consistent
|
493 |
-
result_state.update(analysis_node(result_state))
|
494 |
-
# Re-validate
|
495 |
-
cond = condition_loop_node(result_state)
|
496 |
-
result_state.update(cond)
|
497 |
-
|
498 |
-
safe_response = {
|
499 |
-
"patientDetails": result_state.get("patientDetails", {"name": "", "age": "", "sex": "", "pid": patient_id}),
|
500 |
-
"doctorDetails": result_state.get("doctorDetails", {"referredBy": ""}),
|
501 |
-
"reports": result_state.get("reports", []),
|
502 |
-
"overallAnalysis": result_state.get("overallAnalysis", {"summary": "", "recommendations": "", "longTermTrends": ""}),
|
503 |
-
"_pre_extracted_medications": result_state.get("medications", []),
|
504 |
-
"_validation": {
|
505 |
-
"valid": result_state.get("valid", True),
|
506 |
-
"missing": result_state.get("missing", [])
|
507 |
-
}
|
508 |
-
}
|
509 |
-
return jsonify(safe_response), 200
|
510 |
|
|
|
|
|
|
|
511 |
except Exception as e:
|
512 |
-
logger.exception("
|
513 |
-
return jsonify({"error": "
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
""
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
-F "files[]=@/path/to/report2.pdf"
|
529 |
-
"""
|
530 |
-
try:
|
531 |
-
# patient id can be in form or args (for convenience)
|
532 |
-
patient_id = request.form.get("patient_id") or request.args.get("patient_id")
|
533 |
-
if not patient_id:
|
534 |
-
return jsonify({"error": "patient_id form field required"}), 400
|
535 |
-
|
536 |
-
# get uploaded files (support both files and files[] naming)
|
537 |
-
uploaded_files = request.files.getlist("files")
|
538 |
-
if not uploaded_files:
|
539 |
-
# fallback: single file under name 'file'
|
540 |
-
single = request.files.get("file")
|
541 |
-
if single:
|
542 |
-
uploaded_files = [single]
|
543 |
-
|
544 |
-
if not uploaded_files:
|
545 |
-
return jsonify({"error": "no files uploaded (use form field 'files')"}), 400
|
546 |
-
|
547 |
-
# create patient folder under REPORTS_ROOT/<patient_id>
|
548 |
-
patient_folder = REPORTS_ROOT / str(patient_id)
|
549 |
-
patient_folder.mkdir(parents=True, exist_ok=True)
|
550 |
-
|
551 |
-
saved = []
|
552 |
-
skipped = []
|
553 |
-
|
554 |
-
for file_storage in uploaded_files:
|
555 |
-
orig_name = getattr(file_storage, "filename", "") or ""
|
556 |
-
filename = secure_filename(orig_name)
|
557 |
-
if not filename:
|
558 |
-
skipped.append({"filename": orig_name, "reason": "invalid filename"})
|
559 |
-
continue
|
560 |
-
|
561 |
-
# extension check
|
562 |
-
ext = filename.rsplit(".", 1)[1].lower() if "." in filename else ""
|
563 |
-
if ext not in ALLOWED_EXTENSIONS:
|
564 |
-
skipped.append({"filename": filename, "reason": f"extension '{ext}' not allowed"})
|
565 |
-
continue
|
566 |
-
|
567 |
-
# avoid overwriting: if collision, add numeric suffix
|
568 |
-
dest = patient_folder / filename
|
569 |
-
if dest.exists():
|
570 |
-
base, dot, extension = filename.rpartition(".")
|
571 |
-
# if no base (e.g. ".bashrc") fallback
|
572 |
-
base = base or filename
|
573 |
-
i = 1
|
574 |
-
while True:
|
575 |
-
candidate = f"{base}__{i}.{extension}" if extension else f"{base}__{i}"
|
576 |
-
dest = patient_folder / candidate
|
577 |
-
if not dest.exists():
|
578 |
-
filename = candidate
|
579 |
-
break
|
580 |
-
i += 1
|
581 |
-
|
582 |
-
try:
|
583 |
-
file_storage.save(str(dest))
|
584 |
-
saved.append(filename)
|
585 |
-
except Exception as e:
|
586 |
-
logger.exception("Failed to save uploaded file %s: %s", filename, e)
|
587 |
-
skipped.append({"filename": filename, "reason": f"save failed: {e}"})
|
588 |
-
|
589 |
-
return jsonify({
|
590 |
-
"patient_id": str(patient_id),
|
591 |
-
"saved": saved,
|
592 |
-
"skipped": skipped,
|
593 |
-
"patient_folder": str(patient_folder)
|
594 |
-
}), 200
|
595 |
-
|
596 |
-
except Exception as exc:
|
597 |
-
logger.exception("Upload failed: %s", exc)
|
598 |
-
return jsonify({"error": "upload failed", "detail": str(exc)}), 500
|
599 |
|
600 |
@app.route("/ping", methods=["GET"])
|
601 |
def ping():
|
@@ -604,4 +240,3 @@ def ping():
|
|
604 |
if __name__ == "__main__":
|
605 |
port = int(os.getenv("PORT", 7860))
|
606 |
app.run(host="0.0.0.0", port=port, debug=True)
|
607 |
-
|
|
|
1 |
#!/usr/bin/env python3
|
|
|
|
|
2 |
import os
|
3 |
import json
|
4 |
import logging
|
5 |
import re
|
6 |
+
from typing import Dict, Any
|
7 |
from pathlib import Path
|
8 |
+
from unstructured.partition.pdf import partition_pdf
|
|
|
9 |
from flask import Flask, request, jsonify
|
10 |
from flask_cors import CORS
|
11 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
12 |
from bloatectomy import bloatectomy
|
13 |
+
from werkzeug.utils import secure_filename
|
|
|
14 |
from langchain_groq import ChatGroq
|
|
|
|
|
|
|
|
|
15 |
from typing_extensions import TypedDict, NotRequired
|
16 |
|
17 |
+
# --- Logging ---
|
18 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
19 |
+
logger = logging.getLogger("patient-assistant")
|
20 |
|
21 |
+
# --- Load environment ---
|
22 |
load_dotenv()
|
23 |
+
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
24 |
+
if not GROQ_API_KEY:
|
25 |
+
logger.error("GROQ_API_KEY not set in environment")
|
26 |
+
exit(1)
|
27 |
+
|
28 |
+
# --- Flask app setup ---
|
29 |
+
BASE_DIR = Path(__file__).resolve().parent
|
30 |
+
REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", str(BASE_DIR / "reports")))
|
31 |
+
static_folder = BASE_DIR / "static"
|
32 |
+
|
33 |
+
app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static")
|
34 |
+
CORS(app)
|
35 |
+
|
36 |
+
# --- LLM setup ---
|
37 |
llm = ChatGroq(
|
38 |
model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
|
39 |
temperature=0.0,
|
40 |
+
max_tokens=1024,
|
41 |
+
api_key=GROQ_API_KEY,
|
42 |
)
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
|
45 |
+
"""Helper function to clean up text using the bloatectomy library."""
|
46 |
try:
|
47 |
b = bloatectomy(text, style=style, output="html")
|
48 |
tokens = getattr(b, "tokens", None)
|
|
|
53 |
logger.exception("Bloatectomy cleaning failed; returning original text")
|
54 |
return text
|
55 |
|
56 |
+
# --- Agent prompt instructions ---
|
57 |
+
PATIENT_ASSISTANT_PROMPT = """
|
58 |
+
You are a patient assistant helping to analyze medical records and reports. Your goal is to provide a comprehensive response based on the patient's health history and the current conversation.
|
59 |
+
|
60 |
+
Your tasks include:
|
61 |
+
- Analyzing medical records and reports to detect anomalies, redundant tests, or misleading treatments.
|
62 |
+
- Suggesting preventive care based on the overall patient health history.
|
63 |
+
- Optimizing healthcare costs by comparing past visits and treatments, helping patients make smarter choices.
|
64 |
+
- Offering personalized lifestyle recommendations, such as adopting healthier food habits, daily routines, and regular health checks.
|
65 |
+
- Generating a natural, helpful reply to the user.
|
66 |
+
|
67 |
+
You will be provided with the last user message, the conversation history, and a summary of the patient's medical reports. Use this information to give a tailored and informative response.
|
68 |
+
|
69 |
+
STRICT OUTPUT FORMAT (JSON ONLY):
|
70 |
+
Return a single JSON object with the following keys:
|
71 |
+
- assistant_reply: string // a natural language reply to the user (short, helpful, always present)
|
72 |
+
- patientDetails: object // keys may include name, problem, city, contact (update if user shared info)
|
73 |
+
- conversationSummary: string (optional) // short summary of conversation + relevant patient docs
|
74 |
+
|
75 |
+
Rules:
|
76 |
+
- ALWAYS include `assistant_reply` as a non-empty string.
|
77 |
+
- Do NOT produce any text outside the JSON object.
|
78 |
+
- Be concise in `assistant_reply`. If you need more details, ask a targeted follow-up question.
|
79 |
+
- Do not make up information that is not present in the provided medical reports or conversation history.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
"""
|
81 |
|
82 |
+
# --- JSON extraction helper ---
|
83 |
+
def extract_json_from_llm_response(raw_response: str) -> dict:
|
84 |
+
"""Safely extracts a JSON object from a string that might contain extra text or markdown."""
|
85 |
+
default = {
|
86 |
+
"assistant_reply": "I'm sorry — I couldn't understand that. Could you please rephrase?",
|
87 |
+
"patientDetails": {},
|
88 |
+
"conversationSummary": "",
|
89 |
+
}
|
90 |
|
91 |
+
if not raw_response or not isinstance(raw_response, str):
|
92 |
+
return default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
+
# Find the JSON object, ignoring any markdown code fences
|
95 |
+
m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response)
|
96 |
+
json_string = m.group(1).strip() if m else raw_response
|
|
|
|
|
|
|
97 |
|
98 |
+
# Find the first opening brace and the last closing brace
|
99 |
+
first = json_string.find('{')
|
100 |
+
last = json_string.rfind('}')
|
101 |
+
if first == -1 or last == -1 or first >= last:
|
102 |
+
try:
|
103 |
+
return json.loads(json_string)
|
104 |
+
except Exception:
|
105 |
+
logger.warning("Could not locate JSON braces in LLM output. Falling back to default.")
|
106 |
+
return default
|
107 |
+
|
108 |
+
candidate = json_string[first:last+1]
|
109 |
+
# Remove trailing commas that might cause parsing issues
|
110 |
+
candidate = re.sub(r',\s*(?=[}\]])', '', candidate)
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
try:
|
113 |
+
parsed = json.loads(candidate)
|
114 |
+
except Exception as e:
|
115 |
+
logger.warning("Failed to parse JSON from LLM output: %s", e)
|
116 |
+
return default
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
+
# Basic validation of the parsed JSON
|
119 |
+
if isinstance(parsed, dict) and "assistant_reply" in parsed and isinstance(parsed["assistant_reply"], str) and parsed["assistant_reply"].strip():
|
120 |
+
parsed.setdefault("patientDetails", {})
|
121 |
+
parsed.setdefault("conversationSummary", "")
|
122 |
return parsed
|
123 |
+
else:
|
124 |
+
logger.warning("Parsed JSON missing 'assistant_reply' or invalid format. Returning default.")
|
125 |
+
return default
|
126 |
|
127 |
+
# --- Flask routes ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
@app.route("/", methods=["GET"])
|
129 |
def serve_frontend():
|
130 |
+
"""Serves the frontend HTML file."""
|
131 |
try:
|
132 |
+
return app.send_static_file("frontend2.html")
|
133 |
+
except Exception:
|
134 |
+
return "<h3>frontend2.html not found in static/ — please add your frontend2.html there.</h3>", 404
|
|
|
135 |
|
136 |
+
@app.route("/chat", methods=["POST"])
|
137 |
+
def chat():
|
138 |
+
"""Handles the chat conversation with the assistant."""
|
139 |
+
data = request.get_json(force=True)
|
140 |
+
if not isinstance(data, dict):
|
141 |
+
return jsonify({"error": "invalid request body"}), 400
|
|
|
142 |
|
143 |
patient_id = data.get("patient_id")
|
144 |
+
if not patient_id:
|
145 |
+
return jsonify({"error": "patient_id required"}), 400
|
146 |
|
147 |
+
chat_history = data.get("chat_history") or []
|
148 |
+
patient_state = data.get("patient_state") or {}
|
149 |
|
150 |
+
# --- Read and parse patient reports ---
|
151 |
+
patient_folder = REPORTS_ROOT / f"p_{patient_id}"
|
|
|
|
|
|
|
152 |
combined_text_parts = []
|
153 |
+
if patient_folder.exists() and patient_folder.is_dir():
|
154 |
+
for fname in sorted(os.listdir(patient_folder)):
|
155 |
+
file_path = patient_folder / fname
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
page_text = ""
|
157 |
+
if partition_pdf is not None and str(file_path).lower().endswith('.pdf'):
|
158 |
+
try:
|
159 |
+
elements = partition_pdf(filename=str(file_path))
|
160 |
+
page_text = "\n".join([el.text for el in elements if hasattr(el, 'text') and el.text])
|
161 |
+
except Exception:
|
162 |
+
logger.exception("Failed to parse PDF %s", file_path)
|
163 |
+
else:
|
164 |
+
try:
|
165 |
+
page_text = file_path.read_text(encoding='utf-8', errors='ignore')
|
166 |
+
except Exception:
|
167 |
+
page_text = ""
|
168 |
+
|
169 |
+
if page_text:
|
170 |
+
cleaned = clean_notes_with_bloatectomy(page_text, style="remov")
|
171 |
+
if cleaned:
|
172 |
+
combined_text_parts.append(cleaned)
|
173 |
+
|
174 |
+
# --- Prepare the state for the LLM ---
|
175 |
+
state = patient_state.copy()
|
176 |
+
state["lastUserMessage"] = ""
|
177 |
+
if chat_history:
|
178 |
+
# Find the last user message
|
179 |
+
for msg in reversed(chat_history):
|
180 |
+
if msg.get("role") == "user" and msg.get("content"):
|
181 |
+
state["lastUserMessage"] = msg["content"]
|
182 |
+
break
|
183 |
+
|
184 |
+
# Update the conversation summary with the parsed documents
|
185 |
+
base_summary = state.get("conversationSummary", "") or ""
|
186 |
+
docs_summary = "\n\n".join(combined_text_parts)
|
187 |
+
if docs_summary:
|
188 |
+
state["conversationSummary"] = (base_summary + "\n\n" + docs_summary).strip()
|
189 |
+
else:
|
190 |
+
state["conversationSummary"] = base_summary
|
191 |
+
|
192 |
+
# --- Direct LLM Invocation ---
|
193 |
+
user_prompt = f"""
|
194 |
+
Current patientDetails: {json.dumps(state.get("patientDetails", {}))}
|
195 |
+
Current conversationSummary: {state.get("conversationSummary", "")}
|
196 |
+
Last user message: {state.get("lastUserMessage", "")}
|
197 |
+
|
198 |
+
Return ONLY valid JSON with keys: assistant_reply, patientDetails, conversationSummary.
|
199 |
+
"""
|
200 |
|
201 |
+
messages = [
|
202 |
+
{"role": "system", "content": PATIENT_ASSISTANT_PROMPT},
|
203 |
+
{"role": "user", "content": user_prompt}
|
204 |
+
]
|
205 |
+
|
206 |
try:
|
207 |
+
logger.info("Invoking LLM with prepared state and prompt...")
|
208 |
+
llm_response = llm.invoke(messages)
|
209 |
+
raw_response = ""
|
210 |
+
if hasattr(llm_response, "content"):
|
211 |
+
raw_response = llm_response.content
|
212 |
+
else:
|
213 |
+
raw_response = str(llm_response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
+
logger.info(f"Raw LLM response: {raw_response}")
|
216 |
+
parsed_result = extract_json_from_llm_response(raw_response)
|
217 |
+
|
218 |
except Exception as e:
|
219 |
+
logger.exception("LLM invocation failed")
|
220 |
+
return jsonify({"error": "LLM invocation failed", "detail": str(e)}), 500
|
221 |
+
|
222 |
+
updated_state = parsed_result or {}
|
223 |
+
|
224 |
+
assistant_reply = updated_state.get("assistant_reply")
|
225 |
+
if not assistant_reply or not isinstance(assistant_reply, str) or not assistant_reply.strip():
|
226 |
+
# Fallback to a polite message if the LLM response is invalid or empty
|
227 |
+
assistant_reply = "I'm here to help — could you tell me more about your symptoms?"
|
228 |
+
|
229 |
+
response_payload = {
|
230 |
+
"assistant_reply": assistant_reply,
|
231 |
+
"updated_state": updated_state,
|
232 |
+
}
|
233 |
+
|
234 |
+
return jsonify(response_payload)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
|
236 |
@app.route("/ping", methods=["GET"])
|
237 |
def ping():
|
|
|
240 |
if __name__ == "__main__":
|
241 |
port = int(os.getenv("PORT", 7860))
|
242 |
app.run(host="0.0.0.0", port=port, debug=True)
|
|