jayebaku commited on
Commit
b223991
verified
1 Parent(s): 3f4252c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +403 -403
app.py CHANGED
@@ -1,404 +1,404 @@
1
- import os
2
- import gradio as gr
3
- import pandas as pd
4
-
5
- from classifier import classify
6
- from statistics import mean
7
- from qa_summary import generate_answer
8
-
9
-
10
- HFTOKEN = os.environ["HF_TOKEN"]
11
-
12
-
13
-
14
- js = """
15
- async () => {
16
- // Load Twitter Widgets script
17
- const script = document.createElement("script");
18
- script.onload = () => console.log("Twitter Widgets.js loaded");
19
- script.src = "https://platform.twitter.com/widgets.js";
20
- document.head.appendChild(script);
21
-
22
- // Define a global function to reload Twitter widgets
23
- globalThis.reloadTwitterWidgets = () => {
24
- if (window.twttr && twttr.widgets) {
25
- twttr.widgets.load();
26
- }
27
- };
28
- }
29
- """
30
-
31
- def T_on_select(evt: gr.SelectData):
32
-
33
- if evt.index[1] == 3:
34
- html = """<blockquote class="twitter-tweet" data-dnt="true" data-theme="dark">""" + \
35
- f"""\n<a href="https://twitter.com/anyuser/status/{evt.value}"></a></blockquote>"""
36
- else:
37
- html = f"""<h2>{evt.value}</h2>"""
38
- return gr.update(value=html)
39
-
40
- def single_classification(text, event_model, threshold):
41
- res = classify(text, event_model, HFTOKEN, threshold)
42
- return res["event"], res["score"]
43
-
44
- def load_and_classify_csv(file, text_field, event_model, threshold):
45
- filepath = file.name
46
- if ".csv" in filepath:
47
- df = pd.read_csv(filepath)
48
- else:
49
- df = pd.read_table(filepath)
50
-
51
- if text_field not in df.columns:
52
- raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")
53
-
54
- labels, scores = [], []
55
- for post in df[text_field].to_list():
56
- res = classify(post, event_model, HFTOKEN, threshold)
57
- labels.append(res["event"])
58
- scores.append(res["score"])
59
-
60
- df["model_label"] = labels
61
- df["model_score"] = scores
62
-
63
- # model_confidence = round(mean(scores), 5)
64
- model_confidence = mean(scores)
65
- fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list())
66
- flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list())
67
- not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
68
-
69
- return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df, gr.update(interactive=True), gr.update(interactive=True)
70
-
71
- def load_and_classify_csv_dataframe(file, text_field, event_model, threshold): #, filter
72
-
73
- filepath = file.name
74
- if ".csv" in filepath:
75
- df = pd.read_csv(filepath)
76
- else:
77
- df = pd.read_table(filepath)
78
-
79
- if text_field not in df.columns:
80
- raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")
81
-
82
- labels, scores = [], []
83
- for post in df[text_field].to_list():
84
- res = classify(post, event_model, HFTOKEN, threshold)
85
- labels.append(res["event"])
86
- scores.append(round(res["score"], 5))
87
-
88
- df["event_label"] = labels
89
- df["model_score"] = scores
90
-
91
- result_df = df[[text_field, "event_label", "model_score", "tweet_id"]].copy()
92
- result_df["tweet_id"] = result_df["tweet_id"].astype(str)
93
-
94
- filters = list(result_df["event_label"].unique())
95
- extra_filters = ['Not-'+x for x in filters]+['All']
96
-
97
- return gr.update(value=result_df), result_df, gr.update(choices=sorted(filters+extra_filters),
98
- value='All',
99
- label="Filter data by label",
100
- visible=True)
101
-
102
-
103
- def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
104
- posts = data_df[text_field].to_list()
105
- selections = flood_selections + fire_selections + none_selections
106
- eval = []
107
- for post in posts:
108
- if post in selections:
109
- eval.append("incorrect")
110
- else:
111
- eval.append("correct")
112
-
113
- data_df["model_eval"] = eval
114
- incorrect = len(selections)
115
- correct = num_posts - incorrect
116
- accuracy = (correct/num_posts)*100
117
-
118
- data_df.to_csv("output.csv")
119
- return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True)
120
-
121
- def init_queries(history):
122
- history = history or []
123
- if not history:
124
- history = [
125
- "What areas are being evacuated?",
126
- "What areas are predicted to be impacted?",
127
- "What areas are without power?",
128
- "What barriers are hindering response efforts?",
129
- "What events have been canceled?",
130
- "What preparations are being made?",
131
- "What regions have announced a state of emergency?",
132
- "What roads are blocked / closed?",
133
- "What services have been closed?",
134
- "What warnings are currently in effect?",
135
- "Where are emergency services deployed?",
136
- "Where are emergency services needed?",
137
- "Where are evacuations needed?",
138
- "Where are people needing rescued?",
139
- "Where are recovery efforts taking place?",
140
- "Where has building or infrastructure damage occurred?",
141
- "Where has flooding occured?"
142
- "Where are volunteers being requested?",
143
- "Where has road damage occured?",
144
- "What area has the wildfire burned?",
145
- "Where have homes been damaged or destroyed?"]
146
-
147
- return gr.CheckboxGroup(choices=history), history
148
-
149
- def add_query(to_add, history):
150
- if to_add not in history:
151
- history.append(to_add)
152
- return gr.CheckboxGroup(choices=history), history
153
-
154
- def qa_summarise(selected_queries, qa_llm_model, text_field, data_df):
155
-
156
- qa_input_df = data_df[data_df["model_label"] != "none"].reset_index()
157
- texts = qa_input_df[text_field].to_list()
158
-
159
- summary = generate_answer(qa_llm_model, texts, selected_queries[0], selected_queries, mode="multi_summarize")
160
-
161
- doc_df = pd.DataFrame()
162
- doc_df["number"] = [i+1 for i in range(len(texts))]
163
- doc_df["text"] = texts
164
-
165
- return summary, doc_df
166
-
167
-
168
- with gr.Blocks(fill_width=True) as demo:
169
-
170
- demo.load(None,None,None,js=js)
171
-
172
- event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier",
173
- "jayebaku/distilbert-base-multilingual-cased-weather-classifier-2",
174
- "jayebaku/twitter-xlm-roberta-base-crexdata-relevance-classifier",
175
- "jayebaku/twhin-bert-base-crexdata-relevance-classifier"]
176
-
177
- T_data_ss_state = gr.State(value=pd.DataFrame())
178
-
179
-
180
- with gr.Tab("Event Type Classification"):
181
- gr.Markdown(
182
- """
183
- # T4.5 Relevance Classifier Demo
184
- This is a demo created to explore floods and wildfire classification in social media posts.\n
185
- Usage:\n
186
- - Upload .tsv or .csv data file (must contain a text column with social media posts).\n
187
- - Next, type the name of the text column.\n
188
- - Then, choose a BERT classifier model from the drop down.\n
189
- - Finally, click the 'start prediction' buttton.\n
190
- """)
191
- with gr.Row():
192
- with gr.Column(scale=4):
193
- T_file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
194
-
195
- with gr.Column(scale=6):
196
- T_text_field = gr.Textbox(label="Text field name", value="tweet_text")
197
- T_event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
198
- T_predict_button = gr.Button("Start Prediction")
199
- with gr.Accordion("Prediction threshold", open=False):
200
- T_threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False,
201
- info="This value sets a threshold by which texts classified flood or fire are accepted, \
202
- higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
203
-
204
- with gr.Row():
205
- with gr.Column(scale=8):
206
- T_data = gr.DataFrame(wrap=True,
207
- show_fullscreen_button=True,
208
- show_copy_button=True,
209
- show_row_numbers=True,
210
- show_search="filter",
211
- column_widths=["49%","17%","17%","17%"])
212
-
213
- with gr.Column(scale=2):
214
- T_data_filter = gr.Dropdown(visible=False)
215
- T_tweet_embed = gr.HTML("<h1>Select a Tweet ID to view Tweet</h1>")
216
-
217
-
218
-
219
- with gr.Tab("Event Type Classification Eval"):
220
- gr.Markdown(
221
- """
222
- # T4.5 Relevance Classifier Demo
223
- This is a demo created to explore floods and wildfire classification in social media posts.\n
224
- Usage:\n
225
- - Upload .tsv or .csv data file (must contain a text column with social media posts).\n
226
- - Next, type the name of the text column.\n
227
- - Then, choose a BERT classifier model from the drop down.\n
228
- - Finally, click the 'start prediction' buttton.\n
229
- Evaluation:\n
230
- - To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n
231
- - Then, click on the 'Calculate Accuracy' button.\n
232
- - Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file.
233
- """)
234
- with gr.Row():
235
- with gr.Column(scale=4):
236
- file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
237
-
238
- with gr.Column(scale=6):
239
- text_field = gr.Textbox(label="Text field name", value="tweet_text")
240
- event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
241
- ETCE_predict_button = gr.Button("Start Prediction")
242
- with gr.Accordion("Prediction threshold", open=False):
243
- threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False,
244
- info="This value sets a threshold by which texts classified flood or fire are accepted, \
245
- higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
246
-
247
- with gr.Row(): # XXX confirm this is not a problem later --equal_height=True
248
- with gr.Column():
249
- gr.Markdown("""### Flood-related""")
250
- flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
251
-
252
- with gr.Column():
253
- gr.Markdown("""### Fire-related""")
254
- fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
255
-
256
- with gr.Column():
257
- gr.Markdown("""### None""")
258
- none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
259
-
260
- with gr.Row():
261
- with gr.Column(scale=5):
262
- gr.Markdown(r"""
263
- Accuracy: is the model's ability to make correct predicitons.
264
- It is the fraction of correct prediction out of the total predictions.
265
-
266
- $$
267
- \text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100
268
- $$
269
-
270
- Model Confidence: is the mean probabilty of each case
271
- belonging to their assigned classes. A value of 1 is best.
272
- """, latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }])
273
- gr.Markdown("\n\n\n")
274
- model_confidence = gr.Number(label="Model Confidence")
275
-
276
- with gr.Column(scale=5):
277
- correct = gr.Number(label="Number of correct classifications")
278
- incorrect = gr.Number(label="Number of incorrect classifications")
279
- accuracy = gr.Number(label="Model Accuracy (%)")
280
-
281
- ETCE_accuracy_button = gr.Button("Calculate Accuracy")
282
- download_csv = gr.DownloadButton(visible=False)
283
- num_posts = gr.Number(visible=False)
284
- data = gr.DataFrame(visible=False)
285
- data_eval = gr.DataFrame(visible=False)
286
-
287
-
288
- qa_tab = gr.Tab("Question Answering")
289
- with qa_tab:
290
- gr.Markdown(
291
- """
292
- # Question Answering Demo
293
- This section uses RAG to answer questions about the relevant social media posts identified by the relevance classifier\n
294
- Usage:\n
295
- - Select queries from predefined\n
296
- - Parameters for QA can be editted in sidebar\n
297
-
298
- Note: QA process is disabled untill after the relevance classification is done
299
- """)
300
-
301
- with gr.Accordion("Parameters", open=False):
302
- with gr.Row():
303
- with gr.Column():
304
- qa_llm_model = gr.Dropdown(["mistral", "solar", "phi3mini"], label="QA model", value="phi3mini", interactive=True)
305
- aggregator = gr.Dropdown(["linear", "outrank"], label="Aggregation method", value="linear", interactive=True)
306
- with gr.Column():
307
- batch_size = gr.Slider(50, 500, value=150, step=1, label="Batch size", info="Choose between 50 and 500", interactive=True)
308
- topk = gr.Slider(1, 10, value=5, step=1, label="Number of results to retrieve", info="Choose between 1 and 10", interactive=True)
309
-
310
- selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True)
311
- queries_state = gr.State()
312
- qa_tab.select(init_queries, inputs=queries_state, outputs=[selected_queries, queries_state])
313
-
314
- query_inp = gr.Textbox(label="Add custom queries like the one above, one at a time")
315
- QA_addqry_button = gr.Button("Add to queries", interactive=False)
316
- QA_run_button = gr.Button("Start QA", interactive=False)
317
- hsummary = gr.Textbox(label="Summary")
318
-
319
- qa_df = gr.DataFrame()
320
-
321
-
322
- with gr.Tab("Single Text Classification"):
323
- gr.Markdown(
324
- """
325
- # Event Type Prediction Demo
326
- In this section you test the relevance classifier with written texts.\n
327
- Usage:\n
328
- - Type a tweet-like text in the textbox.\n
329
- - Then press Enter.\n
330
- """)
331
- with gr.Row():
332
- with gr.Column(scale=3):
333
- model_sing_classify = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
334
- with gr.Column(scale=7):
335
- threshold_sing_classify = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold",
336
- info="This value sets a threshold by which texts classified flood or fire are accepted, \
337
- higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
338
-
339
- text_to_classify = gr.Textbox(label="Text", info="Enter tweet-like text", submit_btn=True)
340
- text_to_classify_examples = gr.Examples([["The streets are flooded, I can't leave #BostonStorm"],
341
- ["Controlado el incendio de Rodezno que ha obligado a desalojar a varias bodegas de la zona."],
342
- ["Cambrils:estaci贸 Renfe inundada 19 persones dins d'un tren. FGC a Capellades, petit descarrilament 5 passatgers #Inuncat @emergenciescat"],
343
- ["Anscheinend steht die komplette Neckarwiese unter Wasser! #Hochwasser"]], text_to_classify)
344
-
345
- with gr.Row():
346
- with gr.Column():
347
- classification = gr.Textbox(label="Classification")
348
- with gr.Column():
349
- classification_score = gr.Number(label="Classification Score")
350
-
351
-
352
-
353
-
354
-
355
-
356
-
357
-
358
- # Test event listeners
359
- T_predict_button.click(
360
- load_and_classify_csv_dataframe,
361
- inputs=[T_file_input, T_text_field, T_event_model, T_threshold],
362
- outputs=[T_data, T_data_ss_state, T_data_filter]
363
- )
364
-
365
- T_data.select(T_on_select, None, T_tweet_embed).then(fn=None, js="reloadTwitterWidgets()")
366
-
367
- @T_data_filter.input(inputs=[T_data_ss_state, T_data_filter], outputs=T_data)
368
- def filter_df(df, filter):
369
- if filter == "All":
370
- result_df = df.copy()
371
- elif filter.startswith("Not"):
372
- result_df = df[df["event_label"]!=filter.split('-')[1]].copy()
373
- else:
374
- result_df = df[df["event_label"]==filter].copy()
375
- return gr.update(value=result_df)
376
-
377
-
378
- # Button clicks ETC Eval
379
- ETCE_predict_button.click(
380
- load_and_classify_csv,
381
- inputs=[file_input, text_field, event_model, threshold],
382
- outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data, QA_addqry_button, QA_run_button])
383
-
384
- ETCE_accuracy_button.click(
385
- calculate_accuracy,
386
- inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data],
387
- outputs=[incorrect, correct, accuracy, data_eval, download_csv])
388
-
389
-
390
- # Button clicks QA
391
- QA_addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])
392
-
393
- QA_run_button.click(qa_summarise,
394
- inputs=[selected_queries, qa_llm_model, text_field, data], ## XXX fix text_field
395
- outputs=[hsummary, qa_df])
396
-
397
-
398
- # Event listener for single text classification
399
- text_to_classify.submit(
400
- single_classification,
401
- inputs=[text_to_classify, model_sing_classify, threshold_sing_classify],
402
- outputs=[classification, classification_score])
403
-
404
  demo.launch()
 
