File size: 34,575 Bytes
2dbf87d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
#!/usr/bin/env python
# coding: utf-8



from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS, Chroma
from langchain_community.llms import LlamaCpp
from langchain.chains import RetrievalQA, LLMChain
from langchain.prompts import ChatPromptTemplate
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser 
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
import google.generativeai as genai
import matplotlib.pyplot as plt
import os
from typing import Dict, List, Tuple
import json
import numpy as np
import seaborn as sns
from llama_cpp import Llama
import os

GEMINI_API_KEY = os.getenv('GEMINI')
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGING')
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN


# Initialize the Sentence Transformer model for intent classification
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")



# Define intents with keywords
intents = {
    "stomach": {
        "keywords": [
            "pain", "cramps", "nausea", "indigestion", "diarrhea", "bloating", "acid reflux", "constipation", 
            "ulcers", "food poisoning", "heartburn", "vomiting", "gassiness", "stomach flu", "gastritis", 
            "stomach ache", "IBS", "peptic ulcers", "acidity", "flatulence", "abdominal pain", "acid reflux disease", 
            "GERD", "feeling full quickly", "poor appetite", "belching", "sharp abdominal pain", "dull stomach ache", 
            "epigastric pain", "dyspepsia", "gurgling stomach sounds"
        ],
        "embedding_vector": None
    },
    "skin": {
        "keywords": [
            "rash", "itch", "eczema", "dry skin", "acne", "redness", "hives", "fungal infection", "psoriasis", 
            "dermatitis", "sunburn", "skin peeling", "discoloration", "swelling", "pimples", "spots", "cysts", 
            "skin tags", "lesions", "ulcers", "allergic reaction", "rosacea", "warts", "athlete's foot", "moles", 
            "boils", "flaky skin", "stretch marks", "pigmentation", "vitiligo", "skin irritation", "flaky scalp", 
            "blisters", "cracked skin", "sensitivity", "prickly heat", "bruising", "scars"
        ],
        "embedding_vector": None
    },
    "bp": {
        "keywords": [
            "hypertension", "blood pressure", "dizziness", "headache", "fatigue", "high blood pressure", 
            "low blood pressure", "hypotension", "chest pain", "palpitations", "fainting", "shortness of breath", 
            "blurred vision", "confusion", "nosebleeds", "lightheadedness", "pounding heartbeat", 
            "irregular heartbeat", "heart strain", "stroke risk", "renal issues", "diastolic pressure", 
            "systolic pressure", "blood flow issues", "heart failure", "hypertension crisis", "high pulse rate", 
            "low pulse rate", "cardiovascular disease", "pressure behind eyes"
        ],
        "embedding_vector": None
    },
    "diabetes": {
        "keywords": [
            "insulin", "sugar", "glucose", "thirst", "frequent urination", "weight loss", "blurred vision", 
            "fatigue", "tingling", "numbness", "slow healing wounds", "dry mouth", "diabetic neuropathy", "hunger", 
            "high blood sugar", "low blood sugar", "polyuria", "polydipsia", "glycemic index", "hyperglycemia", 
            "hypoglycemia", "ketoacidosis", "foot ulcers", "nerve pain", "eye problems", "retinopathy", "nephropathy", 
            "glucose intolerance", "sugar cravings", "sweating", "shakiness", "dizzy spells", "carb counting", 
            "A1C levels", "prediabetes", "metabolic syndrome"
        ],
        "embedding_vector": None
    }
}


# Precompute intent embeddings
for intent, data in intents.items():
    tokens = tokenizer(' '.join(data['keywords']), return_tensors="pt", padding=True, truncation=True)
    data['embedding_vector'] = model(**tokens).pooler_output.detach().numpy()


# Text preprocessing
def preprocess_text(text: str) -> str:
    text = text.lower()
    text = ''.join(char for char in text if char.isalnum() or char.isspace())
    replacements = {'bp': 'blood pressure', 'ut': 'urinary tract', 'hr': 'heart rate'}
    for abbr, full_form in replacements.items():
        text = text.replace(abbr, full_form)
    return text


