WebashalarForML commited on
Commit
42e73f2
·
verified ·
1 Parent(s): bc47476

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -543
app.py CHANGED
@@ -1,117 +1,48 @@
1
  #!/usr/bin/env python3
2
- # app.py - Health Reports processing agent (PDF -> cleaned text -> structured JSON)
3
-
4
  import os
5
  import json
6
  import logging
7
  import re
 
8
  from pathlib import Path
9
- from typing import List, Dict, Any
10
- from werkzeug.utils import secure_filename
11
  from flask import Flask, request, jsonify
12
  from flask_cors import CORS
13
  from dotenv import load_dotenv
14
- from unstructured.partition.pdf import partition_pdf
15
-
16
- # Bloatectomy class (as per the source you provided)
17
  from bloatectomy import bloatectomy
18
-
19
- # LLM / agent
20
  from langchain_groq import ChatGroq
21
- from langgraph.prebuilt import create_react_agent
22
-
23
- # LangGraph imports
24
- from langgraph.graph import StateGraph, START, END
25
  from typing_extensions import TypedDict, NotRequired
26
 
27
- # --- Logging ---------------------------------------------------------------
28
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
29
- logger = logging.getLogger("health-agent")
30
 
31
- # --- Environment & config -------------------------------------------------
32
  load_dotenv()
33
- from pathlib import Path
34
- REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", "reports")).resolve() # e.g. /app/reports/<patient_id>/<file.pdf>
35
- SSRI_FILE = Path(os.getenv("SSRI_FILE", "app/medicationCategories/SSRI_list.txt")).resolve()
36
- MISC_FILE = Path(os.getenv("MISC_FILE", "app/medicationCategories/MISC_list.txt")).resolve()
37
- GROQ_API_KEY = os.getenv("GROQ_API_KEY", None)
38
- ALLOWED_EXTENSIONS = {"pdf"}
39
- # --- LLM setup -------------------------------------------------------------
 
 
 
 
 
 
 
40
  llm = ChatGroq(
41
  model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
42
  temperature=0.0,
43
- max_tokens=None,
 
44
  )
45
 
46
- # Top-level strict system prompt for report JSON pieces (each node will use a more specific prompt)
47
- NODE_BASE_INSTRUCTIONS = """
48
- You are HealthAI — a clinical assistant producing JSON for downstream processing.
49
- Produce only valid JSON (no extra text). Follow field types exactly. If missing data, return empty strings or empty arrays.
50
- Be conservative: do not assert diagnoses; provide suggestions and ask physician confirmation where needed.
51
- """
52
-
53
- # Build a generic agent and a JSON resolver agent (to fix broken JSON from LLM)
54
- agent = create_react_agent(model=llm, tools=[], prompt=NODE_BASE_INSTRUCTIONS)
55
- agent_json_resolver = create_react_agent(model=llm, tools=[], prompt="""
56
- You are a JSON fixer. Input: a possibly-malformed JSON-like text. Output: valid JSON only (enclosed in triple backticks).
57
- Fix missing quotes, trailing commas, unescaped newlines, stray assistant labels, and ensure schema compliance.
58
- """)
59
-
60
- # -------------------- JSON extraction / sanitizer ---------------------------
61
- def extract_json_from_llm_response(raw_response: str) -> dict:
62
- try:
63
- # --- 1) Pull out the JSON code-block if present ---
64
- md = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response)
65
- json_string = md.group(1).strip() if md else raw_response
66
-
67
- # --- 2) Trim to the outermost { … } so we drop any prefix/suffix junk ---
68
- first, last = json_string.find('{'), json_string.rfind('}')
69
- if 0 <= first < last:
70
- json_string = json_string[first:last+1]
71
-
72
- # --- 3) PRE-CLEANUP: remove rogue assistant labels, fix boolean quotes ---
73
- json_string = re.sub(r'\b\w+\s*{', '{', json_string)
74
- json_string = re.sub(r'"assistant"\s*:', '', json_string)
75
- json_string = re.sub(r'\b(false|true)"', r'\1', json_string)
76
-
77
- # --- 4) Escape embedded quotes in long string fields (best-effort) ---
78
- def _esc(m):
79
- prefix, body = m.group(1), m.group(2)
80
- return prefix + body.replace('"', r'\"')
81
- json_string = re.sub(
82
- r'("logic"\s*:\s*")([\s\S]+?)(?=",\s*"[A-Za-z_]\w*"\s*:\s*)',
83
- _esc,
84
- json_string
85
- )
86
-
87
- # --- 5) Remove trailing commas before } or ] ---
88
- json_string = re.sub(r',\s*(?=[}\],])', '', json_string)
89
- json_string = re.sub(r',\s*,', ',', json_string)
90
-
91
- # --- 6) Balance braces if obvious excess ---
92
- ob, cb = json_string.count('{'), json_string.count('}')
93
- if cb > ob:
94
- excess = cb - ob
95
- json_string = json_string.rstrip()[:-excess]
96
-
97
- # --- 7) Escape literal newlines inside strings so json.loads can parse ---
98
- def _escape_newlines_in_strings(s: str) -> str:
99
- return re.sub(
100
- r'"((?:[^"\\]|\\.)*?)"',
101
- lambda m: '"' + m.group(1).replace('\n', '\\n').replace('\r', '\\r') + '"',
102
- s,
103
- flags=re.DOTALL
104
- )
105
- json_string = _escape_newlines_in_strings(json_string)
106
-
107
- # Final parse
108
- return json.loads(json_string)
109
- except Exception as e:
110
- logger.error(f"Failed to extract JSON from LLM response: {e}")
111
- raise
112
-
113
- # -------------------- Utility: Bloatectomy wrapper ------------------------
114
  def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
 