1
+ import os
2
+ import gradio as gr
3
+ import pandas as pd
4
+
5
+ from classifier import classify
6
+ from statistics import mean
7
+ from qa_summary import generate_answer
8
+
9
+
10
+ HFTOKEN = os.environ["HF_TOKEN"]
11
+
12
+
13
+
14
+ js = """
15
+ async () => {
16
+ // Load Twitter Widgets script
17
+ const script = document.createElement("script");
18
+ script.onload = () => console.log("Twitter Widgets.js loaded");
19
+ script.src = "https://platform.twitter.com/widgets.js";
20
+ document.head.appendChild(script);
21
+
22
+ // Define a global function to reload Twitter widgets
23
+ globalThis.reloadTwitterWidgets = () => {
24
+ if (window.twttr && twttr.widgets) {
25
+ twttr.widgets.load();
26
+ }
27
+ };
28
+ }
29
+ """
30
+
31
+ def T_on_select(evt: gr.SelectData):
32
+
33
+ if evt.index[1] == 3:
34
+ html = """<blockquote class="twitter-tweet" data-dnt="true" data-theme="dark">""" + \
35
+ f"""\n<a href="https://twitter.com/anyuser/status/{evt.value}"></a></blockquote>"""
36
+ else:
37
+ html = f"""<h2>{evt.value}</h2>"""
38
+ return gr.update(value=html)
39
+
40
+ def single_classification(text, event_model, threshold):
41
+ res = classify(text, event_model, HFTOKEN, threshold)
42
+ return res["event"], res["score"]
43
+
44
+ def load_and_classify_csv(file, text_field, event_model, threshold):
45
+ filepath = file.name
46
+ if ".csv" in filepath:
47
+ df = pd.read_csv(filepath)
48
+ else:
49
+ df = pd.read_table(filepath)
50
+
51
+ if text_field not in df.columns:
52
+ raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")
53
+
54
+ labels, scores = [], []
55
+ for post in df[text_field].to_list():
56
+ res = classify(post, event_model, HFTOKEN, threshold)
57
+ labels.append(res["event"])
58
+ scores.append(res["score"])
59
+
60
+ df["model_label"] = labels
61
+ df["model_score"] = scores
62
+
63
+ # model_confidence = round(mean(scores), 5)
64
+ model_confidence = mean(scores)
65
+ fire_related = gr.CheckboxGroup(choices=df[df["model_label"]=="fire"][text_field].to_list())
66
+ flood_related = gr.CheckboxGroup(choices=df[df["model_label"]=="flood"][text_field].to_list())
67
+ not_related = gr.CheckboxGroup(choices=df[df["model_label"]=="none"][text_field].to_list())
68
+
69
+ return flood_related, fire_related, not_related, model_confidence, len(df[text_field].to_list()), df, gr.update(interactive=True), gr.update(interactive=True)
70
+
71
+ def load_and_classify_csv_dataframe(file, text_field, event_model, threshold): #, filter
72
+
73
+ filepath = file.name
74
+ if ".csv" in filepath:
75
+ df = pd.read_csv(filepath)
76
+ else:
77
+ df = pd.read_table(filepath)
78
+
79
+ if text_field not in df.columns:
80
+ raise gr.Error(f"Error: Enter text column'{text_field}' not in CSV file.")
81
+
82
+ labels, scores = [], []
83
+ for post in df[text_field].to_list():
84
+ res = classify(post, event_model, HFTOKEN, threshold)
85
+ labels.append(res["event"])
86
+ scores.append(round(res["score"], 5))
87
+
88
+ df["event_label"] = labels
89
+ df["model_score"] = scores
90
+
91
+ result_df = df[[text_field, "event_label", "model_score", "tweet_id"]].copy()
92
+ result_df["tweet_id"] = result_df["tweet_id"].astype(str)
93
+
94
+ filters = list(result_df["event_label"].unique())
95
+ extra_filters = ['Not-'+x for x in filters]+['All']
96
+
97
+ return gr.update(value=result_df), result_df, gr.update(choices=sorted(filters+extra_filters),
98
+ value='All',
99
+ label="Filter data by label",
100
+ visible=True)
101
+
102
+
103
+ def calculate_accuracy(flood_selections, fire_selections, none_selections, num_posts, text_field, data_df):
104
+ posts = data_df[text_field].to_list()
105
+ selections = flood_selections + fire_selections + none_selections
106
+ eval = []
107
+ for post in posts:
108
+ if post in selections:
109
+ eval.append("incorrect")
110
+ else:
111
+ eval.append("correct")
112
+
113
+ data_df["model_eval"] = eval
114
+ incorrect = len(selections)
115
+ correct = num_posts - incorrect
116
+ accuracy = (correct/num_posts)*100
117
+
118
+ data_df.to_csv("output.csv")
119
+ return incorrect, correct, accuracy, data_df, gr.DownloadButton(label=f"Download CSV", value="output.csv", visible=True)
120
+
121
+ def init_queries(history):
122
+ history = history or []
123
+ if not history:
124
+ history = [
125
+ "What areas are being evacuated?",
126
+ "What areas are predicted to be impacted?",
127
+ "What areas are without power?",
128
+ "What barriers are hindering response efforts?",
129
+ "What events have been canceled?",
130
+ "What preparations are being made?",
131
+ "What regions have announced a state of emergency?",
132
+ "What roads are blocked / closed?",
133
+ "What services have been closed?",
134
+ "What warnings are currently in effect?",
135
+ "Where are emergency services deployed?",
136
+ "Where are emergency services needed?",
137
+ "Where are evacuations needed?",
138
+ "Where are people needing rescued?",
139
+ "Where are recovery efforts taking place?",
140
+ "Where has building or infrastructure damage occurred?",
141
+ "Where has flooding occured?"
142
+ "Where are volunteers being requested?",
143
+ "Where has road damage occured?",
144
+ "What area has the wildfire burned?",
145
+ "Where have homes been damaged or destroyed?"]
146
+
147
+ return gr.CheckboxGroup(choices=history), history
148
+
149
+ def add_query(to_add, history):
150
+ if to_add not in history:
151
+ history.append(to_add)
152
+ return gr.CheckboxGroup(choices=history), history
153
+
154
+ def qa_summarise(selected_queries, qa_llm_model, text_field, data_df):
155
+
156
+ qa_input_df = data_df[data_df["model_label"] != "none"].reset_index()
157
+ texts = qa_input_df[text_field].to_list()
158
+
159
+ summary = generate_answer(qa_llm_model, texts, selected_queries[0], selected_queries, mode="multi_summarize")
160
+
161
+ doc_df = pd.DataFrame()
162
+ doc_df["number"] = [i+1 for i in range(len(texts))]
163
+ doc_df["text"] = texts
164
+
165
+ return summary, doc_df
166
+
167
+
168
+ with gr.Blocks(fill_width=True) as demo:
169
+
170
+ demo.load(None,None,None,js=js)
171
+
172
+ event_models = ["jayebaku/distilbert-base-multilingual-cased-crexdata-relevance-classifier",
173
+ "jayebaku/distilbert-base-multilingual-cased-weather-classifier-2",
174
+ "jayebaku/twitter-xlm-roberta-base-crexdata-relevance-classifier",
175
+ "jayebaku/twhin-bert-base-crexdata-relevance-classifier"]
176
+
177
+ T_data_ss_state = gr.State(value=pd.DataFrame())
178
+
179
+
180
+ with gr.Tab("Event Type Classification"):
181
+ gr.Markdown(
182
+ """
183
+ # T4.5 Relevance Classifier Demo
184
+ This is a demo created to explore floods and wildfire classification in social media posts.\n
185
+ Usage:\n
186
+ - Upload .tsv or .csv data file (must contain a text column with social media posts).\n
187
+ - Next, type the name of the text column.\n
188
+ - Then, choose a BERT classifier model from the drop down.\n
189
+ - Finally, click the 'start prediction' buttton.\n
190
+ """)
191
+ with gr.Row():
192
+ with gr.Column(scale=4):
193
+ T_file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
194
+
195
+ with gr.Column(scale=6):
196
+ T_text_field = gr.Textbox(label="Text field name", value="tweet_text")
197
+ T_event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
198
+ T_predict_button = gr.Button("Start Prediction")
199
+ with gr.Accordion("Prediction threshold", open=False):
200
+ T_threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False,
201
+ info="This value sets a threshold by which texts classified flood or fire are accepted, \
202
+ higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
203
+
204
+ with gr.Row():
205
+ with gr.Column(scale=8):
206
+ T_data = gr.DataFrame(wrap=True,
207
+ show_fullscreen_button=True,
208
+ show_copy_button=True,
209
+ show_row_numbers=True,
210
+ show_search="filter",
211
+ column_widths=["49%","17%","17%","17%"])
212
+
213
+ with gr.Column(scale=2):
214
+ T_data_filter = gr.Dropdown(visible=False)
215
+ T_tweet_embed = gr.HTML("<h1>Select a Tweet ID to view Tweet</h1>")
216
+
217
+
218
+
219
+ with gr.Tab("Event Type Classification Eval"):
220
+ gr.Markdown(
221
+ """
222
+ # T4.5 Relevance Classifier Demo
223
+ This is a demo created to explore floods and wildfire classification in social media posts.\n
224
+ Usage:\n
225
+ - Upload .tsv or .csv data file (must contain a text column with social media posts).\n
226
+ - Next, type the name of the text column.\n
227
+ - Then, choose a BERT classifier model from the drop down.\n
228
+ - Finally, click the 'start prediction' buttton.\n
229
+ Evaluation:\n
230
+ - To evaluate the model's accuracy select the INCORRECT classifications using the checkboxes in front of each post.\n
231
+ - Then, click on the 'Calculate Accuracy' button.\n
232
+ - Then, click on the 'Download data as CSV' to get the classifications and evaluation data as a .csv file.
233
+ """)
234
+ with gr.Row():
235
+ with gr.Column(scale=4):
236
+ file_input = gr.File(label="Upload CSV or TSV File", file_types=['.tsv', '.csv'])
237
+
238
+ with gr.Column(scale=6):
239
+ text_field = gr.Textbox(label="Text field name", value="tweet_text")
240
+ event_model = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
241
+ ETCE_predict_button = gr.Button("Start Prediction")
242
+ with gr.Accordion("Prediction threshold", open=False):
243
+ threshold = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold", show_label=False,
244
+ info="This value sets a threshold by which texts classified flood or fire are accepted, \
245
+ higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
246
+
247
+ with gr.Row(): # XXX confirm this is not a problem later --equal_height=True
248
+ with gr.Column():
249
+ gr.Markdown("""### Flood-related""")
250
+ flood_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
251
+
252
+ with gr.Column():
253
+ gr.Markdown("""### Fire-related""")
254
+ fire_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
255
+
256
+ with gr.Column():
257
+ gr.Markdown("""### None""")
258
+ none_checkbox_output = gr.CheckboxGroup(label="Select ONLY incorrect classifications", interactive=True)
259
+
260
+ with gr.Row():
261
+ with gr.Column(scale=5):
262
+ gr.Markdown(r"""
263
+ Accuracy: is the model's ability to make correct predicitons.
264
+ It is the fraction of correct prediction out of the total predictions.
265
+
266
+ $$
267
+ \text{Accuracy} = \frac{\text{Correct predictions}}{\text{All predictions}} * 100
268
+ $$
269
+
270
+ Model Confidence: is the mean probabilty of each case
271
+ belonging to their assigned classes. A value of 1 is best.
272
+ """, latex_delimiters=[{ "left": "$$", "right": "$$", "display": True }])
273
+ gr.Markdown("\n\n\n")
274
+ model_confidence = gr.Number(label="Model Confidence")
275
+
276
+ with gr.Column(scale=5):
277
+ correct = gr.Number(label="Number of correct classifications")
278
+ incorrect = gr.Number(label="Number of incorrect classifications")
279
+ accuracy = gr.Number(label="Model Accuracy (%)")
280
+
281
+ ETCE_accuracy_button = gr.Button("Calculate Accuracy")
282
+ download_csv = gr.DownloadButton(visible=False)
283
+ num_posts = gr.Number(visible=False)
284
+ data = gr.DataFrame(visible=False)
285
+ data_eval = gr.DataFrame(visible=False)
286
+
287
+
288
+ qa_tab = gr.Tab("Question Answering")
289
+ with qa_tab:
290
+ gr.Markdown(
291
+ """
292
+ # Question Answering Demo
293
+ This section uses RAG to answer questions about the relevant social media posts identified by the relevance classifier\n
294
+ Usage:\n
295
+ - Select queries from predefined\n
296
+ - Parameters for QA can be editted in sidebar\n
297
+
298
+ Note: QA process is disabled untill after the relevance classification is done
299
+ """)
300
+
301
+ with gr.Accordion("Parameters", open=False):
302
+ with gr.Row():
303
+ with gr.Column():
304
+ qa_llm_model = gr.Dropdown(["mistral", "solar", "phi3mini"], label="QA model", value="phi3mini", interactive=True)
305
+ aggregator = gr.Dropdown(["linear", "outrank"], label="Aggregation method", value="linear", interactive=True)
306
+ with gr.Column():
307
+ batch_size = gr.Slider(50, 500, value=150, step=1, label="Batch size", info="Choose between 50 and 500", interactive=True)
308
+ topk = gr.Slider(1, 10, value=5, step=1, label="Number of results to retrieve", info="Choose between 1 and 10", interactive=True)
309
+
310
+ selected_queries = gr.CheckboxGroup(label="Select at least one query using the checkboxes", interactive=True)
311
+ queries_state = gr.State()
312
+ qa_tab.select(init_queries, inputs=queries_state, outputs=[selected_queries, queries_state])
313
+
314
+ query_inp = gr.Textbox(label="Add custom queries like the one above, one at a time")
315
+ QA_addqry_button = gr.Button("Add to queries", interactive=False)
316
+ QA_run_button = gr.Button("Start QA", interactive=False)
317
+ hsummary = gr.Textbox(label="Summary")
318
+
319
+ qa_df = gr.DataFrame()
320
+
321
+
322
+ with gr.Tab("Single Text Classification"):
323
+ gr.Markdown(
324
+ """
325
+ # Event Type Prediction Demo
326
+ In this section you test the relevance classifier with written texts.\n
327
+ Usage:\n
328
+ - Type a tweet-like text in the textbox.\n
329
+ - Then press Enter.\n
330
+ """)
331
+ with gr.Row():
332
+ with gr.Column(scale=3):
333
+ model_sing_classify = gr.Dropdown(event_models, value=event_models[0], label="Select classification model")
334
+ with gr.Column(scale=7):
335
+ threshold_sing_classify = gr.Slider(0, 1, value=0, step=0.01, label="Prediction threshold",
336
+ info="This value sets a threshold by which texts classified flood or fire are accepted, \
337
+ higher values makes the classifier stricter (CAUTION: A value of 1 will set all predictions as none)", interactive=True)
338
+
339
+ text_to_classify = gr.Textbox(label="Text", info="Enter tweet-like text", submit_btn=True)
340
+ text_to_classify_examples = gr.Examples([["The streets are flooded, I can't leave #BostonStorm"],
341
+ ["Controlado el incendio de Rodezno que ha obligado a desalojar a varias bodegas de la zona."],
342
+ ["Cambrils:estaci贸 Renfe inundada 19 persones dins d'un tren. FGC a Capellades, petit descarrilament 5 passatgers #Inuncat @emergenciescat"],
343
+ ["Anscheinend steht die komplette Neckarwiese unter Wasser! #Hochwasser"]], text_to_classify)
344
+
345
+ with gr.Row():
346
+ with gr.Column():
347
+ classification = gr.Textbox(label="Classification")
348
+ with gr.Column():
349
+ classification_score = gr.Number(label="Classification Score")
350
+
351
+
352
+
353
+
354
+
355
+
356
+
357
+
358
+ # Test event listeners
359
+ T_predict_button.click(
360
+ load_and_classify_csv_dataframe,
361
+ inputs=[T_file_input, T_text_field, T_event_model, T_threshold],
362
+ outputs=[T_data, T_data_ss_state, T_data_filter]
363
+ )
364
+
365
+ T_data.select(T_on_select, None, T_tweet_embed)#.then(fn=None, js="reloadTwitterWidgets()")
366
+
367
+ @T_data_filter.input(inputs=[T_data_ss_state, T_data_filter], outputs=T_data)
368
+ def filter_df(df, filter):
369
+ if filter == "All":
370
+ result_df = df.copy()
371
+ elif filter.startswith("Not"):
372
+ result_df = df[df["event_label"]!=filter.split('-')[1]].copy()
373
+ else:
374
+ result_df = df[df["event_label"]==filter].copy()
375
+ return gr.update(value=result_df)
376
+
377
+
378
+ # Button clicks ETC Eval
379
+ ETCE_predict_button.click(
380
+ load_and_classify_csv,
381
+ inputs=[file_input, text_field, event_model, threshold],
382
+ outputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, model_confidence, num_posts, data, QA_addqry_button, QA_run_button])
383
+
384
+ ETCE_accuracy_button.click(
385
+ calculate_accuracy,
386
+ inputs=[flood_checkbox_output, fire_checkbox_output, none_checkbox_output, num_posts, text_field, data],
387
+ outputs=[incorrect, correct, accuracy, data_eval, download_csv])
388
+
389
+
390
+ # Button clicks QA
391
+ QA_addqry_button.click(add_query, inputs=[query_inp, queries_state], outputs=[selected_queries, queries_state])
392
+
393
+ QA_run_button.click(qa_summarise,
394
+ inputs=[selected_queries, qa_llm_model, text_field, data], ## XXX fix text_field
395
+ outputs=[hsummary, qa_df])
396
+
397
+
398
+ # Event listener for single text classification
399
+ text_to_classify.submit(
400
+ single_classification,
401
+ inputs=[text_to_classify, model_sing_classify, threshold_sing_classify],
402
+ outputs=[classification, classification_score])
403
+
404
  demo.launch()