# Intent classification
def classify_medical_intent(symptoms: str) -> str:
    preprocessed_symptoms = preprocess_text(symptoms)
    tokens = tokenizer(preprocessed_symptoms, return_tensors="pt", padding=True, truncation=True)
    symptoms_embedding = model(**tokens).pooler_output.detach().numpy()

    similarities = [
        cosine_similarity(symptoms_embedding, intent_data['embedding_vector'])[0][0]
        for intent_data in intents.values()
    ]

    keyword_scores = [
        sum(keyword in preprocessed_symptoms for keyword in intent_data['keywords'])
        for intent_data in intents.values()
    ]

    ensemble_scores = [
        0.7 * similarity + 0.3 * keyword_match
        for similarity, keyword_match in zip(similarities, keyword_scores)
    ]

    best_intent_index = np.argmax(ensemble_scores)
    return list(intents.keys())[best_intent_index]



def plot_enhanced_fishbone(disease: str, causes: Dict[str, List[str]]) -> plt.Figure:
    """

    Create an enhanced fishbone diagram with detailed sub-branches

    """
    # Setup
    sns.set_theme(style="whitegrid")
    fig, ax = plt.subplots(figsize=(20, 12), facecolor='#f0f2f6')

    # Main spine parameters
    spine_length = 14
    spine_start = 2
    spine_end = spine_start + spine_length

    # Draw main spine with arrow
    ax.arrow(spine_start, 0, spine_length, 0,
            head_width=0.4, head_length=0.6,
            fc='black', ec='black', lw=2)

    # Calculate positions
    num_causes = len(causes)
    spacing = spine_length / (num_causes + 1)
    branch_length = 3.5
    angle = 45

    # Calculate branch geometry
    dx = branch_length * np.cos(np.deg2rad(angle))
    dy = branch_length * np.sin(np.deg2rad(angle))

    # Use a colormap for different categories
    colors = plt.cm.Pastel1(np.linspace(0, 1, num_causes))

    # Draw branches and sub-branches
    for i, (cause, subcauses) in enumerate(causes.items()):
        x_pos = spine_start + (i + 1) * spacing

        # Alternate between top and bottom
        if i % 2 == 0:
            y_end = dy
            sub_y_offset = 0.5
            va = 'bottom'
        else:
            y_end = -dy
            sub_y_offset = -0.5
            va = 'top'

        # Draw main branch
        color = colors[i]
        ax.plot([x_pos, x_pos + dx], [0, y_end],
                color=color, lw=2, zorder=2)

        # Add main cause text
        ax.text(x_pos + dx, y_end + sub_y_offset,
                cause.upper(),
                ha='center',
                va=va,
                fontsize=10,
                fontweight='bold',
                bbox=dict(facecolor='white',
                         edgecolor=color,
                         boxstyle='round,pad=0.5',
                         alpha=0.9))

        # Add sub-branches
        for j, subcause in enumerate(subcauses):
            # Calculate sub-branch positions
            sub_ratio = (j + 1) / (len(subcauses) + 1)
            sub_x = x_pos + dx * sub_ratio
            sub_y = y_end * sub_ratio

            # Draw sub-branch
            ax.plot([sub_x, sub_x + dx/2],
                   [sub_y, sub_y],
                   color=color, lw=1, zorder=2)

            # Add sub-cause text
            ax.text(sub_x + dx/2 + 0.1,
                   sub_y,
                   subcause,
                   ha='left',
                   va='center',
                   fontsize=8,
                   bbox=dict(facecolor='white',
                            edgecolor=color,
                            alpha=0.7,
                            boxstyle='round,pad=0.3'))

        # Add decorative elements
        ax.plot(x_pos, 0, 'o', color=color, markersize=6, zorder=3)

    # Add problem statement
    ax.text(spine_end + 0.7, 0,
            disease.upper(),
            ha='left',
            va='center',
            fontsize=12,
            fontweight='bold',
            bbox=dict(facecolor='lightgray',
                     edgecolor='gray',
                     boxstyle='round,pad=0.5'))

    # Styling
    plt.title('Enhanced Root Cause Analysis (Ishikawa Diagram)',
              pad=20,
              fontsize=14,
              fontweight='bold')

    # Set limits and remove axes
    margin = 2
    ax.set_xlim(0, spine_end + 4)
    ax.set_ylim(-branch_length - margin, branch_length + margin)
    ax.axis('off')

    plt.tight_layout()
    return fig