115
  try:
116
  b = bloatectomy(text, style=style, output="html")
117
  tokens = getattr(b, "tokens", None)
@@ -122,480 +53,185 @@ def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
122
  logger.exception("Bloatectomy cleaning failed; returning original text")
123
  return text
124
 
125
- # --------------- Utility: medication extraction (adapted) -----------------
126
- def readDrugs_from_file(path: Path):
127
- try:
128
- if not path.exists():
129
- return {}, []
130
- txt = path.read_text(encoding="utf-8", errors="ignore")
131
- generics = re.findall(r"^(.*?)\|", txt, re.MULTILINE)
132
- generics = [g.lower() for g in generics if g]
133
- lines = [ln.strip().lower() for ln in txt.splitlines() if ln.strip()]
134
- return dict(zip(generics, lines)), generics
135
- except Exception:
136
- logger.exception(f"Failed to read drugs from file: {path}")
137
- return {}, []
138
-
139
- def addToDrugs_line(line: str, drugs_flags: List[int], listing: Dict[str,str], genList: List[str]) -> List[int]:
140
- try:
141
- gen_index = {g:i for i,g in enumerate(genList)}
142
- for generic, pattern_line in listing.items():
143
- try:
144
- if re.search(pattern_line, line, re.I):
145
- idx = gen_index.get(generic)
146
- if idx is not None:
147
- drugs_flags[idx] = 1
148
- except re.error:
149
- continue
150
- return drugs_flags
151
- except Exception:
152
- logger.exception("Error in addToDrugs_line")
153
- return drugs_flags
154
-
155
- def extract_medications_from_text(text: str) -> List[str]:
156
- try:
157
- ssri_map, ssri_generics = readDrugs_from_file(SSRI_FILE)
158
- misc_map, misc_generics = readDrugs_from_file(MISC_FILE)
159
- combined_map = {**ssri_map, **misc_map}
160
- combined_generics = []
161
- if ssri_generics:
162
- combined_generics.extend(ssri_generics)
163
- if misc_generics:
164
- combined_generics.extend(misc_generics)
165
-
166
- flags = [0]* len(combined_generics)
167
- meds_found = set()
168
- for ln in text.splitlines():
169
- ln = ln.strip()
170
- if not ln:
171
- continue
172
- if combined_map:
173
- flags = addToDrugs_line(ln, flags, combined_map, combined_generics)
174
- m = re.search(r"\b(Rx|Drug|Medication|Prescribed|Tablet)\s*[:\-]?\s*([A-Za-z0-9\-\s/\.]+)", ln, re.I)
175
- if m:
176
- meds_found.add(m.group(2).strip())
177
- m2 = re.findall(r"\b([A-Z][a-z0-9\-]{2,}\s*(?:[0-9]{1,4}\s*(?:mg|mcg|g|IU))?)", ln)
178
- for s in m2:
179
- if re.search(r"\b(mg|mcg|g|IU)\b", s, re.I):
180
- meds_found.add(s.strip())
181
- for i, f in enumerate(flags):
182
- if f == 1:
183
- meds_found.add(combined_generics[i])
184
- return list(meds_found)
185
- except Exception:
186
- logger.exception("Failed to extract medications from text")
187
- return []
188
-
189
- # -------------------- Node prompts --------------------------
190
- PATIENT_NODE_PROMPT = """
191
- You will extract patientDetails from the provided document texts.
192
- Return ONLY JSON with this exact shape:
193
- { "patientDetails": {"name": "", "age": "", "sex": "", "pid": ""} }
194
- Fill fields using text evidence or leave empty strings.
195
  """
