Jeff Myers II commited on
Commit
54e639b
·
1 Parent(s): 8b78b70

Replaced Gemma for OpenAI API and commented out visualize_quiz_answers (seems to be breaking the script).

Browse files
Gemma.py CHANGED
@@ -3,12 +3,14 @@ from huggingface_hub import login
3
  import spaces
4
  import json
5
  import os
 
6
 
7
  __export__ = ["GemmaLLM"]
8
 
9
  class GemmaLLM:
10
  def __init__(self):
11
  login(token=os.environ.get("GEMMA_TOKEN"))
 
12
 
13
  # quant_config = quantization_config.BitsAndBytesConfig(
14
  # load_in_8bit=True,
@@ -16,21 +18,23 @@ class GemmaLLM:
16
  # llm_int8_has_fp16_weight=False,
17
  # )
18
 
19
- model_id = "google/gemma-3-4b-it"
20
  # model_id = "google/gemma-3n-E4B-it-litert-preview"
21
  # model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant_config)
22
  # tokenizer = AutoTokenizer.from_pretrained(model_id)
23
 
24
- self.model = pipeline("text-generation", model=model_id)
25
 
26
  @spaces.GPU
27
  def generate(self, message) -> str:
28
- outputs = self.model(message, max_new_tokens=1024)[0]["generated_text"]
29
-
30
- return outputs
 
 
 
31
 
32
  def _get_summary_message(self, article, num_paragraphs) -> dict:
33
-
34
  summarize = "You are a helpful assistant. Your main task is to summarize articles. You will be given an article that you will generate a summary for. The summary should include all the key points of the article. ONLY RESPOND WITH THE SUMMARY!!!"
35
 
36
  summary = f"Summarize the data in the following JSON into {num_paragraphs} paragraph(s) so that it is easy to read and understand:\n"
@@ -43,8 +47,7 @@ class GemmaLLM:
43
  def get_summary(self, article, num_paragraphs) -> str:
44
  message = self._get_summary_message(article, num_paragraphs)
45
  summary = self.generate(message)
46
-
47
- return summary[2]["content"]
48
 
49
  def _get_questions_message(self, summary, num_questions, difficulty) -> dict:
50
  question = f"""
@@ -62,7 +65,33 @@ class GemmaLLM:
62
  return message
63
 
64
  def get_questions(self, summary, num_questions, difficulty) -> dict:
 
65
  message = self._get_questions_message(summary, num_questions, difficulty)
 
66
  questions = self.generate(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- return json.loads(questions[2]["content"].strip("```").replace("json\n", ""))
 
 
 
 
 
 
3
  import spaces
4
  import json
5
  import os
6
+ import openai
7
 
8
  __export__ = ["GemmaLLM"]
9
 
10
  class GemmaLLM:
11
  def __init__(self):
12
  login(token=os.environ.get("GEMMA_TOKEN"))
13
+ openai.api_key = os.environ.get("OPENAI_API_KEY")
14
 
15
  # quant_config = quantization_config.BitsAndBytesConfig(
16
  # load_in_8bit=True,
 
18
  # llm_int8_has_fp16_weight=False,
19
  # )
20
 
21
+ # model_id = "google/gemma-3-4b-it"
22
  # model_id = "google/gemma-3n-E4B-it-litert-preview"
23
  # model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quant_config)
24
  # tokenizer = AutoTokenizer.from_pretrained(model_id)
25
 
26
+ # self.model = pipeline("text-generation", model=model_id)
27
 
28
  @spaces.GPU
29
  def generate(self, message) -> str:
30
+ response = openai.chat.completions.create(
31
+ model="gpt-4.1", # You can use "gpt-4" if you have access
32
+ messages=message)
33
+ return response.choices[0].message.content
34
+ # outputs = self.model(message, max_new_tokens=1024)[0]["generated_text"]
35
+ # return outputs
36
 
37
  def _get_summary_message(self, article, num_paragraphs) -> dict:
 
38
  summarize = "You are a helpful assistant. Your main task is to summarize articles. You will be given an article that you will generate a summary for. The summary should include all the key points of the article. ONLY RESPOND WITH THE SUMMARY!!!"
39
 
40
  summary = f"Summarize the data in the following JSON into {num_paragraphs} paragraph(s) so that it is easy to read and understand:\n"
 
47
  def get_summary(self, article, num_paragraphs) -> str:
48
  message = self._get_summary_message(article, num_paragraphs)
49
  summary = self.generate(message)
50
+ return summary
 
51
 
52
  def _get_questions_message(self, summary, num_questions, difficulty) -> dict:
53
  question = f"""
 
65
  return message
66
 
67
  def get_questions(self, summary, num_questions, difficulty) -> dict:
68
+ # print("Getting questions message...")
69
  message = self._get_questions_message(summary, num_questions, difficulty)
70
+ # print("Generating questions...")
71
  questions = self.generate(message)
