Alessio Cocchieri commited on
Commit
9ff961b
·
1 Parent(s): aedda6d

Add application file

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ import gradio as gr
3
+ import json
4
+ from typing import Dict, List, Tuple, Any
5
+
6
+ from zshot import PipelineConfig
7
+ from zshot.linker import LinkerSMXM
8
+ from zshot.utils.data_models import Entity
9
+
10
+ from spacy.cli import download
11
+ download("en_core_web_sm")
12
+
13
+ # Function to load the NER model
14
+ def load_model(entity_data):
15
+ entities = [
16
+ Entity(
17
+ name=entity["name"],
18
+ description=entity["description"],
19
+ vocabulary=entity.get("vocabulary")
20
+ ) for entity in entity_data
21
+ ]
22
+
23
+ nlp = spacy.blank("en")
24
+ nlp_config = PipelineConfig(
25
+ linker=LinkerSMXM(model_name="disi-unibo-nlp/openbioner-base"),
26
+ entities=entities,
27
+ device='cpu' # Change to 'cpu' if GPU not available
28
+ )
29
+ nlp.add_pipe("zshot", config=nlp_config, last=True)
30
+
31
+ return nlp
32
+
33
+ # Default entities - focusing on BACTERIUM example
34
+ default_entities = [
35
+ {
36
+ "name": "BACTERIUM",
37
+ "description": "A bacterium refers to a type of microorganism that can exist as a single cell and may cause infections or play a role in various biological processes. Examples include species like Streptococcus pneumoniae and Streptomyces ahygroscopicus.",
38
+ }
39
+ ]
40
+
41
+ # Initialize model with default entities
42
+ nlp = load_model(default_entities)
43
+
44
+ # Function to create HTML visualization of entities
45
+ def get_entity_html(doc) -> str:
46
+ colors = {
47
+ "BACTERIUM": "#8dd3c7",
48
+ "CHEMICAL": "#fb8072",
49
+ "DISEASE": "#80b1d3",
50
+ "GENE": "#fdb462",
51
+ "SPECIES": "#b3de69"
52
+ }
53
+
54
+ html_parts = []
55
+ last_idx = 0
56
+
57
+ # Display text with highlighted entities
58
+ for ent in doc.ents:
59
+ # Add text before the entity
60
+ html_parts.append(doc.text[last_idx:ent.start_char])
61
+
62
+ # Add the highlighted entity
63
+ color = colors.get(ent.label_, "#ddd")
64
+ html_parts.append(
65
+ f'<span style="background-color: {color}; padding: 0.2em 0.3em; '
66
+ f'border-radius: 0.35em; margin: 0 0.1em; font-weight: bold; color: #000;">'
67
+ f'{doc.text[ent.start_char:ent.end_char]}'
68
+ f'<span style="font-size: 0.8em; font-weight: bold; margin-left: 0.5em">{ent.label_}</span>'
69
+ f'</span>'
70
+ )
71
+
72
+ # Update the last index
73
+ last_idx = ent.end_char
74
+
75
+ # Add any remaining text
76
+ html_parts.append(doc.text[last_idx:])
77
+
78
+ # Wrap the result in a div with dark theme styling
79
+ return f'<div style="line-height: 1.5; padding: 10px; background: #222; color: #fff; border-radius: 5px;">{"".join(html_parts)}</div>'
80
+
81
+ # Function to get entity details including spans
82
+ def get_entity_details(doc) -> List[Dict[str, Any]]:
83
+ entity_details = []
84
+ for ent in doc.ents:
85
+ entity_details.append({
86
+ "text": ent.text,
87
+ "type": ent.label_,
88
+ "start": ent.start_char,
89
+ "end": ent.end_char
90
+ })
91
+ return entity_details
92
+
93
+ # Main processing function
94
+ def process_text(text: str, entities_json: str) -> Tuple[str, List[Dict[str, Any]]]:
95
+ global nlp
96
+
97
+ # Update model if entities have changed
98
+ try:
99
+ entities = json.loads(entities_json)
100
+ nlp = load_model(entities)
101
+ except json.JSONDecodeError:
102
+ return "Error: Invalid JSON in entity configuration", []
103
+
104
+ # Process the text with the NER model
105
+ doc = nlp(text)
106
+
107
+ # Generate visualization HTML
108
+ html_output = get_entity_html(doc)
109
+
110
+ # Get detailed entity information including spans
111
+ entity_details = get_entity_details(doc)
112
+
113
+ return html_output, entity_details
114
+
115
+ # Set theme to dark
116
+ theme = gr.themes.Soft(
117
+ primary_hue="blue",
118
+ secondary_hue="slate",
119
+ neutral_hue="slate",
120
+ text_size=gr.themes.sizes.text_md,
121
+ ).set(
122
+ body_background_fill="#1a1a1a",
123
+ background_fill_primary="#222",
124
+ background_fill_secondary="#333",
125
+ border_color_primary="#444",
126
+ block_background_fill="#222",
127
+ block_label_background_fill="#333",
128
+ block_label_text_color="#fff",
129
+ block_title_text_color="#fff",
130
+ body_text_color="#fff",
131
+ button_primary_background_fill="#2563eb",
132
+ button_primary_text_color="#fff",
133
+ input_background_fill="#333",
134
+ input_border_color="#555",
135
+ input_placeholder_color="#888",
136
+ panel_background_fill="#222",
137
+ slider_color="#2563eb",
138
+ )
139
+
140
+ # Create Gradio interface with dark theme
141
+ with gr.Blocks(title="Named Entity Recognition", theme=theme) as demo:
142
+ gr.Markdown("# OpenBioNER - Demo")
143
+
144
+ # First row: Entity Definitions
145
+ with gr.Row():
146
+ entities_input = gr.Code(
147
+ label="Entity Definitions (JSON)",
148
+ language="json",
149
+ value=json.dumps(default_entities, indent=2),
150
+ lines=6
151
+ )
152
+
153
+ # Second row: Input text and examples side by side
154
+ with gr.Row():
155
+ # Left side - Input text
156
+ with gr.Column():
157
+ text_input = gr.Textbox(
158
+ label="Text to analyze",
159
+ placeholder="Enter text to analyze...",
160
+ value="Impact of cofactor - binding loop mutations on thermotolerance and activity of E. coli transketolase",
161
+ lines=3
162
+ )
163
+
164
+ analyze_btn = gr.Button("Analyze Text", variant="primary")
165
+
166
+ # Right side - Example texts
167
+ with gr.Column():
168
+ gr.Markdown("### Quick Examples")
169
+ example1_btn = gr.Button("E. coli research")
170
+ example2_btn = gr.Button("Bacterial infection case")
171
+ example3_btn = gr.Button("Multiple bacterial species")
172
+
173
+ # Third row: Output visualization and spans side by side
174
+ with gr.Row():
175
+ # Left side - Highlighted text output
176
+ with gr.Column():
177
+ gr.Markdown("### Recognized Entities")
178
+ result_html = gr.HTML()
179
+
180
+ # Right side - Entity spans details
181
+ with gr.Column():
182
+ gr.Markdown("### Entity Details with Spans")
183
+ entity_details = gr.JSON()
184
+
185
+ # Set up event handlers for the analyze button
186
+ analyze_btn.click(
187
+ fn=process_text,
188
+ inputs=[text_input, entities_input],
189
+ outputs=[result_html, entity_details]
190
+ )
191
+
192
+ # Set up event handlers for example buttons
193
+ example1_btn.click(
194
+ fn=lambda: "Impact of cofactor - binding loop mutations on thermotolerance and activity of E. coli transketolase",
195
+ inputs=None,
196
+ outputs=text_input
197
+ )
198
+
199
+ example2_btn.click(
200
+ fn=lambda: "The patient was diagnosed with pneumonia caused by Streptococcus pneumoniae and treated with antibiotics for 7 days.",
201
+ inputs=None,
202
+ outputs=text_input
203
+ )
204
+
205
+ example3_btn.click(
206
+ fn=lambda: "We compared growth rates of E. coli, B. subtilis and S. aureus in various media containing different carbon sources.",
207
+ inputs=None,
208
+ outputs=text_input
209
+ )
210
+
211
+ if __name__ == "__main__":
212
+ demo.launch()