196
 
197
- DOCTOR_NODE_PROMPT = """
198
- You will extract doctorDetails found in the documents.
199
- Return ONLY JSON with this exact shape:
200
- { "doctorDetails": {"referredBy": ""} }
201
- """
 
 
 
202
 
203
- TEST_REPORT_NODE_PROMPT = """
204
- You will extract per-test structured results from the documents.
205
- Return ONLY JSON with this exact shape:
206
- {
207
- "reports": [
208
- {
209
- "testName": "",
210
- "dateReported": "",
211
- "timeReported": "",
212
- "abnormalFindings": [
213
- {"investigation": "", "result": 0, "unit": "", "status": "", "referenceValue": ""}
214
- ],
215
- "interpretation": "",
216
- "trends": []
217
- }
218
- ]
219
- }
220
- - Include only findings that are outside reference ranges OR explicitly called 'abnormal' in the report.
221
- - For result numeric parsing, prefer numeric values; if not numeric, keep original string.
222
- - Use statuses: Low, High, Borderline, Positive, Negative, Normal.
223
- """
224
 
225
- ANALYSIS_NODE_PROMPT = """
226
- You will create an overallAnalysis based on the extracted reports (the agent will give you the 'reports' JSON).
227
- Return ONLY JSON:
228
- { "overallAnalysis": { "summary": "", "recommendations": "", "longTermTrends": "",""risk_prediction": "","drug_interaction": "" } }
229
- Be conservative, evidence-based, and suggest follow-up steps for physicians.
230
- """
231
 
232
- CONDITION_LOOP_NODE_PROMPT = """
233
- Validation and condition node:
234
- Input: partial JSON (patientDetails, doctorDetails, reports, overallAnalysis).
235
- Task: Check required keys exist and that each report has at least testName and abnormalFindings list.
236
- Return ONLY JSON:
237
- { "valid": true, "missing": [] }
238
- If missing fields, list keys in 'missing'. Do NOT modify content.
239
- """
 
 
 
 
 
240
 
241
- # -------------------- Node helpers -------------------------
242
- def call_node_agent(node_prompt: str, payload: dict) -> dict:
243
- """
244
- Call the generic agent with a targeted node prompt and the payload.
245
- Tries to parse JSON. If parsing fails, uses the JSON resolver agent once.
246
- """
247
  try:
248
- content = {
249
- "prompt": node_prompt,
250
- "payload": payload
251
- }
252
- resp = agent.invoke({"messages": [{"role": "user", "content": json.dumps(content)}]})
253
-
254
- # Extract raw text from AIMessage or other response types
255
- raw = None
256
- if isinstance(resp, str):
257
- raw = resp
258
- elif hasattr(resp, "content"): # AIMessage or similar
259
- raw = resp.content
260
- elif isinstance(resp, dict):
261
- msgs = resp.get("messages")
262
- if msgs:
263
- last_msg = msgs[-1]
264
- if isinstance(last_msg, str):
265
- raw = last_msg
266
- elif hasattr(last_msg, "content"):
267
- raw = last_msg.content
268
- elif isinstance(last_msg, dict):
269
- raw = last_msg.get("content", "")
270
- else:
271
- raw = str(last_msg)
272
- else:
273
- raw = json.dumps(resp)
274
- else:
275
- raw = str(resp)
276
 
277
- parsed = extract_json_from_llm_response(raw)
 
 
 
278
  return parsed
 
 
 
279
 