def generate_interview_questions(initial_symptoms: str, category: str, gemini_model) -> List[str]:
    """Generate 5 specific interview questions using Gemini API based on initial symptoms and category."""
    prompt = f"""Given a patient with {category}-related symptoms: '{initial_symptoms}',

generate exactly 5 specific medical interview questions to understand their condition better.

Focus on gathering important diagnostic information for {category} conditions.



Return ONLY a JSON array of 5 questions in this exact format:

[

    "Question 1 text here",

    "Question 2 text here",

    "Question 3 text here",

    "Question 4 text here",

    "Question 5 text here"

]"""

    try:
        response = gemini_model.generate_content(prompt)
        response_text = response.text.strip()
        
        # Handle potential formatting issues
        if not response_text.startswith('['):
            # Try to extract JSON array if it's buried in additional text
            import re
            json_match = re.search(r'\[(.*?)\]', response_text, re.DOTALL)
            if json_match:
                response_text = json_match.group(0)
            else:
                # Fallback: Convert response to list format
                questions = [q.strip() for q in response_text.split('\n') if q.strip()]
                return questions[:5]
        
        questions = json.loads(response_text)
        
        # Ensure exactly 5 questions
        if len(questions) < 5:
            questions.extend([
                "How long have you been experiencing these symptoms?",
                "Have you noticed any patterns or triggers?",
                "Are there any other symptoms you've experienced?",
                "Have you made any recent lifestyle changes?",
                "Have you tried any treatments or medications?"
            ][:5 - len(questions)])
        
        return questions[:5]
    
    except json.JSONDecodeError:
        # Fallback questions based on category
        fallback_questions = {
            "bp": [
                "How often do you check your blood pressure?",
                "Do you experience headaches or dizziness?",
                "What is your typical salt intake?",
                "Do you have a family history of hypertension?",
                "What is your current exercise routine?"
            ],
            "diabetes": [
                "When did you last check your blood sugar?",
                "Have you noticed increased thirst or urination?",
                "What is your typical daily diet?",
                "Do you have a family history of diabetes?",
                "How often do you exercise?"
            ],
            "skin": [
                "How long have you had this skin condition?",
                "Is there any itching or pain?",
                "Have you used any new products recently?",
                "Does the condition worsen at any particular time?",
                "Have you noticed any triggers?"
            ],
            "stomach": [
                "When did your stomach problems begin?",
                "How would you describe the pain or discomfort?",
                "Are symptoms related to eating specific foods?",
                "Have you noticed any changes in appetite?",
                "Do you experience nausea or vomiting?"
            ]
        }
        
        return fallback_questions.get(category.lower(), [
            "How long have you been experiencing these symptoms?",
            "Have you noticed any patterns or triggers?",
            "Are there any other symptoms you've experienced?",
            "Have you made any recent lifestyle changes?",
            "Have you tried any treatments or medications?"
        ])


# In[20]:


def conduct_interview(questions: List[str], category: str, llm) -> Dict:
    """Conduct the interview using provided questions and gather responses."""
    interview_data = {
        'intent': category,
        'initial_symptoms': questions[0],
        'detailed_responses': {}
    }
    
    print("\nStarting detailed medical interview...\n")
    
    for i, question in enumerate(questions, 1):
        print(f"Question {i}: {question}")
        user_response = input("Your answer: ").strip()
        
        # Modified prompt to be more explicit about the required response format
        prompt = f"""As a medical professional, respond to this patient statement with empathy: '{user_response}'



Requirements:

- Respond directly to the patient

- Show understanding of their situation

- Keep response to 2-3 sentences

- Do not include any instructions or labels

- Start response with 'I' or 'Thank you' or similar direct phrases



For example, if patient says they have a headache, respond like:

"I understand you're experiencing head pain. Let's work together to identify the cause and find appropriate relief."



Now provide your response:"""
        
        # Get response and handle empty cases
        chatbot_response = get_llm_response(llm, prompt)
        clean_response = clean_llm_response(chatbot_response)
        
        # If we got an empty response after cleaning, use a fallback response
        if not clean_response:
            clean_response = generate_fallback_response(user_response)
            
        print(f"\nAssistant: {clean_response}\n")
        interview_data['detailed_responses'][f"Q{i}"] = user_response
    
    return interview_data

