hibble commited on
Commit
3c7902c
Β·
verified Β·
1 Parent(s): 9850940

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +824 -9
app.py CHANGED
@@ -1,11 +1,826 @@
1
- import torch
2
- from transformers import pipeline
 
 
 
 
 
 
 
3
 
4
- chat = [
5
- {"role": "system", "content": "You are a sassy, wise-cracking robot as imagined by Hollywood circa 1986."},
6
- {"role": "user", "content": "Hey, can you tell me any fun things to do in New York?"}
7
- ]
 
 
 
 
8
 
9
- pipeline = pipeline(task="text-generation", model="meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto")
10
- response = pipeline(chat, max_new_tokens=512)
11
- print(response[0]["generated_text"][-1]["content"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import uuid
4
+ from base64 import b64encode
5
+ from datetime import datetime
6
+ from mimetypes import guess_type
7
+ from pathlib import Path
8
+ from typing import Optional
9
+ import json
10
 
11
+ import spaces
12
+ import spaces
13
+ import gradio as gr
14
+ from feedback import save_feedback, scheduler
15
+ from gradio.components.chatbot import OptionDict
16
+ from huggingface_hub import InferenceClient
17
+ from pandas import DataFrame
18
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
19
 
20
+
21
+ BASE_MODEL = os.getenv("MODEL", "google/gemma-3-12b-pt")
22
+ ZERO_GPU = (
23
+ bool(os.getenv("ZERO_GPU", False)) or True
24
+ if str(os.getenv("ZERO_GPU")).lower() == "true"
25
+ else False
26
+ )
27
+ TEXT_ONLY = (
28
+ bool(os.getenv("TEXT_ONLY", False)) or True
29
+ if str(os.getenv("TEXT_ONLY")).lower() == "true"
30
+ else False
31
+ )
32
+
33
+
34
+ def create_inference_client(
35
+ model: Optional[str] = None, base_url: Optional[str] = None
36
+ ) -> InferenceClient | dict:
37
+ """Create an InferenceClient instance with the given model or environment settings.
38
+ This function will run the model locally if ZERO_GPU is set to True.
39
+ This function will run the model locally if ZERO_GPU is set to True.
40
+
41
+ Args:
42
+ model: Optional model identifier to use. If not provided, will use environment settings.
43
+ base_url: Optional base URL for the inference API.
44
+
45
+ Returns:
46
+ Either an InferenceClient instance or a dictionary with pipeline and tokenizer
47
+ """
48
+ if ZERO_GPU:
49
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
50
+ model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, load_in_8bit=True)
51
+ return {
52
+ "pipeline": pipeline(
53
+ "text-generation",
54
+ model=model,
55
+ tokenizer=tokenizer,
56
+ max_new_tokens=2000,
57
+ ),
58
+ "tokenizer": tokenizer
59
+ }
60
+ else:
61
+ return InferenceClient(
62
+ token=os.getenv("HF_TOKEN"),
63
+ model=model if model else (BASE_MODEL if not base_url else None),
64
+ base_url=base_url,
65
+ )
66
+
67
+
68
+ CLIENT = create_inference_client()
69
+
70
+
71
+ def get_persistent_storage_path(filename: str) -> tuple[Path, bool]:
72
+ """Check if persistent storage is available and return the appropriate path.
73
+
74
+ Args:
75
+ filename: The name of the file to check/create
76
+
77
+ Returns:
78
+ A tuple containing (file_path, is_persistent)
79
+ """
80
+ persistent_path = Path("/data") / filename
81
+ local_path = Path(__file__).parent / filename
82
+
83
+ # Check if persistent storage is available and writable
84
+ use_persistent = False
85
+ if Path("/data").exists() and Path("/data").is_dir():
86
+ try:
87
+ # Test if we can write to the directory
88
+ test_file = Path("/data/write_test.tmp")
89
+ test_file.touch()
90
+ test_file.unlink() # Remove the test file
91
+ use_persistent = True
92
+ except (PermissionError, OSError):
93
+ print("Persistent storage exists but is not writable, falling back to local storage")
94
+ use_persistent = False
95
+
96
+ return (persistent_path if use_persistent else local_path, use_persistent)
97
+
98
+
99
+ def load_languages() -> dict[str, str]:
100
+ """Load languages from JSON file or persistent storage"""
101
+ languages_path, use_persistent = get_persistent_storage_path("languages.json")
102
+ local_path = Path(__file__).parent / "languages.json"
103
+
104
+ # If persistent storage is available but file doesn't exist yet, copy the local file to persistent storage
105
+ if use_persistent and not languages_path.exists():
106
+ try:
107
+ if local_path.exists():
108
+ import shutil
109
+ shutil.copy(local_path, languages_path)
110
+ print(f"Copied languages to persistent storage at {languages_path}")
111
+ else:
112
+ with open(languages_path, "w", encoding="utf-8") as f:
113
+ json.dump({"English": "You are a helpful assistant."}, f, ensure_ascii=False, indent=2)
114
+ print(f"Created new languages file in persistent storage at {languages_path}")
115
+ except Exception as e:
116
+ print(f"Error setting up persistent storage: {e}")
117
+ languages_path = local_path # Fall back to local path if any error occurs
118
+
119
+ if not languages_path.exists() and local_path.exists():
120
+ languages_path = local_path
121
+
122
+ if languages_path.exists():
123
+ with open(languages_path, "r", encoding="utf-8") as f:
124
+ return json.load(f)
125
+ else:
126
+ default_languages = {"English": "You are a helpful assistant."}
127
+ return default_languages
128
+
129
+ LANGUAGES = load_languages()
130
+
131
+ USER_AGREEMENT = """
132
+ You have been asked to participate in a research study conducted by Lingo Lab from the Computer Science and Artificial Intelligence Laboratory at the Massachusetts Institute of Technology (M.I.T.), together with huggingface.
133
+
134
+ The purpose of this study is the collection of multilingual human feedback to improve language models. As part of this study you will interat with a language model in a langugage of your choice, and provide indication to wether its reponses are helpful or not.
135
+
136
+ Your name and personal data will never be recorded. You may decline further participation, at any time, without adverse consequences.There are no foreseeable risks or discomforts for participating in this study. Note participating in the study may pose risks that are currently unforeseeable. If you have questions or concerns about the study, you can contact the researchers at leshem@mit.edu. If you have any questions about your rights as a participant in this research (E-6610), feel you have been harmed, or wish to discuss other study-related concerns with someone who is not part of the research team, you can contact the M.I.T. Committee on the Use of Humans as Experimental Subjects (COUHES) by phone at (617) 253-8420, or by email at couhes@mit.edu.
137
+
138
+ Clicking on the next button at the bottom of this page indicates that you are at least 18 years of age and willingly agree to participate in the research voluntarily.
139
+ """
140
+
141
+
142
+ def add_user_message(history, message):
143
+ if isinstance(message, dict) and "files" in message:
144
+ for x in message["files"]:
145
+ history.append({"role": "user", "content": {"path": x}})
146
+ if message["text"] is not None:
147
+ history.append({"role": "user", "content": message["text"]})
148
+ else:
149
+ history.append({"role": "user", "content": message})
150
+ return history, gr.Textbox(value=None, interactive=False)
151
+
152
+
153
+ def format_system_message(language: str, history: list):
154
+ system_message = [
155
+ {
156
+ "role": "system",
157
+ "content": LANGUAGES.get(language, LANGUAGES["English"]),
158
+ }
159
+ ]
160
+ if history and history[0]["role"] == "system":
161
+ history = history[1:]
162
+ history = system_message + history
163
+ return history
164
+
165
+
166
+ def format_history_as_messages(history: list):
167
+ messages = []
168
+ current_role = None
169
+ current_message_content = []
170
+
171
+ if TEXT_ONLY:
172
+ for entry in history:
173
+ messages.append({"role": entry["role"], "content": entry["content"]})
174
+ return messages
175
+
176
+ if TEXT_ONLY:
177
+ for entry in history:
178
+ messages.append({"role": entry["role"], "content": entry["content"]})
179
+ return messages
180
+
181
+ for entry in history:
182
+ content = entry["content"]
183
+
184
+ if entry["role"] != current_role:
185
+ if current_role is not None:
186
+ messages.append(
187
+ {"role": current_role, "content": current_message_content}
188
+ )
189
+ current_role = entry["role"]
190
+ current_message_content = []
191
+
192
+ if isinstance(content, tuple): # Handle file paths
193
+ for temp_path in content:
194
+ if space_host := os.getenv("SPACE_HOST"):
195
+ url = f"https://{space_host}/gradio_api/file%3D{temp_path}"
196
+ else:
197
+ url = _convert_path_to_data_uri(temp_path)
198
+ current_message_content.append(
199
+ {"type": "image_url", "image_url": {"url": url}}
200
+ )
201
+ elif isinstance(content, str): # Handle text
202
+ current_message_content.append({"type": "text", "text": content})
203
+
204
+ if current_role is not None:
205
+ messages.append({"role": current_role, "content": current_message_content})
206
+
207
+ return messages
208
+
209
+
210
+ def _convert_path_to_data_uri(path) -> str:
211
+ mime_type, _ = guess_type(path)
212
+ with open(path, "rb") as image_file:
213
+ data = image_file.read()
214
+ data_uri = f"data:{mime_type};base64," + b64encode(data).decode("utf-8")
215
+ return data_uri
216
+
217
+
218
+ def _is_file_safe(path) -> bool:
219
+ try:
220
+ return Path(path).is_file()
221
+ except Exception:
222
+ return ""
223
+
224
+
225
+ def _process_content(content) -> str | list[str]:
226
+ if isinstance(content, str) and _is_file_safe(content):
227
+ return _convert_path_to_data_uri(content)
228
+ elif isinstance(content, list) or isinstance(content, tuple):
229
+ return _convert_path_to_data_uri(content[0])
230
+ return content
231
+
232
+
233
+ def _process_rating(rating) -> int:
234
+ if isinstance(rating, str):
235
+ return 0
236
+ elif isinstance(rating, int):
237
+ return rating
238
+ else:
239
+ raise ValueError(f"Invalid rating: {rating}")
240
+
241
+
242
+ def add_fake_like_data(
243
+ history: list,
244
+ conversation_id: str,
245
+ session_id: str,
246
+ language: str,
247
+ liked: bool = False,
248
+ ) -> None:
249
+ data = {
250
+ "index": len(history) - 1,
251
+ "value": history[-1],
252
+ "liked": liked,
253
+ }
254
+ _, dataframe = wrangle_like_data(
255
+ gr.LikeData(target=None, data=data), history.copy()
256
+ )
257
+ submit_conversation(
258
+ dataframe=dataframe,
259
+ conversation_id=conversation_id,
260
+ session_id=session_id,
261
+ language=language,
262
+ )
263
+
264
+
265
+ @spaces.GPU
266
+ def call_pipeline(messages: list, language: str):
267
+ """Call the appropriate model pipeline based on configuration"""
268
+ if ZERO_GPU:
269
+ tokenizer = CLIENT["tokenizer"]
270
+ formatted_prompt = tokenizer.apply_chat_template(
271
+ messages,
272
+ tokenize=False,
273
+ )
274
+
275
+ response = CLIENT["pipeline"](
276
+ formatted_prompt,
277
+ clean_up_tokenization_spaces=False,
278
+ max_length=2000,
279
+ return_full_text=False,
280
+ )
281
+
282
+ return response[0]["generated_text"]
283
+ else:
284
+ response = CLIENT(
285
+ messages,
286
+ clean_up_tokenization_spaces=False,
287
+ max_length=2000,
288
+ )
289
+ return response[0]["generated_text"][-1]["content"]
290
+
291
+
292
+ def respond(
293
+ history: list,
294
+ language: str,
295
+ temperature: Optional[float] = None,
296
+ seed: Optional[int] = None,
297
+ ) -> list:
298
+ """Respond to the user message with a system message
299
+
300
+ Return the history with the new message"""
301
+ messages = format_history_as_messages(history)
302
+
303
+ if ZERO_GPU:
304
+ content = call_pipeline(messages, language)
305
+ else:
306
+ response = CLIENT.chat.completions.create(
307
+ messages=messages,
308
+ max_tokens=2000,
309
+ stream=False,
310
+ seed=seed,
311
+ temperature=temperature,
312
+ )
313
+ content = response.choices[0].message.content
314
+
315
+ message = gr.ChatMessage(role="assistant", content=content)
316
+ history.append(message)
317
+ return history
318
+
319
+
320
+ def update_dataframe(dataframe: DataFrame, history: list) -> DataFrame:
321
+ """Update the dataframe with the new message"""
322
+ data = {
323
+ "index": 9999,
324
+ "value": None,
325
+ "liked": False,
326
+ }
327
+ _, dataframe = wrangle_like_data(
328
+ gr.LikeData(target=None, data=data), history.copy()
329
+ )
330
+ return dataframe
331
+
332
+
333
+ def wrangle_like_data(x: gr.LikeData, history) -> DataFrame:
334
+ """Wrangle conversations and liked data into a DataFrame"""
335
+
336
+ if isinstance(x.index, int):
337
+ liked_index = x.index
338
+ else:
339
+ liked_index = x.index[0]
340
+
341
+ output_data = []
342
+ for idx, message in enumerate(history):
343
+ if isinstance(message, gr.ChatMessage):
344
+ message = message.__dict__
345
+ if idx == liked_index:
346
+ if x.liked is True:
347
+ message["metadata"] = {"title": "liked"}
348
+ elif x.liked is False:
349
+ message["metadata"] = {"title": "disliked"}
350
+
351
+ if message["metadata"] is None:
352
+ message["metadata"] = {}
353
+ elif not isinstance(message["metadata"], dict):
354
+ message["metadata"] = message["metadata"].__dict__
355
+
356
+ rating = message["metadata"].get("title")
357
+ if rating == "liked":
358
+ message["rating"] = 1
359
+ elif rating == "disliked":
360
+ message["rating"] = -1
361
+ else:
362
+ message["rating"] = 0
363
+
364
+ message["chosen"] = ""
365
+ message["rejected"] = ""
366
+ if message["options"]:
367
+ for option in message["options"]:
368
+ if not isinstance(option, dict):
369
+ option = option.__dict__
370
+ message[option["label"]] = option["value"]
371
+ else:
372
+ if message["rating"] == 1:
373
+ message["chosen"] = message["content"]
374
+ elif message["rating"] == -1:
375
+ message["rejected"] = message["content"]
376
+
377
+ output_data.append(
378
+ dict(
379
+ [(k, v) for k, v in message.items() if k not in ["metadata", "options"]]
380
+ )
381
+ )
382
+
383
+ return history, DataFrame(data=output_data)
384
+
385
+
386
+ def wrangle_edit_data(
387
+ x: gr.EditData,
388
+ history: list,
389
+ dataframe: DataFrame,
390
+ conversation_id: str,
391
+ session_id: str,
392
+ language: str,
393
+ ) -> list:
394
+ """Edit the conversation and add negative feedback if assistant message is edited, otherwise regenerate the message
395
+
396
+ Return the history with the new message"""
397
+ if isinstance(x.index, int):
398
+ index = x.index
399
+ else:
400
+ index = x.index[0]
401
+
402
+ original_message = gr.ChatMessage(
403
+ role="assistant", content=dataframe.iloc[index]["content"]
404
+ ).__dict__
405
+
406
+ if history[index]["role"] == "user":
407
+ # Add feedback on original and corrected message
408
+ add_fake_like_data(
409
+ history=history[: index + 2],
410
+ conversation_id=conversation_id,
411
+ session_id=session_id,
412
+ language=language,
413
+ liked=True,
414
+ )
415
+ add_fake_like_data(
416
+ history=history[: index + 1] + [original_message],
417
+ conversation_id=conversation_id,
418
+ session_id=session_id,
419
+ language=language,
420
+ )
421
+ history = respond(
422
+ history=history[: index + 1],
423
+ language=language,
424
+ temperature=random.randint(1, 100) / 100,
425
+ seed=random.randint(0, 1000000),
426
+ )
427
+ return history
428
+ else:
429
+ add_fake_like_data(
430
+ history=history[: index + 1],
431
+ conversation_id=conversation_id,
432
+ session_id=session_id,
433
+ language=language,
434
+ liked=True,
435
+ )
436
+ add_fake_like_data(
437
+ history=history[:index] + [original_message],
438
+ conversation_id=conversation_id,
439
+ session_id=session_id,
440
+ language=language,
441
+ )
442
+ history = history[: index + 1]
443
+ history[-1]["options"] = [
444
+ OptionDict(label="chosen", value=x.value),
445
+ OptionDict(label="rejected", value=original_message["content"]),
446
+ ]
447
+ return history
448
+
449
+
450
+ def wrangle_retry_data(
451
+ x: gr.RetryData,
452
+ history: list,
453
+ dataframe: DataFrame,
454
+ conversation_id: str,
455
+ session_id: str,
456
+ language: str,
457
+ ) -> list:
458
+ """Respond to the user message with a system message and add negative feedback on the original message
459
+
460
+ Return the history with the new message"""
461
+ add_fake_like_data(
462
+ history=history,
463
+ conversation_id=conversation_id,
464
+ session_id=session_id,
465
+ language=language,
466
+ )
467
+
468
+ # Return the history without a new message
469
+ history = respond(
470
+ history=history[:-1],
471
+ language=language,
472
+ temperature=random.randint(1, 100) / 100,
473
+ seed=random.randint(0, 1000000),
474
+ )
475
+ return history, update_dataframe(dataframe, history)
476
+
477
+
478
+ def submit_conversation(dataframe, conversation_id, session_id, language):
479
+ """ "Submit the conversation to dataset repo"""
480
+ if dataframe.empty or len(dataframe) < 2:
481
+ gr.Info("No feedback to submit.")
482
+ return (gr.Dataframe(value=None, interactive=False), [])
483
+
484
+ dataframe["content"] = dataframe["content"].apply(_process_content)
485
+ dataframe["rating"] = dataframe["rating"].apply(_process_rating)
486
+ conversation = dataframe.to_dict(orient="records")
487
+ conversation_data = {
488
+ "conversation": conversation,
489
+ "timestamp": datetime.now().isoformat(),
490
+ "session_id": session_id,
491
+ "conversation_id": conversation_id,
492
+ "language": language,
493
+ }
494
+ save_feedback(input_object=conversation_data)
495
+ return (gr.Dataframe(value=None, interactive=False), [])
496
+
497
+
498
+ def open_add_language_modal():
499
+ return gr.Group(visible=True)
500
+
501
+ def close_add_language_modal():
502
+ return gr.Group(visible=False)
503
+
504
+ def save_new_language(lang_name, system_prompt):
505
+ """Save the new language and system prompt to persistent storage if available, otherwise to local file."""
506
+ global LANGUAGES
507
+
508
+ languages_path, use_persistent = get_persistent_storage_path("languages.json")
509
+ local_path = Path(__file__).parent / "languages.json"
510
+
511
+ if languages_path.exists():
512
+ with open(languages_path, "r", encoding="utf-8") as f:
513
+ data = json.load(f)
514
+ else:
515
+ data = {}
516
+
517
+ data[lang_name] = system_prompt
518
+
519
+ with open(languages_path, "w", encoding="utf-8") as f:
520
+ json.dump(data, f, ensure_ascii=False, indent=2)
521
+
522
+ if use_persistent and local_path != languages_path:
523
+ try:
524
+ with open(local_path, "w", encoding="utf-8") as f:
525
+ json.dump(data, f, ensure_ascii=False, indent=2)
526
+ except Exception as e:
527
+ print(f"Error updating local backup: {e}")
528
+
529
+ LANGUAGES.update({lang_name: system_prompt})
530
+ return gr.Group(visible=False), gr.HTML("<script>window.location.reload();</script>"), gr.Dropdown(choices=list(LANGUAGES.keys()))
531
+
532
+
533
+ css = """
534
+ .options.svelte-pcaovb {
535
+ display: none !important;
536
+ }
537
+ .option.svelte-pcaovb {
538
+ display: none !important;
539
+ }
540
+ .retry-btn {
541
+ display: none !important;
542
+ }
543
+ /* Style for the add language button */
544
+ button#add-language-btn {
545
+ padding: 0 !important;
546
+ font-size: 30px !important;
547
+ font-weight: bold !important;
548
+ }
549
+ /* Style for the user agreement container */
550
+ .user-agreement-container {
551
+ box-shadow: 0 2px 5px rgba(0,0,0,0.1) !important;
552
+ max-height: 300px;
553
+ overflow-y: auto;
554
+ padding: 10px;
555
+ border: 1px solid #ddd;
556
+ border-radius: 5px;
557
+ margin-bottom: 10px;
558
+ }
559
+ /* Style for the consent modal */
560
+ .consent-modal {
561
+ position: fixed !important;
562
+ top: 50% !important;
563
+ left: 50% !important;
564
+ transform: translate(-50%, -50%) !important;
565
+ z-index: 9999 !important;
566
+ background: white !important;
567
+ padding: 10px !important;
568
+ border-radius: 8px !important;
569
+ box-shadow: 0 4px 10px rgba(0,0,0,0.2) !important;
570
+ max-width: 90% !important;
571
+ width: 600px !important;
572
+ }
573
+ /* Overlay for the consent modal */
574
+ .modal-overlay {
575
+ position: fixed !important;
576
+ top: 0 !important;
577
+ left: 0 !important;
578
+ width: 100% !important;
579
+ height: 100% !important;
580
+ background-color: rgba(0, 0, 0, 0.7) !important;
581
+ z-index: 9998 !important;
582
+ }
583
+ """
584
+
585
+ def get_config(request: gr.Request):
586
+ """Get configuration from cookies"""
587
+ config = {"feel_consent": False}
588
+
589
+ if request and hasattr(request, 'cookies'):
590
+ for key in config.keys():
591
+ if key in request.cookies:
592
+ config[key] = request.cookies[key] == 'true'
593
+
594
+ return config["feel_consent"]
595
+
596
+ js = '''function js(){
597
+ window.set_cookie = function(key, value){
598
+ document.cookie = key+'='+value+'; Path=/; SameSite=Strict';
599
+ return [value];
600
+ }
601
+ }'''
602
+
603
+
604
+
605
+ with gr.Blocks(css=css, js=js) as demo:
606
+ # State variable to track if user has consented
607
+ user_consented = gr.State(value=False)
608
+
609
+ # Main application interface (initially visible but will be conditionally shown)
610
+ with gr.Group() as main_app: # Remove explicit visible=True to let it be controlled dynamically
611
+ ##############################
612
+ # Chatbot
613
+ ##############################
614
+ gr.Markdown("""
615
+ # ♾️ FeeL - a real-time Feedback Loop for LMs
616
+ """)
617
+
618
+ with gr.Accordion("About") as explanation:
619
+ gr.Markdown(f"""
620
+ FeeL is a collaboration between Hugging Face and MIT.
621
+ It is a community-driven project to provide a real-time feedback loop for VLMs, where your feedback is continuously used to fine-tune the underlying models.
622
+ The [dataset](https://huggingface.co/datasets/{scheduler.repo_id}), [code](https://github.com/huggingface/feel) and [models](https://huggingface.co/collections/feel-fl/feel-models-67a9b6ef0fdd554315e295e8) are public.
623
+
624
+ Start by selecting your language, chat with the model with text and images and provide feedback in different ways.
625
+
626
+ - ✏️ Edit a message
627
+ - πŸ‘/πŸ‘Ž Like or dislike a message
628
+ - πŸ”„ Regenerate a message
629
+
630
+ """)
631
+
632
+ with gr.Column():
633
+ gr.Markdown("Select your language or add a new one:")
634
+ with gr.Row():
635
+ language = gr.Dropdown(
636
+ choices=list(load_languages().keys()),
637
+ container=False,
638
+ show_label=False,
639
+ scale=8
640
+ )
641
+ add_language_btn = gr.Button(
642
+ "+",
643
+ elem_id="add-language-btn",
644
+ size="sm"
645
+ )
646
+
647
+
648
+ # Create a hidden group instead of a modal
649
+ with gr.Group(visible=False) as add_language_modal:
650
+ gr.Markdown("&nbsp;Add New Language")
651
+ new_lang_name = gr.Textbox(label="Language Name", lines=1)
652
+ new_system_prompt = gr.Textbox(label="System Prompt", lines=4)
653
+ with gr.Row():
654
+ with gr.Column(scale=1):
655
+ save_language_btn = gr.Button("Save")
656
+ with gr.Column(scale=1):
657
+ cancel_language_btn = gr.Button("Cancel")
658
+
659
+ refresh_html = gr.HTML(visible=False)
660
+
661
+ session_id = gr.Textbox(
662
+ interactive=False,
663
+ value=str(uuid.uuid4()),
664
+ visible=False,
665
+ )
666
+
667
+ conversation_id = gr.Textbox(
668
+ interactive=False,
669
+ value=str(uuid.uuid4()),
670
+ visible=False,
671
+ )
672
+
673
+ chatbot = gr.Chatbot(
674
+ elem_id="chatbot",
675
+ editable="all",
676
+ value=[
677
+ {
678
+ "role": "system",
679
+ "content": LANGUAGES[language.value],
680
+ }
681
+ ],
682
+ type="messages",
683
+ feedback_options=["Like", "Dislike"],
684
+ )
685
+
686
+ chat_input = gr.Textbox(
687
+ interactive=True,
688
+ placeholder="Enter message or upload file...",
689
+ show_label=False,
690
+ submit_btn=True,
691
+ )
692
+
693
+ with gr.Accordion("Collected feedback", open=False):
694
+ dataframe = gr.Dataframe(wrap=True, label="Collected feedback")
695
+
696
+ submit_btn = gr.Button(value="πŸ’Ύ Submit conversation", visible=False)
697
+
698
+ # Overlay for the consent modal
699
+ with gr.Group(elem_classes=["modal-overlay"], visible=False) as consent_overlay:
700
+ pass
701
+
702
+ # Consent popup
703
+ with gr.Group(elem_classes=["consent-modal"], visible=False) as consent_modal:
704
+ gr.Markdown("# User Agreement")
705
+ with gr.Group(elem_classes=["user-agreement-container"]):
706
+ gr.Markdown(USER_AGREEMENT)
707
+ consent_btn = gr.Button("I agree")
708
+
709
+ # Check consent on page load and show/hide components appropriately
710
+ def initialize_consent_status():
711
+ # This function will be called when the app loads
712
+ return False # Default to not consented
713
+
714
+ def update_visibility(has_consent):
715
+ # Show/hide components based on consent status
716
+ return (
717
+ gr.Group(visible=True), # main_app
718
+ gr.Group(visible=not has_consent), # consent_overlay
719
+ gr.Group(visible=not has_consent) # consent_modal
720
+ )
721
+
722
+ # Initialize app with consent checking
723
+ demo.load(fn=get_config, js=js, outputs=user_consented).then(
724
+ fn=update_visibility,
725
+ inputs=user_consented,
726
+ outputs=[main_app, consent_overlay, consent_modal]
727
+ )
728
+
729
+ # Function to handle consent button click
730
+ def handle_consent():
731
+ return True
732
+
733
+ consent_btn.click(
734
+ fn=handle_consent,
735
+ outputs=user_consented,
736
+ js="(value) => set_cookie('feel_consent', 'true')"
737
+ ).then(
738
+ fn=update_visibility,
739
+ inputs=user_consented,
740
+ outputs=[main_app, consent_overlay, consent_modal]
741
+ )
742
+
743
+ ##############################
744
+ # Deal with feedback
745
+ ##############################
746
+
747
+ language.change(
748
+ fn=format_system_message,
749
+ inputs=[language, chatbot],
750
+ outputs=[chatbot],
751
+ )
752
+
753
+ chat_input.submit(
754
+ fn=add_user_message,
755
+ inputs=[chatbot, chat_input],
756
+ outputs=[chatbot, chat_input],
757
+ ).then(respond, inputs=[chatbot, language], outputs=[chatbot]).then(
758
+ lambda: gr.Textbox(interactive=True), None, [chat_input]
759
+ ).then(update_dataframe, inputs=[dataframe, chatbot], outputs=[dataframe]).then(
760
+ submit_conversation,
761
+ inputs=[dataframe, conversation_id, session_id, language],
762
+ )
763
+
764
+ chatbot.like(
765
+ fn=wrangle_like_data,
766
+ inputs=[chatbot],
767
+ outputs=[chatbot, dataframe],
768
+ like_user_message=False,
769
+ ).then(
770
+ submit_conversation,
771
+ inputs=[dataframe, conversation_id, session_id, language],
772
+ )
773
+
774
+ chatbot.retry(
775
+ fn=wrangle_retry_data,
776
+ inputs=[chatbot, dataframe, conversation_id, session_id, language],
777
+ outputs=[chatbot, dataframe],
778
+ )
779
+
780
+ chatbot.edit(
781
+ fn=wrangle_edit_data,
782
+ inputs=[chatbot, dataframe, conversation_id, session_id, language],
783
+ outputs=[chatbot],
784
+ ).then(update_dataframe, inputs=[dataframe, chatbot], outputs=[dataframe])
785
+
786
+ gr.on(
787
+ triggers=[submit_btn.click, chatbot.clear],
788
+ fn=submit_conversation,
789
+ inputs=[dataframe, conversation_id, session_id, language],
790
+ outputs=[dataframe, chatbot],
791
+ ).then(
792
+ fn=lambda x: str(uuid.uuid4()),
793
+ inputs=[conversation_id],
794
+ outputs=[conversation_id],
795
+ )
796
+
797
+ def on_app_load():
798
+ global LANGUAGES
799
+ LANGUAGES = load_languages()
800
+ language_choices = list(LANGUAGES.keys())
801
+
802
+ return str(uuid.uuid4()), gr.Dropdown(choices=language_choices, value=language_choices[0])
803
+
804
+ demo.load(
805
+ fn=on_app_load,
806
+ inputs=None,
807
+ outputs=[session_id, language]
808
+ )
809
+
810
+ add_language_btn.click(
811
+ fn=lambda: gr.Group(visible=True),
812
+ outputs=[add_language_modal]
813
+ )
814
+
815
+ cancel_language_btn.click(
816
+ fn=lambda: gr.Group(visible=False),
817
+ outputs=[add_language_modal]
818
+ )
819
+
820
+ save_language_btn.click(
821
+ fn=save_new_language,
822
+ inputs=[new_lang_name, new_system_prompt],
823
+ outputs=[add_language_modal, refresh_html, language]
824
+ )
825
+
826
+ demo.launch()