280
- except Exception as e:
281
- logger.warning("Node agent JSON parse failed: %s. Attempting JSON resolver.", e)
282
- try:
283
- resolver_prompt = f"Fix this JSON. Input:\n```json\n{raw}\n```\nReturn valid JSON only."
284
- r = agent_json_resolver.invoke({"messages": [{"role": "user", "content": resolver_prompt}]})
285
-
286
- rtxt = None
287
- if isinstance(r, str):
288
- rtxt = r
289
- elif hasattr(r, "content"):
290
- rtxt = r.content
291
- elif isinstance(r, dict):
292
- msgs = r.get("messages")
293
- if msgs:
294
- last_msg = msgs[-1]
295
- if isinstance(last_msg, str):
296
- rtxt = last_msg
297
- elif hasattr(last_msg, "content"):
298
- rtxt = last_msg.content
299
- elif isinstance(last_msg, dict):
300
- rtxt = last_msg.get("content", "")
301
- else:
302
- rtxt = str(last_msg)
303
- else:
304
- rtxt = json.dumps(r)
305
- else:
306
- rtxt = str(r)
307
-
308
- corrected = extract_json_from_llm_response(rtxt)
309
- return corrected
310
- except Exception as e2:
311
- logger.exception("JSON resolver also failed: %s", e2)
312
- return {}
313
-
314
- # -------------------- Define LangGraph State schema -------------------------
315
- class State(TypedDict):
316
- patient_meta: NotRequired[Dict[str, Any]]
317
- patient_id: str
318
- documents: List[Dict[str, Any]]
319
- medications: List[str]
320
- patientDetails: NotRequired[Dict[str, Any]]
321
- doctorDetails: NotRequired[Dict[str, Any]]
322
- reports: NotRequired[List[Dict[str, Any]]]
323
- overallAnalysis: NotRequired[Dict[str, Any]]
324
- valid: NotRequired[bool]
325
- missing: NotRequired[List[str]]
326
-
327
- # -------------------- Node implementations as LangGraph nodes -------------------------
328
- def patient_details_node(state: State) -> dict:
329
- payload = {
330
- "patient_meta": state.get("patient_meta", {}),
331
- "documents": state.get("documents", []),
332
- "medications": state.get("medications", [])
333
- }
334
- logger.info("Running patient_details_node")
335
- out = call_node_agent(PATIENT_NODE_PROMPT, payload)
336
- return {"patientDetails": out.get("patientDetails", {}) if isinstance(out, dict) else {}}
337
-
338
- def doctor_details_node(state: State) -> dict:
339
- payload = {
340
- "documents": state.get("documents", []),
341
- "medications": state.get("medications", [])
342
- }
343
- logger.info("Running doctor_details_node")
344
- out = call_node_agent(DOCTOR_NODE_PROMPT, payload)
345
- return {"doctorDetails": out.get("doctorDetails", {}) if isinstance(out, dict) else {}}
346
-
347
- def test_report_node(state: State) -> dict:
348
- payload = {
349
- "documents": state.get("documents", []),
350
- "medications": state.get("medications", [])
351
- }
352
- logger.info("Running test_report_node")
353
- out = call_node_agent(TEST_REPORT_NODE_PROMPT, payload)
354
- return {"reports": out.get("reports", []) if isinstance(out, dict) else []}
355
-
356
- def analysis_node(state: State) -> dict:
357
- payload = {
358
- "patientDetails": state.get("patientDetails", {}),
359
- "doctorDetails": state.get("doctorDetails", {}),
360
- "reports": state.get("reports", []),
361
- "medications": state.get("medications", [])
362
- }
363
- logger.info("Running analysis_node")
364
- out = call_node_agent(ANALYSIS_NODE_PROMPT, payload)
365
- return {"overallAnalysis": out.get("overallAnalysis", {}) if isinstance(out, dict) else {}}
366
-
367
- def condition_loop_node(state: State) -> dict:
368
- payload = {
369
- "patientDetails": state.get("patientDetails", {}),
370
- "doctorDetails": state.get("doctorDetails", {}),
371
- "reports": state.get("reports", []),
372
- "overallAnalysis": state.get("overallAnalysis", {})
373
- }
374
- logger.info("Running condition_loop_node (validation)")
375
- out = call_node_agent(CONDITION_LOOP_NODE_PROMPT, payload)
376
- if isinstance(out, dict) and "valid" in out:
377
- return {"valid": bool(out.get("valid")), "missing": out.get("missing", [])}
378
- missing = []
379
- if not state.get("patientDetails"):
380
- missing.append("patientDetails")
381
- if not state.get("reports"):
382
- missing.append("reports")
383
- return {"valid": len(missing) == 0, "missing": missing}
384
-
385
- # -------------------- Build LangGraph StateGraph -------------------------
386
- graph_builder = StateGraph(State)
387
-
388
- graph_builder.add_node("patient_details", patient_details_node)
389
- graph_builder.add_node("doctor_details", doctor_details_node)
390
- graph_builder.add_node("test_report", test_report_node)
391
- graph_builder.add_node("analysis", analysis_node)
392
- graph_builder.add_node("condition_loop", condition_loop_node)
393
-
394
- graph_builder.add_edge(START, "patient_details")
395
- graph_builder.add_edge("patient_details", "doctor_details")
396
- graph_builder.add_edge("doctor_details", "test_report")
397
- graph_builder.add_edge("test_report", "analysis")
398
- graph_builder.add_edge("analysis", "condition_loop")
399
- graph_builder.add_edge("condition_loop", END)
400
-
401
- graph = graph_builder.compile()
402
-
403
- # -------------------- Flask app & endpoints -------------------------------
404
- # -------------------- Flask app & endpoints -------------------------------
405
- BASE_DIR = Path(__file__).resolve().parent
406
- static_folder = BASE_DIR / "static"
407
- app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static")
408
- CORS(app) # dev convenience; lock down in production
409
-
410
- # serve frontend root
411
  @app.route("/", methods=["GET"])