def clean_llm_response(response: str) -> str:
    """Clean and validate the LLM response."""
    if not response:
        return ""
    
    # List of words that indicate instruction text rather than actual response
    instruction_indicators = [
        "instructions:", "example:", "note:", "response should", "requirements:", 
        "remember to", "make sure to", "the response", "your response",
        "if they", "if patient", "keep it", "be brief", "respond with"
    ]
    
    # Get all non-empty lines
    lines = [line.strip() for line in response.strip().split('\n') if line.strip()]
    
    for line in lines:
        # Skip if line is too short
        if len(line) < 10:
            continue
            
        # Skip if line starts with common prefixes
        if any(line.lower().startswith(prefix) for prefix in [
            "assistant:", "ai:", "chatbot:", "response:", "answer:", 
            "example:", "note:", "question:"
        ]):
            continue
            
        # Skip if line contains instruction indicators
        if any(indicator in line.lower() for indicator in instruction_indicators):
            continue
            
        # Skip if line looks like a template or placeholder
        if '[' in line or ']' in line or '{' in line or '}' in line:
            continue
            
        # Line passes all checks - likely a valid response
        return line
    
    return ""

def generate_fallback_response(user_response: str) -> str:
    """Generate a fallback response when the LLM response is empty or invalid."""
    # Convert user response to lowercase once
    response_lower = user_response.lower()
    
    # Check for symptoms first (more specific)
    symptoms = ['pain', 'ache', 'hurt', 'dizzy', 'nausea', 'sick', 'fever', 
               'cough', 'tired', 'exhausted', 'headache', 'sore']
    if any(symptom in response_lower for symptom in symptoms):
        return "I hear that you're not feeling well, and I want you to know that your symptoms are being taken seriously. We'll work together to understand what's happening and find the right approach to help you feel better."
    
    # Check for negative responses
    if not user_response or response_lower in ['no', 'none', 'n/a', 'nope', 'nothing']:
        return "Thank you for letting me know. Please don't hesitate to mention if you experience any new symptoms or concerns. Your health is our priority."
    
    # Check for medication or treatment related responses
    if any(word in response_lower for word in ['medicine', 'medication', 'pill', 'drug', 'treatment']):
        return "Thank you for sharing these details about your medication history. This information is very helpful for understanding your situation and planning appropriate care."
    
    # Default response
    return "I appreciate you sharing this information with me. It helps us better understand your situation so we can provide the most appropriate care for your needs."



