WebashalarForML commited on
Commit
ceaa691
·
verified ·
1 Parent(s): efcd956

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +500 -499
app.py CHANGED
@@ -1,499 +1,500 @@
1
- #!/usr/bin/env python3
2
- # app.py - Health Reports processing agent (PDF -> cleaned text -> structured JSON)
3
- # Requires: bloatectomy, unstructured, langgraph, langchain_groq (ChatGroq), python-dotenv
4
-
5
- import os
6
- import json
7
- import logging
8
- import re
9
- from pathlib import Path
10
- from typing import List, Dict, Any
11
-
12
- from flask import Flask, request, jsonify
13
- from flask_cors import CORS
14
- from dotenv import load_dotenv
15
- from unstructured.partition.pdf import partition_pdf
16
-
17
- # Bloatectomy class (as per the source you provided)
18
- from bloatectomy import bloatectomy
19
-
20
- # LLM / agent
21
- from langchain_groq import ChatGroq
22
- from langgraph.prebuilt import create_react_agent
23
-
24
- # LangGraph imports
25
- from langgraph.graph import StateGraph, START, END
26
- from typing_extensions import TypedDict, NotRequired
27
-
28
- # --- Logging ---------------------------------------------------------------
29
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
30
- logger = logging.getLogger("health-agent")
31
-
32
- # --- Environment & config -------------------------------------------------
33
- load_dotenv()
34
- REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", r"D:\DEV PATEL\2025\HealthCareAI\reports")) # e.g. /app/reports/<patient_id>/<file.pdf>
35
- SSRI_FILE = Path(os.getenv("SSRI_FILE", r"D:\DEV PATEL\2025\HealthCareAI\medicationCategories\SSRI_list.txt"))
36
- MISC_FILE = Path(os.getenv("MISC_FILE", r"D:\DEV PATEL\2025\HealthCareAI\medicationCategories\MISC_list.txt"))
37
- GROQ_API_KEY = os.getenv("GROQ_API_KEY", None)
38
-
39
- # --- LLM setup -------------------------------------------------------------
40
- llm = ChatGroq(
41
- model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
42
- temperature=0.0,
43
- max_tokens=None,
44
- )
45
-
46
- # Top-level strict system prompt for report JSON pieces (each node will use a more specific prompt)
47
- NODE_BASE_INSTRUCTIONS = """
48
- You are HealthAI — a clinical assistant producing JSON for downstream processing.
49
- Produce only valid JSON (no extra text). Follow field types exactly. If missing data, return empty strings or empty arrays.
50
- Be conservative: do not assert diagnoses; provide suggestions and ask physician confirmation where needed.
51
- """
52
-
53
- # Build a generic agent and a JSON resolver agent (to fix broken JSON from LLM)
54
- agent = create_react_agent(model=llm, tools=[], prompt=NODE_BASE_INSTRUCTIONS)
55
- agent_json_resolver = create_react_agent(model=llm, tools=[], prompt="""
56
- You are a JSON fixer. Input: a possibly-malformed JSON-like text. Output: valid JSON only (enclosed in triple backticks).
57
- Fix missing quotes, trailing commas, unescaped newlines, stray assistant labels, and ensure schema compliance.
58
- """)
59
-
60
- # -------------------- JSON extraction / sanitizer ---------------------------
61
- def extract_json_from_llm_response(raw_response: str) -> dict:
62
- """
63
- Try extracting a JSON object from raw LLM text. Performs common cleanups seen in LLM outputs.
64
- Raises JSONDecodeError if parsing still fails.
65
- """
66
- # --- 1) Pull out the JSON code-block if present ---
67
- md = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response)
68
- json_string = md.group(1).strip() if md else raw_response
69
-
70
- # --- 2) Trim to the outermost { … } so we drop any prefix/suffix junk ---
71
- first, last = json_string.find('{'), json_string.rfind('}')
72
- if 0 <= first < last:
73
- json_string = json_string[first:last+1]
74
-
75
- # --- 3) PRE-CLEANUP: remove rogue assistant labels, fix boolean quotes ---
76
- json_string = re.sub(r'\b\w+\s*{', '{', json_string)
77
- json_string = re.sub(r'"assistant"\s*:', '', json_string)
78
- json_string = re.sub(r'\b(false|true)"', r'\1', json_string)
79
-
80
- # --- 4) Escape embedded quotes in long string fields (best-effort) ---
81
- def _esc(m):
82
- prefix, body = m.group(1), m.group(2)
83
- return prefix + body.replace('"', r'\"')
84
- json_string = re.sub(
85
- r'("logic"\s*:\s*")([\s\S]+?)(?=",\s*"[A-Za-z_]\w*"\s*:\s*)',
86
- _esc,
87
- json_string
88
- )
89
-
90
- # --- 5) Remove trailing commas before } or ] ---
91
- json_string = re.sub(r',\s*(?=[}\],])', '', json_string)
92
- json_string = re.sub(r',\s*,', ',', json_string)
93
-
94
- # --- 6) Balance braces if obvious excess ---
95
- ob, cb = json_string.count('{'), json_string.count('}')
96
- if cb > ob:
97
- excess = cb - ob
98
- json_string = json_string.rstrip()[:-excess]
99
-
100
- # --- 7) Escape literal newlines inside strings so json.loads can parse ---
101
- def _escape_newlines_in_strings(s: str) -> str:
102
- return re.sub(
103
- r'"((?:[^"\\]|\\.)*?)"',
104
- lambda m: '"' + m.group(1).replace('\n', '\\n').replace('\r', '\\r') + '"',
105
- s,
106
- flags=re.DOTALL
107
- )
108
- json_string = _escape_newlines_in_strings(json_string)
109
-
110
- # Final parse
111
- return json.loads(json_string)
112
-
113
- # -------------------- Utility: Bloatectomy wrapper ------------------------
114
- def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
115
- """
116
- Uses the bloatectomy class to remove duplicates.
117
- style: 'highlight'|'bold'|'remov' ; we use 'remov' to delete duplicates.
118
- Returns cleaned text (single string).
119
- """
120
- try:
121
- b = bloatectomy(text, style=style, output="html")
122
- tokens = getattr(b, "tokens", None)
123
- if not tokens:
124
- return text
125
- return "\n".join(tokens)
126
- except Exception:
127
- logger.exception("Bloatectomy cleaning failed; returning original text")
128
- return text
129
-
130
- # --------------- Utility: medication extraction (adapted) -----------------
131
- def readDrugs_from_file(path: Path):
132
- if not path.exists():
133
- return {}, []
134
- txt = path.read_text(encoding="utf-8", errors="ignore")
135
- generics = re.findall(r"^(.*?)\|", txt, re.MULTILINE)
136
- generics = [g.lower() for g in generics if g]
137
- lines = [ln.strip().lower() for ln in txt.splitlines() if ln.strip()]
138
- return dict(zip(generics, lines)), generics
139
-
140
- def addToDrugs_line(line: str, drugs_flags: List[int], listing: Dict[str,str], genList: List[str]) -> List[int]:
141
- gen_index = {g:i for i,g in enumerate(genList)}
142
- for generic, pattern_line in listing.items():
143
- try:
144
- if re.search(pattern_line, line, re.I):
145
- idx = gen_index.get(generic)
146
- if idx is not None:
147
- drugs_flags[idx] = 1
148
- except re.error:
149
- continue
150
- return drugs_flags
151
-
152
- def extract_medications_from_text(text: str) -> List[str]:
153
- ssri_map, ssri_generics = readDrugs_from_file(SSRI_FILE)
154
- misc_map, misc_generics = readDrugs_from_file(MISC_FILE)
155
- combined_map = {**ssri_map, **misc_map}
156
- combined_generics = []
157
- if ssri_generics:
158
- combined_generics.extend(ssri_generics)
159
- if misc_generics:
160
- combined_generics.extend(misc_generics)
161
-
162
- flags = [0]* len(combined_generics)
163
- meds_found = set()
164
- for ln in text.splitlines():
165
- ln = ln.strip()
166
- if not ln:
167
- continue
168
- if combined_map:
169
- flags = addToDrugs_line(ln, flags, combined_map, combined_generics)
170
- m = re.search(r"\b(Rx|Drug|Medication|Prescribed|Tablet)\s*[:\-]?\s*([A-Za-z0-9\-\s/\.]+)", ln, re.I)
171
- if m:
172
- meds_found.add(m.group(2).strip())
173
- m2 = re.findall(r"\b([A-Z][a-z0-9\-]{2,}\s*(?:[0-9]{1,4}\s*(?:mg|mcg|g|IU))?)", ln)
174
- for s in m2:
175
- if re.search(r"\b(mg|mcg|g|IU)\b", s, re.I):
176
- meds_found.add(s.strip())
177
- for i, f in enumerate(flags):
178
- if f == 1:
179
- meds_found.add(combined_generics[i])
180
- return list(meds_found)
181
-
182
- # -------------------- Node prompts --------------------------
183
- PATIENT_NODE_PROMPT = """
184
- You will extract patientDetails from the provided document texts.
185
- Return ONLY JSON with this exact shape:
186
- { "patientDetails": {"name": "", "age": "", "sex": "", "pid": ""} }
187
- Fill fields using text evidence or leave empty strings.
188
- """
189
-
190
- DOCTOR_NODE_PROMPT = """
191
- You will extract doctorDetails found in the documents.
192
- Return ONLY JSON with this exact shape:
193
- { "doctorDetails": {"referredBy": ""} }
194
- """
195
-
196
- TEST_REPORT_NODE_PROMPT = """
197
- You will extract per-test structured results from the documents.
198
- Return ONLY JSON with this exact shape:
199
- {
200
- "reports": [
201
- {
202
- "testName": "",
203
- "dateReported": "",
204
- "timeReported": "",
205
- "abnormalFindings": [
206
- {"investigation": "", "result": 0, "unit": "", "status": "", "referenceValue": ""}
207
- ],
208
- "interpretation": "",
209
- "trends": []
210
- }
211
- ]
212
- }
213
- - Include only findings that are outside reference ranges OR explicitly called 'abnormal' in the report.
214
- - For result numeric parsing, prefer numeric values; if not numeric, keep original string.
215
- - Use statuses: Low, High, Borderline, Positive, Negative, Normal.
216
- """
217
-
218
- ANALYSIS_NODE_PROMPT = """
219
- You will create an overallAnalysis based on the extracted reports (the agent will give you the 'reports' JSON).
220
- Return ONLY JSON:
221
- { "overallAnalysis": { "summary": "", "recommendations": "", "longTermTrends": "",""risk_prediction": "","drug_interaction": "" } }
222
- Be conservative, evidence-based, and suggest follow-up steps for physicians.
223
- """
224
-
225
- CONDITION_LOOP_NODE_PROMPT = """
226
- Validation and condition node:
227
- Input: partial JSON (patientDetails, doctorDetails, reports, overallAnalysis).
228
- Task: Check required keys exist and that each report has at least testName and abnormalFindings list.
229
- Return ONLY JSON:
230
- { "valid": true, "missing": [] }
231
- If missing fields, list keys in 'missing'. Do NOT modify content.
232
- """
233
-
234
- # -------------------- Node helpers -------------------------
235
- def call_node_agent(node_prompt: str, payload: dict) -> dict:
236
- """
237
- Call the generic agent with a targeted node prompt and the payload.
238
- Tries to parse JSON. If parsing fails, uses the JSON resolver agent once.
239
- """
240
- try:
241
- content = {
242
- "prompt": node_prompt,
243
- "payload": payload
244
- }
245
- resp = agent.invoke({"messages": [{"role": "user", "content": json.dumps(content)}]})
246
-
247
- # Extract raw text from AIMessage or other response types
248
- raw = None
249
- if isinstance(resp, str):
250
- raw = resp
251
- elif hasattr(resp, "content"): # AIMessage or similar
252
- raw = resp.content
253
- elif isinstance(resp, dict):
254
- msgs = resp.get("messages")
255
- if msgs:
256
- last_msg = msgs[-1]
257
- if isinstance(last_msg, str):
258
- raw = last_msg
259
- elif hasattr(last_msg, "content"):
260
- raw = last_msg.content
261
- elif isinstance(last_msg, dict):
262
- raw = last_msg.get("content", "")
263
- else:
264
- raw = str(last_msg)
265
- else:
266
- raw = json.dumps(resp)
267
- else:
268
- raw = str(resp)
269
-
270
- parsed = extract_json_from_llm_response(raw)
271
- return parsed
272
-
273
- except Exception as e:
274
- logger.warning("Node agent JSON parse failed: %s. Attempting JSON resolver.", e)
275
- try:
276
- resolver_prompt = f"Fix this JSON. Input:\n```json\n{raw}\n```\nReturn valid JSON only."
277
- r = agent_json_resolver.invoke({"messages": [{"role": "user", "content": resolver_prompt}]})
278
-
279
- rtxt = None
280
- if isinstance(r, str):
281
- rtxt = r
282
- elif hasattr(r, "content"):
283
- rtxt = r.content
284
- elif isinstance(r, dict):
285
- msgs = r.get("messages")
286
- if msgs:
287
- last_msg = msgs[-1]
288
- if isinstance(last_msg, str):
289
- rtxt = last_msg
290
- elif hasattr(last_msg, "content"):
291
- rtxt = last_msg.content
292
- elif isinstance(last_msg, dict):
293
- rtxt = last_msg.get("content", "")
294
- else:
295
- rtxt = str(last_msg)
296
- else:
297
- rtxt = json.dumps(r)
298
- else:
299
- rtxt = str(r)
300
-
301
- corrected = extract_json_from_llm_response(rtxt)
302
- return corrected
303
- except Exception as e2:
304
- logger.exception("JSON resolver also failed: %s", e2)
305
- return {}
306
-
307
- # -------------------- Define LangGraph State schema -------------------------
308
- class State(TypedDict):
309
- patient_meta: NotRequired[Dict[str, Any]]
310
- patient_id: str
311
- documents: List[Dict[str, Any]]
312
- medications: List[str]
313
- patientDetails: NotRequired[Dict[str, Any]]
314
- doctorDetails: NotRequired[Dict[str, Any]]
315
- reports: NotRequired[List[Dict[str, Any]]]
316
- overallAnalysis: NotRequired[Dict[str, Any]]
317
- valid: NotRequired[bool]
318
- missing: NotRequired[List[str]]
319
-
320
- # -------------------- Node implementations as LangGraph nodes -------------------------
321
- def patient_details_node(state: State) -> dict:
322
- payload = {
323
- "patient_meta": state.get("patient_meta", {}),
324
- "documents": state.get("documents", []),
325
- "medications": state.get("medications", [])
326
- }
327
- logger.info("Running patient_details_node")
328
- out = call_node_agent(PATIENT_NODE_PROMPT, payload)
329
- return {"patientDetails": out.get("patientDetails", {}) if isinstance(out, dict) else {}}
330
-
331
- def doctor_details_node(state: State) -> dict:
332
- payload = {
333
- "documents": state.get("documents", []),
334
- "medications": state.get("medications", [])
335
- }
336
- logger.info("Running doctor_details_node")
337
- out = call_node_agent(DOCTOR_NODE_PROMPT, payload)
338
- return {"doctorDetails": out.get("doctorDetails", {}) if isinstance(out, dict) else {}}
339
-
340
- def test_report_node(state: State) -> dict:
341
- payload = {
342
- "documents": state.get("documents", []),
343
- "medications": state.get("medications", [])
344
- }
345
- logger.info("Running test_report_node")
346
- out = call_node_agent(TEST_REPORT_NODE_PROMPT, payload)
347
- return {"reports": out.get("reports", []) if isinstance(out, dict) else []}
348
-
349
- def analysis_node(state: State) -> dict:
350
- payload = {
351
- "patientDetails": state.get("patientDetails", {}),
352
- "doctorDetails": state.get("doctorDetails", {}),
353
- "reports": state.get("reports", []),
354
- "medications": state.get("medications", [])
355
- }
356
- logger.info("Running analysis_node")
357
- out = call_node_agent(ANALYSIS_NODE_PROMPT, payload)
358
- return {"overallAnalysis": out.get("overallAnalysis", {}) if isinstance(out, dict) else {}}
359
-
360
- def condition_loop_node(state: State) -> dict:
361
- payload = {
362
- "patientDetails": state.get("patientDetails", {}),
363
- "doctorDetails": state.get("doctorDetails", {}),
364
- "reports": state.get("reports", []),
365
- "overallAnalysis": state.get("overallAnalysis", {})
366
- }
367
- logger.info("Running condition_loop_node (validation)")
368
- out = call_node_agent(CONDITION_LOOP_NODE_PROMPT, payload)
369
- if isinstance(out, dict) and "valid" in out:
370
- return {"valid": bool(out.get("valid")), "missing": out.get("missing", [])}
371
- missing = []
372
- if not state.get("patientDetails"):
373
- missing.append("patientDetails")
374
- if not state.get("reports"):
375
- missing.append("reports")
376
- return {"valid": len(missing) == 0, "missing": missing}
377
-
378
- # -------------------- Build LangGraph StateGraph -------------------------
379
- graph_builder = StateGraph(State)
380
-
381
- graph_builder.add_node("patient_details", patient_details_node)
382
- graph_builder.add_node("doctor_details", doctor_details_node)
383
- graph_builder.add_node("test_report", test_report_node)
384
- graph_builder.add_node("analysis", analysis_node)
385
- graph_builder.add_node("condition_loop", condition_loop_node)
386
-
387
- graph_builder.add_edge(START, "patient_details")
388
- graph_builder.add_edge("patient_details", "doctor_details")
389
- graph_builder.add_edge("doctor_details", "test_report")
390
- graph_builder.add_edge("test_report", "analysis")
391
- graph_builder.add_edge("analysis", "condition_loop")
392
- graph_builder.add_edge("condition_loop", END)
393
-
394
- graph = graph_builder.compile()
395
-
396
- # -------------------- Flask app & endpoints -------------------------------
397
- BASE_DIR = Path(__file__).resolve().parent
398
- static_folder = BASE_DIR / "static"
399
- app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static")
400
- CORS(app) # dev convenience; lock down in production
401
-
402
- # serve frontend root
403
- @app.route("/", methods=["GET"])
404
- def serve_frontend():
405
- try:
406
- return app.send_static_file("frontend.html")
407
- except Exception:
408
- return "<h3>frontend.html not found in static/ — drop your frontend.html there.</h3>", 404
409
-
410
- @app.route("/process_reports", methods=["POST"])
411
- def process_reports():
412
- data = request.get_json(force=True)
413
- patient_id = data.get("patient_id")
414
- filenames = data.get("filenames", [])
415
- extra_patient_meta = data.get("patientDetails", {})
416
-
417
- if not patient_id or not filenames:
418
- return jsonify({"error": "missing patient_id or filenames"}), 400
419
-
420
- patient_folder = REPORTS_ROOT / str(patient_id)
421
- if not patient_folder.exists() or not patient_folder.is_dir():
422
- return jsonify({"error": f"patient folder not found: {patient_folder}"}), 404
423
-
424
- documents = []
425
- combined_text_parts = []
426
-
427
- for fname in filenames:
428
- file_path = patient_folder / fname
429
- if not file_path.exists():
430
- logger.warning("file not found: %s", file_path)
431
- continue
432
- try:
433
- elements = partition_pdf(filename=str(file_path))
434
- page_text = "\n".join([el.text for el in elements if hasattr(el, "text") and el.text])
435
- except Exception:
436
- logger.exception("Failed to parse PDF %s", file_path)
437
- page_text = ""
438
- cleaned = clean_notes_with_bloatectomy(page_text, style="remov")
439
- documents.append({
440
- "filename": fname,
441
- "raw_text": page_text,
442
- "cleaned_text": cleaned
443
- })
444
- combined_text_parts.append(cleaned)
445
-
446
- if not documents:
447
- return jsonify({"error": "no valid documents found"}), 400
448
-
449
- combined_text = "\n\n".join(combined_text_parts)
450
- meds = extract_medications_from_text(combined_text)
451
-
452
- initial_state = {
453
- "patient_meta": extra_patient_meta,
454
- "patient_id": patient_id,
455
- "documents": documents,
456
- "medications": meds
457
- }
458
-
459
- try:
460
- result_state = graph.invoke(initial_state)
461
-
462
- # Validate and fill placeholders if needed
463
- if not result_state.get("valid", True):
464
- missing = result_state.get("missing", [])
465
- logger.info("Validation failed; missing keys: %s", missing)
466
- if "patientDetails" in missing:
467
- result_state["patientDetails"] = extra_patient_meta or {"name": "", "age": "", "sex": "", "pid": patient_id}
468
- if "reports" in missing:
469
- result_state["reports"] = []
470
- # Re-run analysis node to keep overallAnalysis consistent
471
- result_state.update(analysis_node(result_state))
472
- # Re-validate
473
- cond = condition_loop_node(result_state)
474
- result_state.update(cond)
475
-
476
- safe_response = {
477
- "patientDetails": result_state.get("patientDetails", {"name": "", "age": "", "sex": "", "pid": patient_id}),
478
- "doctorDetails": result_state.get("doctorDetails", {"referredBy": ""}),
479
- "reports": result_state.get("reports", []),
480
- "overallAnalysis": result_state.get("overallAnalysis", {"summary": "", "recommendations": "", "longTermTrends": ""}),
481
- "_pre_extracted_medications": result_state.get("medications", []),
482
- "_validation": {
483
- "valid": result_state.get("valid", True),
484
- "missing": result_state.get("missing", [])
485
- }
486
- }
487
- return jsonify(safe_response), 200
488
-
489
- except Exception as e:
490
- logger.exception("Node pipeline failed")
491
- return jsonify({"error": "Node pipeline failed", "detail": str(e)}), 500
492
-
493
- @app.route("/ping", methods=["GET"])
494
- def ping():
495
- return jsonify({"status": "ok"})
496
-
497
- if __name__ == "__main__":
498
- port = int(os.getenv("PORT", 5000))
499
- app.run(host="0.0.0.0", port=port, debug=True)
 
 
1
+ #!/usr/bin/env python3
2
+ # app.py - Health Reports processing agent (PDF -> cleaned text -> structured JSON)
3
+ # Requires: bloatectomy, unstructured, langgraph, langchain_groq (ChatGroq), python-dotenv
4
+
5
+ import os
6
+ import json
7
+ import logging
8
+ import re
9
+ from pathlib import Path
10
+ from typing import List, Dict, Any
11
+
12
+ from flask import Flask, request, jsonify
13
+ from flask_cors import CORS
14
+ from dotenv import load_dotenv
15
+ from unstructured.partition.pdf import partition_pdf
16
+
17
+ # Bloatectomy class (as per the source you provided)
18
+ from bloatectomy import bloatectomy
19
+
20
+ # LLM / agent
21
+ from langchain_groq import ChatGroq
22
+ from langgraph.prebuilt import create_react_agent
23
+
24
+ # LangGraph imports
25
+ from langgraph.graph import StateGraph, START, END
26
+ from typing_extensions import TypedDict, NotRequired
27
+
28
+ # --- Logging ---------------------------------------------------------------
29
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
30
+ logger = logging.getLogger("health-agent")
31
+
32
+ # --- Environment & config -------------------------------------------------
33
+ load_dotenv()
34
+ from pathlib import Path
35
+ REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", r"app\reports")) # e.g. /app/reports/<patient_id>/<file.pdf>
36
+ SSRI_FILE = Path(os.getenv("SSRI_FILE", r"app\medicationCategories\SSRI_list.txt"))
37
+ MISC_FILE = Path(os.getenv("MISC_FILE", r"app\medicationCategories\MISC_list.txt"))
38
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", None)
39
+
40
+ # --- LLM setup -------------------------------------------------------------
41
+ llm = ChatGroq(
42
+ model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
43
+ temperature=0.0,
44
+ max_tokens=None,
45
+ )
46
+
47
+ # Top-level strict system prompt for report JSON pieces (each node will use a more specific prompt)
48
+ NODE_BASE_INSTRUCTIONS = """
49
+ You are HealthAI a clinical assistant producing JSON for downstream processing.
50
+ Produce only valid JSON (no extra text). Follow field types exactly. If missing data, return empty strings or empty arrays.
51
+ Be conservative: do not assert diagnoses; provide suggestions and ask physician confirmation where needed.
52
+ """
53
+
54
+ # Build a generic agent and a JSON resolver agent (to fix broken JSON from LLM)
55
+ agent = create_react_agent(model=llm, tools=[], prompt=NODE_BASE_INSTRUCTIONS)
56
+ agent_json_resolver = create_react_agent(model=llm, tools=[], prompt="""
57
+ You are a JSON fixer. Input: a possibly-malformed JSON-like text. Output: valid JSON only (enclosed in triple backticks).
58
+ Fix missing quotes, trailing commas, unescaped newlines, stray assistant labels, and ensure schema compliance.
59
+ """)
60
+
61
+ # -------------------- JSON extraction / sanitizer ---------------------------
62
+ def extract_json_from_llm_response(raw_response: str) -> dict:
63
+ """
64
+ Try extracting a JSON object from raw LLM text. Performs common cleanups seen in LLM outputs.
65
+ Raises JSONDecodeError if parsing still fails.
66
+ """
67
+ # --- 1) Pull out the JSON code-block if present ---
68
+ md = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response)
69
+ json_string = md.group(1).strip() if md else raw_response
70
+
71
+ # --- 2) Trim to the outermost { } so we drop any prefix/suffix junk ---
72
+ first, last = json_string.find('{'), json_string.rfind('}')
73
+ if 0 <= first < last:
74
+ json_string = json_string[first:last+1]
75
+
76
+ # --- 3) PRE-CLEANUP: remove rogue assistant labels, fix boolean quotes ---
77
+ json_string = re.sub(r'\b\w+\s*{', '{', json_string)
78
+ json_string = re.sub(r'"assistant"\s*:', '', json_string)
79
+ json_string = re.sub(r'\b(false|true)"', r'\1', json_string)
80
+
81
+ # --- 4) Escape embedded quotes in long string fields (best-effort) ---
82
+ def _esc(m):
83
+ prefix, body = m.group(1), m.group(2)
84
+ return prefix + body.replace('"', r'\"')
85
+ json_string = re.sub(
86
+ r'("logic"\s*:\s*")([\s\S]+?)(?=",\s*"[A-Za-z_]\w*"\s*:\s*)',
87
+ _esc,
88
+ json_string
89
+ )
90
+
91
+ # --- 5) Remove trailing commas before } or ] ---
92
+ json_string = re.sub(r',\s*(?=[}\],])', '', json_string)
93
+ json_string = re.sub(r',\s*,', ',', json_string)
94
+
95
+ # --- 6) Balance braces if obvious excess ---
96
+ ob, cb = json_string.count('{'), json_string.count('}')
97
+ if cb > ob:
98
+ excess = cb - ob
99
+ json_string = json_string.rstrip()[:-excess]
100
+
101
+ # --- 7) Escape literal newlines inside strings so json.loads can parse ---
102
+ def _escape_newlines_in_strings(s: str) -> str:
103
+ return re.sub(
104
+ r'"((?:[^"\\]|\\.)*?)"',
105
+ lambda m: '"' + m.group(1).replace('\n', '\\n').replace('\r', '\\r') + '"',
106
+ s,
107
+ flags=re.DOTALL
108
+ )
109
+ json_string = _escape_newlines_in_strings(json_string)
110
+
111
+ # Final parse
112
+ return json.loads(json_string)
113
+
114
+ # -------------------- Utility: Bloatectomy wrapper ------------------------
115
+ def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
116
+ """
117
+ Uses the bloatectomy class to remove duplicates.
118
+ style: 'highlight'|'bold'|'remov' ; we use 'remov' to delete duplicates.
119
+ Returns cleaned text (single string).
120
+ """
121
+ try:
122
+ b = bloatectomy(text, style=style, output="html")
123
+ tokens = getattr(b, "tokens", None)
124
+ if not tokens:
125
+ return text
126
+ return "\n".join(tokens)
127
+ except Exception:
128
+ logger.exception("Bloatectomy cleaning failed; returning original text")
129
+ return text
130
+
131
+ # --------------- Utility: medication extraction (adapted) -----------------
132
+ def readDrugs_from_file(path: Path):
133
+ if not path.exists():
134
+ return {}, []
135
+ txt = path.read_text(encoding="utf-8", errors="ignore")
136
+ generics = re.findall(r"^(.*?)\|", txt, re.MULTILINE)
137
+ generics = [g.lower() for g in generics if g]
138
+ lines = [ln.strip().lower() for ln in txt.splitlines() if ln.strip()]
139
+ return dict(zip(generics, lines)), generics
140
+
141
+ def addToDrugs_line(line: str, drugs_flags: List[int], listing: Dict[str,str], genList: List[str]) -> List[int]:
142
+ gen_index = {g:i for i,g in enumerate(genList)}
143
+ for generic, pattern_line in listing.items():
144
+ try:
145
+ if re.search(pattern_line, line, re.I):
146
+ idx = gen_index.get(generic)
147
+ if idx is not None:
148
+ drugs_flags[idx] = 1
149
+ except re.error:
150
+ continue
151
+ return drugs_flags
152
+
153
+ def extract_medications_from_text(text: str) -> List[str]:
154
+ ssri_map, ssri_generics = readDrugs_from_file(SSRI_FILE)
155
+ misc_map, misc_generics = readDrugs_from_file(MISC_FILE)
156
+ combined_map = {**ssri_map, **misc_map}
157
+ combined_generics = []
158
+ if ssri_generics:
159
+ combined_generics.extend(ssri_generics)
160
+ if misc_generics:
161
+ combined_generics.extend(misc_generics)
162
+
163
+ flags = [0]* len(combined_generics)
164
+ meds_found = set()
165
+ for ln in text.splitlines():
166
+ ln = ln.strip()
167
+ if not ln:
168
+ continue
169
+ if combined_map:
170
+ flags = addToDrugs_line(ln, flags, combined_map, combined_generics)
171
+ m = re.search(r"\b(Rx|Drug|Medication|Prescribed|Tablet)\s*[:\-]?\s*([A-Za-z0-9\-\s/\.]+)", ln, re.I)
172
+ if m:
173
+ meds_found.add(m.group(2).strip())
174
+ m2 = re.findall(r"\b([A-Z][a-z0-9\-]{2,}\s*(?:[0-9]{1,4}\s*(?:mg|mcg|g|IU))?)", ln)
175
+ for s in m2:
176
+ if re.search(r"\b(mg|mcg|g|IU)\b", s, re.I):
177
+ meds_found.add(s.strip())
178
+ for i, f in enumerate(flags):
179
+ if f == 1:
180
+ meds_found.add(combined_generics[i])
181
+ return list(meds_found)
182
+
183
+ # -------------------- Node prompts --------------------------
184
+ PATIENT_NODE_PROMPT = """
185
+ You will extract patientDetails from the provided document texts.
186
+ Return ONLY JSON with this exact shape:
187
+ { "patientDetails": {"name": "", "age": "", "sex": "", "pid": ""} }
188
+ Fill fields using text evidence or leave empty strings.
189
+ """
190
+
191
+ DOCTOR_NODE_PROMPT = """
192
+ You will extract doctorDetails found in the documents.
193
+ Return ONLY JSON with this exact shape:
194
+ { "doctorDetails": {"referredBy": ""} }
195
+ """
196
+
197
+ TEST_REPORT_NODE_PROMPT = """
198
+ You will extract per-test structured results from the documents.
199
+ Return ONLY JSON with this exact shape:
200
+ {
201
+ "reports": [
202
+ {
203
+ "testName": "",
204
+ "dateReported": "",
205
+ "timeReported": "",
206
+ "abnormalFindings": [
207
+ {"investigation": "", "result": 0, "unit": "", "status": "", "referenceValue": ""}
208
+ ],
209
+ "interpretation": "",
210
+ "trends": []
211
+ }
212
+ ]
213
+ }
214
+ - Include only findings that are outside reference ranges OR explicitly called 'abnormal' in the report.
215
+ - For result numeric parsing, prefer numeric values; if not numeric, keep original string.
216
+ - Use statuses: Low, High, Borderline, Positive, Negative, Normal.
217
+ """
218
+
219
+ ANALYSIS_NODE_PROMPT = """
220
+ You will create an overallAnalysis based on the extracted reports (the agent will give you the 'reports' JSON).
221
+ Return ONLY JSON:
222
+ { "overallAnalysis": { "summary": "", "recommendations": "", "longTermTrends": "",""risk_prediction": "","drug_interaction": "" } }
223
+ Be conservative, evidence-based, and suggest follow-up steps for physicians.
224
+ """
225
+
226
+ CONDITION_LOOP_NODE_PROMPT = """
227
+ Validation and condition node:
228
+ Input: partial JSON (patientDetails, doctorDetails, reports, overallAnalysis).
229
+ Task: Check required keys exist and that each report has at least testName and abnormalFindings list.
230
+ Return ONLY JSON:
231
+ { "valid": true, "missing": [] }
232
+ If missing fields, list keys in 'missing'. Do NOT modify content.
233
+ """
234
+
235
+ # -------------------- Node helpers -------------------------
236
+ def call_node_agent(node_prompt: str, payload: dict) -> dict:
237
+ """
238
+ Call the generic agent with a targeted node prompt and the payload.
239
+ Tries to parse JSON. If parsing fails, uses the JSON resolver agent once.
240
+ """
241
+ try:
242
+ content = {
243
+ "prompt": node_prompt,
244
+ "payload": payload
245
+ }
246
+ resp = agent.invoke({"messages": [{"role": "user", "content": json.dumps(content)}]})
247
+
248
+ # Extract raw text from AIMessage or other response types
249
+ raw = None
250
+ if isinstance(resp, str):
251
+ raw = resp
252
+ elif hasattr(resp, "content"): # AIMessage or similar
253
+ raw = resp.content
254
+ elif isinstance(resp, dict):
255
+ msgs = resp.get("messages")
256
+ if msgs:
257
+ last_msg = msgs[-1]
258
+ if isinstance(last_msg, str):
259
+ raw = last_msg
260
+ elif hasattr(last_msg, "content"):
261
+ raw = last_msg.content
262
+ elif isinstance(last_msg, dict):
263
+ raw = last_msg.get("content", "")
264
+ else:
265
+ raw = str(last_msg)
266
+ else:
267
+ raw = json.dumps(resp)
268
+ else:
269
+ raw = str(resp)
270
+
271
+ parsed = extract_json_from_llm_response(raw)
272
+ return parsed
273
+
274
+ except Exception as e:
275
+ logger.warning("Node agent JSON parse failed: %s. Attempting JSON resolver.", e)
276
+ try:
277
+ resolver_prompt = f"Fix this JSON. Input:\n```json\n{raw}\n```\nReturn valid JSON only."
278
+ r = agent_json_resolver.invoke({"messages": [{"role": "user", "content": resolver_prompt}]})
279
+
280
+ rtxt = None
281
+ if isinstance(r, str):
282
+ rtxt = r
283
+ elif hasattr(r, "content"):
284
+ rtxt = r.content
285
+ elif isinstance(r, dict):
286
+ msgs = r.get("messages")
287
+ if msgs:
288
+ last_msg = msgs[-1]
289
+ if isinstance(last_msg, str):
290
+ rtxt = last_msg
291
+ elif hasattr(last_msg, "content"):
292
+ rtxt = last_msg.content
293
+ elif isinstance(last_msg, dict):
294
+ rtxt = last_msg.get("content", "")
295
+ else:
296
+ rtxt = str(last_msg)
297
+ else:
298
+ rtxt = json.dumps(r)
299
+ else:
300
+ rtxt = str(r)
301
+
302
+ corrected = extract_json_from_llm_response(rtxt)
303
+ return corrected
304
+ except Exception as e2:
305
+ logger.exception("JSON resolver also failed: %s", e2)
306
+ return {}
307
+
308
+ # -------------------- Define LangGraph State schema -------------------------
309
+ class State(TypedDict):
310
+ patient_meta: NotRequired[Dict[str, Any]]
311
+ patient_id: str
312
+ documents: List[Dict[str, Any]]
313
+ medications: List[str]
314
+ patientDetails: NotRequired[Dict[str, Any]]
315
+ doctorDetails: NotRequired[Dict[str, Any]]
316
+ reports: NotRequired[List[Dict[str, Any]]]
317
+ overallAnalysis: NotRequired[Dict[str, Any]]
318
+ valid: NotRequired[bool]
319
+ missing: NotRequired[List[str]]
320
+
321
+ # -------------------- Node implementations as LangGraph nodes -------------------------
322
+ def patient_details_node(state: State) -> dict:
323
+ payload = {
324
+ "patient_meta": state.get("patient_meta", {}),
325
+ "documents": state.get("documents", []),
326
+ "medications": state.get("medications", [])
327
+ }
328
+ logger.info("Running patient_details_node")
329
+ out = call_node_agent(PATIENT_NODE_PROMPT, payload)
330
+ return {"patientDetails": out.get("patientDetails", {}) if isinstance(out, dict) else {}}
331
+
332
+ def doctor_details_node(state: State) -> dict:
333
+ payload = {
334
+ "documents": state.get("documents", []),
335
+ "medications": state.get("medications", [])
336
+ }
337
+ logger.info("Running doctor_details_node")
338
+ out = call_node_agent(DOCTOR_NODE_PROMPT, payload)
339
+ return {"doctorDetails": out.get("doctorDetails", {}) if isinstance(out, dict) else {}}
340
+
341
+ def test_report_node(state: State) -> dict:
342
+ payload = {
343
+ "documents": state.get("documents", []),
344
+ "medications": state.get("medications", [])
345
+ }
346
+ logger.info("Running test_report_node")
347
+ out = call_node_agent(TEST_REPORT_NODE_PROMPT, payload)
348
+ return {"reports": out.get("reports", []) if isinstance(out, dict) else []}
349
+
350
+ def analysis_node(state: State) -> dict:
351
+ payload = {
352
+ "patientDetails": state.get("patientDetails", {}),
353
+ "doctorDetails": state.get("doctorDetails", {}),
354
+ "reports": state.get("reports", []),
355
+ "medications": state.get("medications", [])
356
+ }
357
+ logger.info("Running analysis_node")
358
+ out = call_node_agent(ANALYSIS_NODE_PROMPT, payload)
359
+ return {"overallAnalysis": out.get("overallAnalysis", {}) if isinstance(out, dict) else {}}
360
+
361
+ def condition_loop_node(state: State) -> dict:
362
+ payload = {
363
+ "patientDetails": state.get("patientDetails", {}),
364
+ "doctorDetails": state.get("doctorDetails", {}),
365
+ "reports": state.get("reports", []),
366
+ "overallAnalysis": state.get("overallAnalysis", {})
367
+ }
368
+ logger.info("Running condition_loop_node (validation)")
369
+ out = call_node_agent(CONDITION_LOOP_NODE_PROMPT, payload)
370
+ if isinstance(out, dict) and "valid" in out:
371
+ return {"valid": bool(out.get("valid")), "missing": out.get("missing", [])}
372
+ missing = []
373
+ if not state.get("patientDetails"):
374
+ missing.append("patientDetails")
375
+ if not state.get("reports"):
376
+ missing.append("reports")
377
+ return {"valid": len(missing) == 0, "missing": missing}
378
+
379
+ # -------------------- Build LangGraph StateGraph -------------------------
380
+ graph_builder = StateGraph(State)
381
+
382
+ graph_builder.add_node("patient_details", patient_details_node)
383
+ graph_builder.add_node("doctor_details", doctor_details_node)
384
+ graph_builder.add_node("test_report", test_report_node)
385
+ graph_builder.add_node("analysis", analysis_node)
386
+ graph_builder.add_node("condition_loop", condition_loop_node)
387
+
388
+ graph_builder.add_edge(START, "patient_details")
389
+ graph_builder.add_edge("patient_details", "doctor_details")
390
+ graph_builder.add_edge("doctor_details", "test_report")
391
+ graph_builder.add_edge("test_report", "analysis")
392
+ graph_builder.add_edge("analysis", "condition_loop")
393
+ graph_builder.add_edge("condition_loop", END)
394
+
395
+ graph = graph_builder.compile()
396
+
397
+ # -------------------- Flask app & endpoints -------------------------------
398
+ BASE_DIR = Path(__file__).resolve().parent
399
+ static_folder = BASE_DIR / "static"
400
+ app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static")
401
+ CORS(app) # dev convenience; lock down in production
402
+
403
+ # serve frontend root
404
+ @app.route("/", methods=["GET"])
405
+ def serve_frontend():
406
+ try:
407
+ return app.send_static_file("frontend.html")
408
+ except Exception:
409
+ return "<h3>frontend.html not found in static/ — drop your frontend.html there.</h3>", 404
410
+
411
+ @app.route("/process_reports", methods=["POST"])
412
+ def process_reports():
413
+ data = request.get_json(force=True)
414
+ patient_id = data.get("patient_id")
415
+ filenames = data.get("filenames", [])
416
+ extra_patient_meta = data.get("patientDetails", {})
417
+
418
+ if not patient_id or not filenames:
419
+ return jsonify({"error": "missing patient_id or filenames"}), 400
420
+
421
+ patient_folder = REPORTS_ROOT / str(patient_id)
422
+ if not patient_folder.exists() or not patient_folder.is_dir():
423
+ return jsonify({"error": f"patient folder not found: {patient_folder}"}), 404
424
+
425
+ documents = []
426
+ combined_text_parts = []
427
+
428
+ for fname in filenames:
429
+ file_path = patient_folder / fname
430
+ if not file_path.exists():
431
+ logger.warning("file not found: %s", file_path)
432
+ continue
433
+ try:
434
+ elements = partition_pdf(filename=str(file_path))
435
+ page_text = "\n".join([el.text for el in elements if hasattr(el, "text") and el.text])
436
+ except Exception:
437
+ logger.exception("Failed to parse PDF %s", file_path)
438
+ page_text = ""
439
+ cleaned = clean_notes_with_bloatectomy(page_text, style="remov")
440
+ documents.append({
441
+ "filename": fname,
442
+ "raw_text": page_text,
443
+ "cleaned_text": cleaned
444
+ })
445
+ combined_text_parts.append(cleaned)
446
+
447
+ if not documents:
448
+ return jsonify({"error": "no valid documents found"}), 400
449
+
450
+ combined_text = "\n\n".join(combined_text_parts)
451
+ meds = extract_medications_from_text(combined_text)
452
+
453
+ initial_state = {
454
+ "patient_meta": extra_patient_meta,
455
+ "patient_id": patient_id,
456
+ "documents": documents,
457
+ "medications": meds
458
+ }
459
+
460
+ try:
461
+ result_state = graph.invoke(initial_state)
462
+
463
+ # Validate and fill placeholders if needed
464
+ if not result_state.get("valid", True):
465
+ missing = result_state.get("missing", [])
466
+ logger.info("Validation failed; missing keys: %s", missing)
467
+ if "patientDetails" in missing:
468
+ result_state["patientDetails"] = extra_patient_meta or {"name": "", "age": "", "sex": "", "pid": patient_id}
469
+ if "reports" in missing:
470
+ result_state["reports"] = []
471
+ # Re-run analysis node to keep overallAnalysis consistent
472
+ result_state.update(analysis_node(result_state))
473
+ # Re-validate
474
+ cond = condition_loop_node(result_state)
475
+ result_state.update(cond)
476
+
477
+ safe_response = {
478
+ "patientDetails": result_state.get("patientDetails", {"name": "", "age": "", "sex": "", "pid": patient_id}),
479
+ "doctorDetails": result_state.get("doctorDetails", {"referredBy": ""}),
480
+ "reports": result_state.get("reports", []),
481
+ "overallAnalysis": result_state.get("overallAnalysis", {"summary": "", "recommendations": "", "longTermTrends": ""}),
482
+ "_pre_extracted_medications": result_state.get("medications", []),
483
+ "_validation": {
484
+ "valid": result_state.get("valid", True),
485
+ "missing": result_state.get("missing", [])
486
+ }
487
+ }
488
+ return jsonify(safe_response), 200
489
+
490
+ except Exception as e:
491
+ logger.exception("Node pipeline failed")
492
+ return jsonify({"error": "Node pipeline failed", "detail": str(e)}), 500
493
+
494
+ @app.route("/ping", methods=["GET"])
495
+ def ping():
496
+ return jsonify({"status": "ok"})
497
+
498
+ if __name__ == "__main__":
499
+ port = int(os.getenv("PORT", 5000))
500
+ app.run(host="0.0.0.0", port=port, debug=True)