412
  def serve_frontend():
 
413
  try:
414
- return app.send_static_file("frontend.html")
415
- except Exception as e:
416
- logger.error(f"Failed to serve frontend.html: {e}")
417
- return "<h3>frontend.html not found in static/ — drop your frontend.html there.</h3>", 404
418
 
419
- @app.route("/process_reports", methods=["POST"])
420
- def process_reports():
421
- try:
422
- data = request.get_json(force=True)
423
- except Exception as e:
424
- logger.error(f"Failed to parse JSON request: {e}")
425
- return jsonify({"error": "Invalid JSON request"}), 400
426
 
427
  patient_id = data.get("patient_id")
428
- filenames = data.get("filenames", [])
429
- extra_patient_meta = data.get("patientDetails", {})
430
 
431
- if not patient_id or not filenames:
432
- return jsonify({"error": "missing patient_id or filenames"}), 400
433
 
434
- patient_folder = REPORTS_ROOT / str(patient_id)
435
- if not patient_folder.exists() or not patient_folder.is_dir():
436
- return jsonify({"error": f"patient folder not found: {patient_folder}"}), 404
437
-
438
- documents = []
439
  combined_text_parts = []
440
-
441
- for fname in filenames:
442
- file_path = patient_folder / fname
443
- if not file_path.exists():
444
- logger.warning("file not found: %s", file_path)
445
- continue
446
- try:
447
- elements = partition_pdf(filename=str(file_path))
448
- page_text = "\n".join([el.text for el in elements if hasattr(el, "text") and el.text])
449
- except Exception:
450
- logger.exception(f"Failed to parse PDF {file_path}")
451
  page_text = ""
