Darwinkel commited on
Commit
72bb8ba
·
verified ·
1 Parent(s): 14a2e71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -15
app.py CHANGED
@@ -5,13 +5,12 @@ from datasets import load_dataset
5
  from peft import PeftModel
6
  import os
7
 
8
- title = "Ask Rick a Question"
9
  description = """
10
- The bot was trained to answer questions based on Rick and Morty dialogues. Ask Rick anything!
11
- <img src="https://huggingface.co/spaces/course-demos/Rick_and_Morty_QA/resolve/main/rick.png" width=200px>
12
  """
13
 
14
- article = "Check out [the original Rick and Morty Bot](https://huggingface.co/spaces/kingabzpro/Rick_and_Morty_Bot) that this demo is based off of."
15
 
16
  model_id = "google/gemma-2b"
17
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
@@ -22,7 +21,7 @@ model = model.merge_and_unload()
22
 
23
 
24
  dataset = load_dataset("allenai/sciq")
25
- random_test_samples = dataset["test"].select(random.sample(range(0, len(dataset["test"])), 1))
26
 
27
  examples = []
28
  for row in random_test_samples:
@@ -52,38 +51,36 @@ def predict(context, answer):
52
  split_outputs[5],
53
  )
54
 
55
- return None
56
 
57
 
58
  support_gr = gr.TextArea(
59
  label="Context",
60
- info="Make sure you use proper punctuation.",
61
  value="Bananas are yellow and curved."
62
  )
63
 
64
  answer_gr = gr.Text(
65
- label="Answer optional",
66
- info="Make sure you use proper punctuation.",
67
  value="yellow"
68
  )
69
 
70
  context_output_gr = gr.Text(
71
- label="Output"
72
  )
73
  answer_output_gr = gr.Text(
74
- label="Output"
75
  )
76
  question_output_gr = gr.Text(
77
- label="Output"
78
  )
79
  distractor1_output_gr = gr.Text(
80
- label="Output"
81
  )
82
  distractor2_output_gr = gr.Text(
83
- label="Output"
84
  )
85
  distractor3_output_gr = gr.Text(
86
- label="Output"
87
  )
88
 
89
  gr.Interface(
 
5
  from peft import PeftModel
6
  import os
7
 
8
+ title = "Gemma-2b SciQ"
9
  description = """
10
+ Gemma-2b fine-tuned on SciQ
 
11
  """
12
 
13
+ article = "GitHub repository: <ur>"
14
 
15
  model_id = "google/gemma-2b"
16
  tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
 
21
 
22
 
23
  dataset = load_dataset("allenai/sciq")
24
+ random_test_samples = dataset["test"].select(random.sample(range(0, len(dataset["test"])), 10))
25
 
26
  examples = []
27
  for row in random_test_samples:
 
51
  split_outputs[5],
52
  )
53
 
54
+ return ("ERROR: " + decoded_outputs, None, None, None, None, None)
55
 
56
 
57
  support_gr = gr.TextArea(
58
  label="Context",
 
59
  value="Bananas are yellow and curved."
60
  )
61
 
62
  answer_gr = gr.Text(
63
+ label="Answer (optional)",
 
64
  value="yellow"
65
  )
66
 
67
  context_output_gr = gr.Text(
68
+ label="Context"
69
  )
70
  answer_output_gr = gr.Text(
71
+ label="Answer"
72
  )
73
  question_output_gr = gr.Text(
74
+ label="Question"
75
  )
76
  distractor1_output_gr = gr.Text(
77
+ label="Distractor 1"
78
  )
79
  distractor2_output_gr = gr.Text(
80
+ label="Distractor 2"
81
  )
82
  distractor3_output_gr = gr.Text(
83
+ label="Distractor 3"
84
  )
85
 
86
  gr.Interface(