72
+ # print(questions)
73
+ questions = questions.strip("```")
74
+ questions = questions.replace("json\n", "")
75
+ questions = json.loads(questions)
76
+ print(questions)
77
+ return questions
78
+
79
+ if __name__ == "__main__":
80
+ gemma = GemmaLLM()
81
+ summary = gemma.get_summary('''
82
+ Iran could reach an agreement similar to the 2015 nuclear deal to end its current conflict with Israel, the Iranian foreign minister said Saturday, according to state media.
83
+
84
+ “It is clear the Israelis are against diplomacy and do not want the current crisis to be resolved peacefully,” Abbas Araghchi said, according to state-run news agency IRNA.
85
+
86
+ “We are, however, certainly prepared to reach a negotiated solution, similar to the one we reached in 2015,” the foreign minister said, according to IRNA.
87
+
88
+ Remember: The 2015 deal, formally known as the Joint Comprehensive Plan of Action, required Iran to limit its nuclear program in return for provisions including relief on sanctions. US President Donald Trump withdrew from the landmark agreement in 2018 and Iran has increasingly grown its nuclear program since then.
89
+
90
+ More from the foreign minister: “That agreement was the result of two years of tireless negotiations, and when it was signed, it was welcomed by the entire world,” Araghchi said of the 2015 deal, according to IRNA.
91
 
92
+ “Diplomacy can be effective — just as it was effective in the past — and can be effective in the future as well. But to return to the path of diplomacy, the aggression must be halted,” he continued, according to the news agency.
93
+ ''', 1)
94
+ print(summary)
95
+ questions = gemma.get_questions(summary, 3, "Easy")
96
+ print(json.dumps(questions, indent=4))
97
+
__pycache__/Gemma.cpython-312.pyc ADDED
Binary file (6.78 kB). View file
 
__pycache__/News.cpython-312.pyc ADDED
Binary file (3.01 kB). View file
 
app.py CHANGED
@@ -3,7 +3,7 @@ import random
3
  import gradio as gr
4
  from News import News
5
  from Gemma import GemmaLLM
6
- import matplotlib.pyplot as plt
7
 
8
  # %%
9
  class Cooldown:
@@ -123,9 +123,6 @@ with gr.Blocks() as demo:
123
 
124
  options = random.shuffle([answer] + false_answers)
125
 
126
- print("Question:", question)
127
- print(f"Formatted options: {options}")
128
-
129
  return question, options, answer
130
 
131
  def hide_quiz(): ################################### Hide quiz-related components
@@ -142,12 +139,8 @@ with gr.Blocks() as demo:
142
  quiz = [(question, random.sample(distractors + [answer], 4), answer) for question, distractors, answer in quiz]
143
  questions, options, answers = zip(*quiz) if quiz else ([], [], [])
144
 
145
- print("options", len(options))
146
-
147
- quiz = [gr.Radio(label=f"{i + 1}: {questions[i]}", choices=options[i], visible=True) for i in range(len(mcqs))]\
148
- + [gr.Radio(visible=False) for _ in range(10 - len(mcqs))]
149
-
150
- print("quiz", len(quiz))
151
 
152
  submit = gr.Button("Submit Answers", interactive=bool(answers), visible=True)
153
 
@@ -271,25 +264,25 @@ with gr.Blocks() as demo:
271
 
272
  demo.launch()
273
 
274
- def visualize_quiz_answers(answers, *quiz_items):
275
- """
276
- Visualization of correct/incorrect answers from the quiz
277
- """
278
- if not answers:
279
- return None
280
 
281
- correct = 0
282
 
283
- for user_ans, question in zip(answers, quiz_items):
284
- if user_ans == question["correct_answer"]:
285
- correct += 1
286
 
287
- incorrect = len(answers) - correct
288
 
289
- fig, ax = plt.subplots()
290
- ax.bar(["Correct", "Incorrect"], [correct, incorrect])
291
- ax.set_ylabel("Questions")
292
- ax.set_title("Quiz Score Summary")
293
 
294
- return fig
295
 
 
3
  import gradio as gr
4
  from News import News
5
  from Gemma import GemmaLLM
6
+ # import matplotlib.pyplot as plt
7
 
8
  # %%
9
  class Cooldown:
 
123
 
124
  options = random.shuffle([answer] + false_answers)
125
 
 
 
 
126
  return question, options, answer
127
 
128
  def hide_quiz(): ################################### Hide quiz-related components
 
139
  quiz = [(question, random.sample(distractors + [answer], 4), answer) for question, distractors, answer in quiz]
140
  questions, options, answers = zip(*quiz) if quiz else ([], [], [])
141
 
142
+ quiz = [gr.Radio(label=f"{i + 1}: {questions[i]}", choices=options[i], visible=True) for i in range(len(mcqs))]
143
+ quiz += [gr.Radio(visible=False) for _ in range(10 - len(mcqs))]
 
 
 
 
144
 
145
  submit = gr.Button("Submit Answers", interactive=bool(answers), visible=True)
146
 
 
264
 
265
  demo.launch()
266
 
267
+ # def visualize_quiz_answers(answers, *quiz_items):
268
+ # """
269
+ # Visualization of correct/incorrect answers from the quiz
270
+ # """
271
+ # if not answers:
272
+ # return None
273
 
274
+ # correct = 0
275
 
276
+ # for user_ans, question in zip(answers, quiz_items):
277
+ # if user_ans == question["correct_answer"]:
278
+ # correct += 1
279
 
280
+ # incorrect = len(answers) - correct
281
 
282
+ # fig, ax = plt.subplots()
283
+ # ax.bar(["Correct", "Incorrect"], [correct, incorrect])
284
+ # ax.set_ylabel("Questions")
285
+ # ax.set_title("Quiz Score Summary")
286
 
287
+ # return fig
288