452
- try:
453
- cleaned = clean_notes_with_bloatectomy(page_text, style="remov")
454
- except Exception:
455
- logger.exception("Failed to clean notes with bloatectomy")
456
- cleaned = page_text
457
- documents.append({
458
- "filename": fname,
459
- "raw_text": page_text,
460
- "cleaned_text": cleaned
461
- })
462
- combined_text_parts.append(cleaned)
463
-
464
- if not documents:
465
- return jsonify({"error": "no valid documents found"}), 400
466
-
467
- combined_text = "\n\n".join(combined_text_parts)
468
- try:
469
- meds = extract_medications_from_text(combined_text)
470
- except Exception:
471
- logger.exception("Failed to extract medications")
472
- meds = []
473
-
474
- initial_state = {
475
- "patient_meta": extra_patient_meta,
476
- "patient_id": patient_id,
477
- "documents": documents,
478
- "medications": meds
479
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
 
 
 
 
 
 
481
  try:
482
- result_state = graph.invoke(initial_state)
483
-
484
- # Validate and fill placeholders if needed
485
- if not result_state.get("valid", True):
486
- missing = result_state.get("missing", [])
487
- logger.info(f"Validation failed; missing keys: {missing}")
488
- if "patientDetails" in missing:
489
- result_state["patientDetails"] = extra_patient_meta or {"name": "", "age": "", "sex": "", "pid": patient_id}
490
- if "reports" in missing:
491
- result_state["reports"] = []
492
- # Re-run analysis node to keep overallAnalysis consistent
493
- result_state.update(analysis_node(result_state))
494
- # Re-validate
495
- cond = condition_loop_node(result_state)
496
- result_state.update(cond)
497
-
498
- safe_response = {
499
- "patientDetails": result_state.get("patientDetails", {"name": "", "age": "", "sex": "", "pid": patient_id}),
500
- "doctorDetails": result_state.get("doctorDetails", {"referredBy": ""}),
501
- "reports": result_state.get("reports", []),
502
- "overallAnalysis": result_state.get("overallAnalysis", {"summary": "", "recommendations": "", "longTermTrends": ""}),
503
- "_pre_extracted_medications": result_state.get("medications", []),
504
- "_validation": {
505
- "valid": result_state.get("valid", True),
506
- "missing": result_state.get("missing", [])
507
- }
508
- }
509
- return jsonify(safe_response), 200
510
 
 
 
 
511
  except Exception as e:
512
- logger.exception("Node pipeline failed")
513
- return jsonify({"error": "Node pipeline failed", "detail": str(e)}), 500
514
-
515
- @app.route("/upload_reports", methods=["POST"])
516
- def upload_reports():
517
- """
518
- Upload one or more files for a patient.
519
-
520
- Expects multipart/form-data with:
521
- - patient_id (form field)
522
- - files (one or multiple files; use the same field name 'files' for each file)
523
-
524
- Example curl:
525
- curl -X POST http://localhost:7860/upload_reports \
526
- -F "patient_id=12345" \
527
- -F "files[]=@/path/to/report1.pdf" \
528
- -F "files[]=@/path/to/report2.pdf"
529
- """
530
- try:
531
- # patient id can be in form or args (for convenience)
532
- patient_id = request.form.get("patient_id") or request.args.get("patient_id")
533
- if not patient_id:
534
- return jsonify({"error": "patient_id form field required"}), 400
535
-
536
- # get uploaded files (support both files and files[] naming)
537
- uploaded_files = request.files.getlist("files")
538
- if not uploaded_files:
539
- # fallback: single file under name 'file'
540
- single = request.files.get("file")
541
- if single:
542
- uploaded_files = [single]
543
-
544
- if not uploaded_files:
545
- return jsonify({"error": "no files uploaded (use form field 'files')"}), 400
546
-
547
- # create patient folder under REPORTS_ROOT/<patient_id>
548
- patient_folder = REPORTS_ROOT / str(patient_id)
549
- patient_folder.mkdir(parents=True, exist_ok=True)
550
-
551
- saved = []
552
- skipped = []
553
-
554
- for file_storage in uploaded_files:
555
- orig_name = getattr(file_storage, "filename", "") or ""
556
- filename = secure_filename(orig_name)
557
- if not filename:
558
- skipped.append({"filename": orig_name, "reason": "invalid filename"})
559
- continue
560
-
561
- # extension check
562
- ext = filename.rsplit(".", 1)[1].lower() if "." in filename else ""
563
- if ext not in ALLOWED_EXTENSIONS:
564
- skipped.append({"filename": filename, "reason": f"extension '{ext}' not allowed"})
565
- continue
566
-
567
- # avoid overwriting: if collision, add numeric suffix
568
- dest = patient_folder / filename
569
- if dest.exists():
570
- base, dot, extension = filename.rpartition(".")
571
- # if no base (e.g. ".bashrc") fallback
572
- base = base or filename
573
- i = 1
574
- while True:
575
- candidate = f"{base}__{i}.{extension}" if extension else f"{base}__{i}"
576
- dest = patient_folder / candidate
577
- if not dest.exists():
578
- filename = candidate
579
- break
580
- i += 1
581
-
582
- try:
583
- file_storage.save(str(dest))
584
- saved.append(filename)
585
- except Exception as e:
586
- logger.exception("Failed to save uploaded file %s: %s", filename, e)
587
- skipped.append({"filename": filename, "reason": f"save failed: {e}"})
588
-
589
- return jsonify({
590
- "patient_id": str(patient_id),
591
- "saved": saved,
592
- "skipped": skipped,
593
- "patient_folder": str(patient_folder)
594
- }), 200
595
-
596
- except Exception as exc:
597
- logger.exception("Upload failed: %s", exc)
598
- return jsonify({"error": "upload failed", "detail": str(exc)}), 500
599
 
600
  @app.route("/ping", methods=["GET"])
601
  def ping():
@@ -604,4 +240,3 @@ def ping():
604
  if __name__ == "__main__":
605
  port = int(os.getenv("PORT", 7860))
606
  app.run(host="0.0.0.0", port=port, debug=True)
607
-
 
1
  #!/usr/bin/env python3
 
 
2
  import os
3
  import json
4
  import logging
5
  import re
6
+ from typing import Dict, Any
7
  from pathlib import Path
8
+ from unstructured.partition.pdf import partition_pdf
 
9
  from flask import Flask, request, jsonify
10
  from flask_cors import CORS
11
  from dotenv import load_dotenv
 
 
 
12
  from bloatectomy import bloatectomy
13
+ from werkzeug.utils import secure_filename
 
14
  from langchain_groq import ChatGroq
 
 
 
 
15
  from typing_extensions import TypedDict, NotRequired
16
 
17
+ # --- Logging ---
18
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
19
+ logger = logging.getLogger("patient-assistant")
20
 
21
+ # --- Load environment ---
22
  load_dotenv()
23
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
24
+ if not GROQ_API_KEY:
25
+ logger.error("GROQ_API_KEY not set in environment")
26
+ exit(1)
27
+
28
+ # --- Flask app setup ---
29
+ BASE_DIR = Path(__file__).resolve().parent
30
+ REPORTS_ROOT = Path(os.getenv("REPORTS_ROOT", str(BASE_DIR / "reports")))
31
+ static_folder = BASE_DIR / "static"
32
+
33
+ app = Flask(__name__, static_folder=str(static_folder), static_url_path="/static")
34
+ CORS(app)
35
+
36
+ # --- LLM setup ---
37
  llm = ChatGroq(
38
  model=os.getenv("LLM_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct"),
39
  temperature=0.0,
40
+ max_tokens=1024,
41
+ api_key=GROQ_API_KEY,
42
  )
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def clean_notes_with_bloatectomy(text: str, style: str = "remov") -> str:
45
+ """Helper function to clean up text using the bloatectomy library."""
46
  try:
47
  b = bloatectomy(text, style=style, output="html")
48
  tokens = getattr(b, "tokens", None)
 
53
  logger.exception("Bloatectomy cleaning failed; returning original text")
54
  return text
55
 
56
+ # --- Agent prompt instructions ---
57
+ PATIENT_ASSISTANT_PROMPT = """
58
+ You are a patient assistant helping to analyze medical records and reports. Your goal is to provide a comprehensive response based on the patient's health history and the current conversation.
59
+
60
+ Your tasks include:
61
+ - Analyzing medical records and reports to detect anomalies, redundant tests, or misleading treatments.
62
+ - Suggesting preventive care based on the overall patient health history.
63
+ - Optimizing healthcare costs by comparing past visits and treatments, helping patients make smarter choices.
64
+ - Offering personalized lifestyle recommendations, such as adopting healthier food habits, daily routines, and regular health checks.
65
+ - Generating a natural, helpful reply to the user.
66
+
67
+ You will be provided with the last user message, the conversation history, and a summary of the patient's medical reports. Use this information to give a tailored and informative response.
68
+
69
+ STRICT OUTPUT FORMAT (JSON ONLY):
70
+ Return a single JSON object with the following keys:
71
+ - assistant_reply: string // a natural language reply to the user (short, helpful, always present)
72
+ - patientDetails: object // keys may include name, problem, city, contact (update if user shared info)
73
+ - conversationSummary: string (optional) // short summary of conversation + relevant patient docs
74
+
75
+ Rules:
76
+ - ALWAYS include `assistant_reply` as a non-empty string.
77
+ - Do NOT produce any text outside the JSON object.
78
+ - Be concise in `assistant_reply`. If you need more details, ask a targeted follow-up question.
79
+ - Do not make up information that is not present in the provided medical reports or conversation history.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  """
81
 
82
+ # --- JSON extraction helper ---
83
+ def extract_json_from_llm_response(raw_response: str) -> dict:
84
+ """Safely extracts a JSON object from a string that might contain extra text or markdown."""
85
+ default = {
86
+ "assistant_reply": "I'm sorry — I couldn't understand that. Could you please rephrase?",
87
+ "patientDetails": {},
88
+ "conversationSummary": "",
89
+ }
90
 
91
+ if not raw_response or not isinstance(raw_response, str):
92
+ return default
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ # Find the JSON object, ignoring any markdown code fences
95
+ m = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", raw_response)
96
+ json_string = m.group(1).strip() if m else raw_response
 
 
 
97
 
98
+ # Find the first opening brace and the last closing brace
99
+ first = json_string.find('{')
100
+ last = json_string.rfind('}')
101
+ if first == -1 or last == -1 or first >= last:
102
+ try:
103
+ return json.loads(json_string)
104
+ except Exception:
105
+ logger.warning("Could not locate JSON braces in LLM output. Falling back to default.")
106
+ return default
107
+
108
+ candidate = json_string[first:last+1]
109
+ # Remove trailing commas that might cause parsing issues
110
+ candidate = re.sub(r',\s*(?=[}\]])', '', candidate)
111
 
 
 
 
 
 
 
112
  try:
113
+ parsed = json.loads(candidate)
114
+ except Exception as e:
115
+ logger.warning("Failed to parse JSON from LLM output: %s", e)
116
+ return default
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
+ # Basic validation of the parsed JSON
119
+ if isinstance(parsed, dict) and "assistant_reply" in parsed and isinstance(parsed["assistant_reply"], str) and parsed["assistant_reply"].strip():
120
+ parsed.setdefault("patientDetails", {})
121
+ parsed.setdefault("conversationSummary", "")
122
  return parsed
123
+ else:
124
+ logger.warning("Parsed JSON missing 'assistant_reply' or invalid format. Returning default.")
125
+ return default
126
 
127
+ # --- Flask routes ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  @app.route("/", methods=["GET"])
129
  def serve_frontend():
130
+ """Serves the frontend HTML file."""
131
  try:
132
+ return app.send_static_file("frontend2.html")
133
+ except Exception:
134
+ return "<h3>frontend2.html not found in static/ — please add your frontend2.html there.</h3>", 404
 
135
 
136
+ @app.route("/chat", methods=["POST"])
137
+ def chat():
138
+ """Handles the chat conversation with the assistant."""
139
+ data = request.get_json(force=True)
140
+ if not isinstance(data, dict):
141
+ return jsonify({"error": "invalid request body"}), 400
 
142
 
143
  patient_id = data.get("patient_id")
144
+ if not patient_id:
145
+ return jsonify({"error": "patient_id required"}), 400
146
 
147
+ chat_history = data.get("chat_history") or []
148
+ patient_state = data.get("patient_state") or {}
149
 
150
+ # --- Read and parse patient reports ---
151
+ patient_folder = REPORTS_ROOT / f"p_{patient_id}"
 
 
 
152
  combined_text_parts = []
153
+ if patient_folder.exists() and patient_folder.is_dir():
154
+ for fname in sorted(os.listdir(patient_folder)):
155
+ file_path = patient_folder / fname
 
 
 
 
 
 
 
 
156
  page_text = ""
157
+ if partition_pdf is not None and str(file_path).lower().endswith('.pdf'):
158
+ try:
159
+ elements = partition_pdf(filename=str(file_path))
160
+ page_text = "\n".join([el.text for el in elements if hasattr(el, 'text') and el.text])
161
+ except Exception:
162
+ logger.exception("Failed to parse PDF %s", file_path)
163
+ else:
164
+ try:
165
+ page_text = file_path.read_text(encoding='utf-8', errors='ignore')
166
+ except Exception:
167
+ page_text = ""
168
+
169
+ if page_text:
170
+ cleaned = clean_notes_with_bloatectomy(page_text, style="remov")
171
+ if cleaned:
172
+ combined_text_parts.append(cleaned)
173
+
174
+ # --- Prepare the state for the LLM ---
175
+ state = patient_state.copy()
176
+ state["lastUserMessage"] = ""
177
+ if chat_history:
178
+ # Find the last user message
179
+ for msg in reversed(chat_history):
180
+ if msg.get("role") == "user" and msg.get("content"):
181
+ state["lastUserMessage"] = msg["content"]
182
+ break
183
+
184
+ # Update the conversation summary with the parsed documents
185
+ base_summary = state.get("conversationSummary", "") or ""
186
+ docs_summary = "\n\n".join(combined_text_parts)
187
+ if docs_summary:
188
+ state["conversationSummary"] = (base_summary + "\n\n" + docs_summary).strip()
189
+ else:
190
+ state["conversationSummary"] = base_summary
191
+
192
+ # --- Direct LLM Invocation ---
193
+ user_prompt = f"""
194
+ Current patientDetails: {json.dumps(state.get("patientDetails", {}))}
195
+ Current conversationSummary: {state.get("conversationSummary", "")}
196
+ Last user message: {state.get("lastUserMessage", "")}
197
+
198
+ Return ONLY valid JSON with keys: assistant_reply, patientDetails, conversationSummary.
199
+ """
200
 
201
+ messages = [
202
+ {"role": "system", "content": PATIENT_ASSISTANT_PROMPT},
203
+ {"role": "user", "content": user_prompt}
204
+ ]
205
+
206
  try:
207
+ logger.info("Invoking LLM with prepared state and prompt...")
208
+ llm_response = llm.invoke(messages)
209
+ raw_response = ""
210
+ if hasattr(llm_response, "content"):
211
+ raw_response = llm_response.content
212
+ else:
213
+ raw_response = str(llm_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ logger.info(f"Raw LLM response: {raw_response}")
216
+ parsed_result = extract_json_from_llm_response(raw_response)
217
+
218
  except Exception as e:
219
+ logger.exception("LLM invocation failed")
220
+ return jsonify({"error": "LLM invocation failed", "detail": str(e)}), 500
221
+
222
+ updated_state = parsed_result or {}
223
+
224
+ assistant_reply = updated_state.get("assistant_reply")
225
+ if not assistant_reply or not isinstance(assistant_reply, str) or not assistant_reply.strip():
226
+ # Fallback to a polite message if the LLM response is invalid or empty
227
+ assistant_reply = "I'm here to help — could you tell me more about your symptoms?"
228
+
229
+ response_payload = {
230
+ "assistant_reply": assistant_reply,
231
+ "updated_state": updated_state,
232
+ }
233
+
234
+ return jsonify(response_payload)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  @app.route("/ping", methods=["GET"])
237
  def ping():
 
240
  if __name__ == "__main__":
241
  port = int(os.getenv("PORT", 7860))
242
  app.run(host="0.0.0.0", port=port, debug=True)