def generate_comprehensive_analysis(interview_data: Dict, category: str, gemini_model, llm) -> Dict:
    """Generate comprehensive medical analysis using both Gemini and Llama."""
    analysis_prompt = f"""Medical Analysis Request:

Patient Concern: {interview_data['intent'].capitalize()} Related Health Issue

Initial Symptoms: {interview_data['initial_symptoms']}

Detailed Interview Responses:

{chr(10).join([f"{k}: {v}" for k, v in interview_data['detailed_responses'].items()])}



Provide a detailed medical analysis in exactly this format using markdown:



**Possible Medical Diagnoses**

• First possible diagnosis with brief explanation

• Second possible diagnosis with brief explanation

• Third possible diagnosis with brief explanation



**Recommended Medical Tests**

• First recommended test with brief explanation

• Second recommended test with brief explanation

• Third recommended test with brief explanation



**Lifestyle and Dietary Recommendations**

• First lifestyle recommendation with brief explanation

• Second lifestyle recommendation with brief explanation

• Third lifestyle recommendation with brief explanation



**Signs Requiring Immediate Attention**

• First warning sign with brief explanation

• Second warning sign with brief explanation

• Third warning sign with brief explanation



**Treatment Approaches**

• First treatment approach with brief explanation

• Second treatment approach with brief explanation

• Third treatment approach with brief explanation"""

    try:
        # Get analysis from Gemini
        analysis_response = gemini_model.generate_content(analysis_prompt)
        response_text = analysis_response.text.strip()
        
        # Parse the response into structured format
        medical_analysis = parse_formatted_response(response_text)
        
        # If parsing failed or returned empty, use fallback
        if not medical_analysis:
            medical_analysis = get_fallback_analysis(category)
            
    except Exception as e:
        print(f"Error generating medical analysis: {str(e)}")
        medical_analysis = get_fallback_analysis(category)

    # Generate root cause analysis
    root_cause_prompt = f"""Based on the patient's {category} condition and responses:

{chr(10).join([f"{k}: {v}" for k, v in interview_data['detailed_responses'].items()])}



Provide a root cause analysis in exactly this format:



**Dietary Factors**

• First dietary factor with explanation

• Second dietary factor with explanation



**Stress Factors**

• First stress factor with explanation

• Second stress factor with explanation



**Lifestyle Factors**

• First lifestyle factor with explanation

• Second lifestyle factor with explanation"""

    try:
        root_cause_response = gemini_model.generate_content(root_cause_prompt)
        root_cause_text = root_cause_response.text.strip()
        root_cause_data = parse_formatted_response(root_cause_text)
        
        if not root_cause_data:
            root_cause_data = get_root_cause_template(category)
            
    except Exception as e:
        print(f"Error generating root cause analysis: {str(e)}")
        root_cause_data = get_root_cause_template(category)

    
    # Generate root cause analysis based on category
    root_cause_template = {
        "stomach": {
            "diet": [], "stress": [], "medication": [], "infection": [],
            "lifestyle": [], "hydration": [], "allergies": []
        },
        "skin": {
            "allergy": [], "hygiene": [], "environment": [], "genetics": [],
            "nutrition": [], "stress": [], "cosmetic_use": []
        },
        "bp": {
            "diet": [], "lifestyle": [], "stress": [], "physical_activity": [],
            "salt_intake": [], "sleep_disorders": []
        },
        "diabetes": {
            "diet": [], "exercise": [], "genetics": [], "insulin_resistance": [],
            "obesity": [], "stress": [], "medication": []
        }
    }
    root_cause_prompt1 = f"""Based on the patient's {category} condition and responses:

{json.dumps(interview_data['detailed_responses'], indent=2)}



Return EXACTLY this JSON structure for {category} with 2-3 specific items per category:

{json.dumps(get_root_cause_template(category), indent=2)}"""

    try:
        root_cause_response = gemini_model.generate_content(root_cause_prompt1)
        response_text = root_cause_response.text.strip()
        
        if not response_text.startswith('{'):
            import re
            json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
            if json_match:
                response_text = json_match.group(0)
            else:
                return {"medical_analysis": medical_analysis, "root_cause_data1": get_root_cause_template(category)}
        
        root_cause_data1 = json.loads(response_text)
    except (json.JSONDecodeError, Exception):
        root_cause_data1 = get_root_cause_template(category)

    # Ensure we return valid dictionaries
    result = {
        "medical_analysis": medical_analysis or get_fallback_analysis(category),
        "root_cause_data": root_cause_data or get_root_cause_template(category),
        "root_cause_data1": root_cause_data1
    }
    
    return result

def parse_formatted_response(text: str) -> Dict:
    """Parse the formatted text response into a structured dictionary."""
    result = {}
    current_section = None
    current_items = []
    
    if not text:
        return None
        
    lines = text.split('\n')
    
    for line in lines:
        line = line.strip()
        if not line:
            continue
            
        # Check for section headers (bolded text)
        if line.startswith('**') and line.endswith('**'):
            # Save previous section if it exists
            if current_section and current_items:
                key = current_section.lower().replace(' ', '_')
                result[key] = current_items
                current_items = []
            
            # Start new section
            current_section = line.strip('*').strip()
            continue
            
        # Check for bullet points
        if line.startswith('•'):
            item = line[1:].strip()
            if item:  # Only add non-empty items
                current_items.append(item)
                
    # Don't forget to save the last section
    if current_section and current_items:
        key = current_section.lower().replace(' ', '_')
        result[key] = current_items
        
    # Verify we got some content
    return result if result else None

