hponepyae commited on
Commit
83ff66a
·
verified ·
1 Parent(s): db10b40

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import os
5
+
6
+ # Get the Hugging Face token from the environment variables
7
+ hf_token = os.environ.get("HF_TOKEN")
8
+
9
+ # Initialize the tokenizer and model
10
+ # We are now using MedGemma, a 4 billion parameter instruction-tuned model
11
+ # specialized for the medical domain.
12
+ model_id = "google/medgemma-4b-it"
13
+
14
+ # Check for GPU availability and set the data type accordingly
15
+ if torch.cuda.is_available():
16
+ dtype = torch.bfloat16
17
+ else:
18
+ dtype = torch.float32
19
+
20
+ # Load the tokenizer and model from Hugging Face
21
+ try:
22
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ model_id,
25
+ token=hf_token,
26
+ torch_dtype=dtype,
27
+ device_map="auto",
28
+ )
29
+ model_loaded = True
30
+ except Exception as e:
31
+ print(f"Error loading model: {e}")
32
+ model_loaded = False
33
+
34
+ # This is the core function that will take the clinical text and return a code
35
+ def get_clinical_code(clinical_text):
36
+ """
37
+ Generates a clinical code from unstructured clinical text using the MedGemma model.
38
+ """
39
+ if not model_loaded:
40
+ return "Error: The model could not be loaded. Please check the logs."
41
+
42
+ if not clinical_text:
43
+ return "Please enter some clinical text."
44
+
45
+ # This is our prompt template. It's designed to guide the model
46
+ # to perform the specific task of clinical coding.
47
+ # We are asking for an ICD-10 code, which is a common standard.
48
+ prompt = f"""
49
+ <start_of_turn>user
50
+ You are an expert medical coder. Your task is to analyze the following clinical text and determine the most appropriate ICD-10 code. Provide only the ICD-10 code and a brief description.
51
+
52
+ Clinical Text: "{clinical_text}"
53
+
54
+ Provide the ICD-10 code and a brief description.
55
+ <end_of_turn>
56
+ <start_of_turn>model
57
+ """
58
+
59
+ # Prepare the input for the model
60
+ input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
61
+
62
+ # Generate the output from the model
63
+ # We are using a max length of 256 tokens which should be sufficient
64
+ # for a code and a short description.
65
+ outputs = model.generate(
66
+ **input_ids,
67
+ max_new_tokens=256,
68
+ do_sample=True,
69
+ temperature=0.7, # A lower temperature makes the output more deterministic
70
+ )
71
+
72
+ # Decode the output and clean it up
73
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
74
+
75
+ # Extract the relevant part of the response
76
+ # The model will output the prompt as well, so we need to remove it.
77
+ model_response_start = response.find("<start_of_turn>model") + len("<start_of_turn>model")
78
+ clean_response = response[model_response_start:].strip()
79
+
80
+ return clean_response
81
+
82
+ # Create the Gradio Interface
83
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
84
+ gr.Markdown(
85
+ """
86
+ # Clinical Code Generator with Google MedGemma
87
+ Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code.
88
+ *Disclaimer: This is a demonstration and not for professional medical use.*
89
+ """
90
+ )
91
+
92
+ with gr.Row():
93
+ # Input Textbox
94
+ input_text = gr.Textbox(
95
+ label="Unstructured Clinical Text",
96
+ placeholder="e.g., Patient presents with a severe headache and photophobia...",
97
+ lines=10
98
+ )
99
+
100
+ # Output Textbox
101
+ output_text = gr.Textbox(
102
+ label="Suggested Clinical Code (ICD-10)",
103
+ interactive=False,
104
+ lines=5
105
+ )
106
+
107
+ # Submit Button
108
+ submit_button = gr.Button("Get Clinical Code", variant="primary")
109
+
110
+ # Connect the button to the function
111
+ submit_button.click(
112
+ fn=get_clinical_code,
113
+ inputs=input_text,
114
+ outputs=output_text
115
+ )
116
+
117
+ gr.Examples(
118
+ examples=[
119
+ ["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."],
120
+ ["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."],
121
+ ["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."]
122
+ ],
123
+ inputs=input_text,
124
+ outputs=output_text,
125
+ fn=get_clinical_code
126
+ )
127
+
128
+ # Launch the Gradio app
129
+ if __name__ == "__main__":
130
+ demo.launch()