def display_analysis_results(analysis_results: Dict):
    """Display the analysis results in a formatted way."""
    if not analysis_results or not analysis_results.get("medical_analysis"):
        print("\nError: No analysis results available.")
        return
        
    print("\n=== Comprehensive Medical Analysis ===\n")
    
    # Display medical analysis
    medical_analysis = analysis_results["medical_analysis"]
    for key, value in medical_analysis.items():
        section_title = key.replace('_', ' ').title()
        print(f"\n**{section_title}**")
        if isinstance(value, list):
            for item in value:
                print(f"• {item}")
        else:
            print(f"• {value}")
    
    print("\n=== Root Cause Analysis ===\n")
    
    # Display root cause analysis
    root_cause_data = analysis_results["root_cause_data"]
    for key, value in root_cause_data.items():
        section_title = key.replace('_', ' ').title()
        print(f"\n**{section_title}**")
        if isinstance(value, list):
            for item in value:
                print(f"• {item}")
        else:
            print(f"• {value}")

def get_fallback_analysis(category: str) -> Dict:
    """Provide fallback analysis if Gemini API fails."""
    fallback_analyses = {
        "bp": {
            "possible_diagnoses": [
                "Essential Hypertension",
                "Secondary Hypertension",
                "White Coat Hypertension"
            ],
            "recommended_tests": [
                "24-hour Blood Pressure Monitoring",
                "ECG",
                "Basic Blood Work"
            ],
            "lifestyle_recommendations": [
                "Reduce Salt Intake",
                "Regular Exercise",
                "Stress Management"
            ],
            "immediate_attention_signs": [
                "Severe Headache",
                "Chest Pain",
                "Vision Problems"
            ],
            "treatment_approaches": [
                "Lifestyle Modifications",
                "Blood Pressure Medications",
                "Regular Monitoring"
            ]
        },
        "diabetes": {
            "possible_diagnoses": [
                "Type 2 Diabetes",
                "Prediabetes",
                "Insulin Resistance"
            ],
            "recommended_tests": [
                "HbA1c Test",
                "Fasting Blood Sugar",
                "Glucose Tolerance Test"
            ],
            "lifestyle_recommendations": [
                "Balanced Diet",
                "Regular Exercise",
                "Weight Management"
            ],
            "immediate_attention_signs": [
                "Very High Blood Sugar",
                "Severe Dehydration",
                "Confusion or Drowsiness"
            ],
            "treatment_approaches": [
                "Diet Control",
                "Oral Medications",
                "Blood Sugar Monitoring"
            ]
        },
        "skin": {
            "possible_diagnoses": [
                "Eczema",
                "Psoriasis",
                "Acne",
                "Skin Allergies",
                "Fungal Infections"
            ],
            "recommended_tests": [
                "Skin Patch Test",
                "Biopsy (if required)",
                "Allergy Test"
            ],
            "lifestyle_recommendations": [
                "Use of Gentle Cleansers",
                "Hydration and Moisturizing",
                "Avoiding Known Allergens"
            ],
            "immediate_attention_signs": [
                "Severe Rash with Swelling",
                "Skin Infection with Fever",
                "Rapidly Spreading Lesions"
            ],
            "treatment_approaches": [
                "Topical Ointments",
                "Antihistamines",
                "Prescription Medications (e.g., Steroids)"
            ]
        },
        "stomach": {
            "possible_diagnoses": [
                "Gastritis",
                "Irritable Bowel Syndrome (IBS)",
                "Acid Reflux (GERD)",
                "Peptic Ulcer",
                "Stomach Infection"
            ],
            "recommended_tests": [
                "Endoscopy",
                "Stool Test",
                "Helicobacter Pylori Test"
            ],
            "lifestyle_recommendations": [
                "Eating Smaller Meals",
                "Avoiding Spicy or Acidic Foods",
                "Stress Management"
            ],
            "immediate_attention_signs": [
                "Severe Abdominal Pain",
                "Blood in Stool or Vomit",
                "Unexplained Weight Loss"
            ],
            "treatment_approaches": [
                "Antacids or Acid Blockers",
                "Probiotics",
                "Medication for H. Pylori (if present)"
            ]
        }

    }
    
    return fallback_analyses.get(category.lower(), {
        "possible_diagnoses": ["Requires Medical Evaluation"],
        "recommended_tests": ["Consult Healthcare Provider"],
        "lifestyle_recommendations": ["Follow General Health Guidelines"],
        "immediate_attention_signs": ["Severe Symptoms", "Persistent Problems"],
        "treatment_approaches": ["Professional Medical Assessment"]
    })

def get_root_cause_template(category: str) -> Dict:
    """Return the template for root cause analysis with sample data."""
    templates = {
        "stomach": {
            "diet": ["Irregular eating patterns", "Spicy food consumption"],
            "stress": ["Work-related stress", "Anxiety"],
            "medication": ["Recent antibiotics", "NSAIDs"],
            "infection": ["Possible H. pylori", "Food-borne infection"],
            "lifestyle": ["Late night eating", "Fast food consumption"],
            "hydration": ["Inadequate water intake", "Excess caffeine"],
            "allergies": ["Food sensitivities", "Lactose intolerance"]
        },
        "skin": {
            "allergy": ["Contact dermatitis", "Environmental allergens"],
            "hygiene": ["Cleansing routine", "Product usage"],
            "environment": ["Sun exposure", "Pollution"],
            "genetics": ["Family history", "Predisposition"],
            "nutrition": ["Vitamin deficiency", "Diet impact"],
            "stress": ["Psychological factors", "Hormonal changes"],
            "cosmetic_use": ["Reaction to products", "Skin barrier damage"]
        },
        "bp": {
            "diet": ["Salt intake", "Processed foods"],
            "lifestyle": ["Sedentary behavior", "Smoking"],
            "stress": ["Work pressure", "Anxiety"],
            "physical_activity": ["Exercise routine", "Daily movement"],
            "salt_intake": ["Hidden sodium", "Dietary habits"],
            "sleep_disorders": ["Sleep apnea", "Insomnia"]
        },
        "diabetes": {
            "diet": ["Carbohydrate intake", "Sugar consumption"],
            "exercise": ["Activity level", "Fitness routine"],
            "genetics": ["Family history", "Genetic factors"],
            "insulin_resistance": ["Metabolic factors", "Body composition"],
            "obesity": ["Weight status", "Fat distribution"],
            "stress": ["Hormonal impact", "Lifestyle factors"],
            "medication": ["Current medications", "Treatment adherence"]
        }
    }
    
    return templates.get(category.lower(), {
        "general_factors": ["To be evaluated", "Requires assessment"],
        "lifestyle": ["To be determined", "Needs analysis"],
        "medical": ["Pending evaluation", "Professional assessment needed"]
    })




def get_llm_response(llm, prompt: str, max_tokens: int = 256) -> str:
    """Get response from Llama model with proper formatting."""
    try:
        response = llm(prompt, max_tokens=max_tokens)
        return response['choices'][0]['text'].strip()
    except ValueError:
        shortened_prompt = prompt[-500:]
        response = llm(shortened_prompt, max_tokens=max_tokens)
        return response['choices'][0]['text'].strip()



def main():
    # Initialize models
    llm = Llama.from_pretrained(
        repo_id="tensorblock/Llama3-Aloe-8B-Alpha-GGUF",
        filename="Llama3-Aloe-8B-Alpha-Q2_K.gguf",
        n_ctx=2048
    )

    genai.configure(api_key=GEMINI_API_KEY)
    gemini_model = genai.GenerativeModel('gemini-1.5-flash')

    print("\nWelcome to the Enhanced Healthcare Assistant!")
    print("Please describe your symptoms:\n")
    
    initial_input = input("You: ").strip()
    detected_category = classify_medical_intent(initial_input)  # You have this function
    
    # Generate interview questions using Gemini
    questions = generate_interview_questions(initial_input, detected_category, gemini_model)
    
    # Conduct interview
    interview_data = conduct_interview(questions, detected_category, llm)
    
    # Generate comprehensive analysis
    analysis_results = generate_comprehensive_analysis(
        interview_data, detected_category, gemini_model, llm
    )

    
    
    # # Display results
    # print("\n=== Comprehensive Medical Analysis ===\n")
    # for key, value in analysis_results["medical_analysis"].items():
    #     print(f"\n{key.replace('_', ' ').title()}:")
    #     if isinstance(value, list):
    #         for item in value:
    #             print(f"• {item}")
    #     else:
    #         print(value)
    display_analysis_results(analysis_results)
    
    # Create and display fishbone diagram
    plot_enhanced_fishbone(
        detected_category.title(),
        analysis_results["root_cause_data1"]
    )
    plt.show()



if __name__ == "__main__":
    main()