Spaces:
Running
Running
Jeff Myers II
commited on
Commit
·
52321a8
1
Parent(s):
eca8be7
Update space
Browse files- Gemma_Model.py +102 -0
- News.py +62 -0
Gemma_Model.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, BitsAndBytesConfig, Gemma3ForCausalLM
|
2 |
+
import torch
|
3 |
+
import json
|
4 |
+
|
5 |
+
__export__ = ["GemmaLLM"]
|
6 |
+
|
7 |
+
class GemmaLLM:
|
8 |
+
|
9 |
+
def __init__(self):
|
10 |
+
model_id = "google/gemma-3-1b-it"
|
11 |
+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
12 |
+
|
13 |
+
self.model = Gemma3ForCausalLM.from_pretrained(
|
14 |
+
model_id,
|
15 |
+
device_map="cpu",
|
16 |
+
quantization_config=quantization_config,
|
17 |
+
low_cpu_mem_usage=True,
|
18 |
+
torch_dtype=torch.float16,
|
19 |
+
).eval()
|
20 |
+
|
21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
22 |
+
|
23 |
+
def generate(self, message) -> str:
|
24 |
+
print("Generating...")
|
25 |
+
inputs = self.tokenizer.apply_chat_template(
|
26 |
+
message,
|
27 |
+
add_generation_prompt=True,
|
28 |
+
tokenize=True,
|
29 |
+
return_dict=True,
|
30 |
+
return_tensors="pt",
|
31 |
+
).to(self.model.device)
|
32 |
+
|
33 |
+
input_length = inputs["input_ids"].shape[1]
|
34 |
+
|
35 |
+
with torch.inference_mode():
|
36 |
+
outputs = self.model.generate(
|
37 |
+
**inputs, max_new_tokens=1024,
|
38 |
+
)[0][input_length:]
|
39 |
+
|
40 |
+
outputs = self.tokenizer.decode(outputs, skip_special_tokens=True)
|
41 |
+
|
42 |
+
print("Completed generating!")
|
43 |
+
|
44 |
+
return outputs
|
45 |
+
|
46 |
+
def get_summary_message(self, article, num_paragraphs) -> dict:
|
47 |
+
|
48 |
+
summarize = "You are a helpful assistant. Your main task is to summarize articles. You will be given an article that you will generate a summary for. The summary should include all the key points of the article. ONLY RESPOND WITH THE SUMMARY!!!"
|
49 |
+
|
50 |
+
summary = f"Summarize the data in the following JSON into {num_paragraphs} paragraph(s) so that it is easy to read and understand:\n"
|
51 |
+
|
52 |
+
message = [
|
53 |
+
{
|
54 |
+
"role": "system",
|
55 |
+
"content": [
|
56 |
+
{"type": "text", "text": summarize},
|
57 |
+
],
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"role": "user",
|
61 |
+
"content": [
|
62 |
+
{"type": "text", "text": summary + json.dumps(article, indent=4)},
|
63 |
+
],
|
64 |
+
},
|
65 |
+
]
|
66 |
+
|
67 |
+
return message
|
68 |
+
|
69 |
+
def get_summary(self, message) -> str:
|
70 |
+
summary = self.generate(message)
|
71 |
+
|
72 |
+
return summary
|
73 |
+
|
74 |
+
def get_questions_message(self, summary, num_questions, difficulty) -> dict:
|
75 |
+
question = f"You are a helpful assistant. Your main task is to generate {num_questions} multiple choice questions from an article. Respond in the following JSON structure and schema:\n\njson\n```{json.dumps(list((
|
76 |
+
dict(question=str.__name__, correct_answer=str.__name__, false_answers=[str.__name__, str.__name__, str.__name__]),
|
77 |
+
dict(question=str.__name__, correct_answer=str.__name__, false_answers=[str.__name__, str.__name__, str.__name__]),
|
78 |
+
dict(question=str.__name__, correct_answer=str.__name__, false_answers=[str.__name__, str.__name__, str.__name__]))), indent=4)}```\n\nThere should only be {num_questions} questions generated. Each question should only have 3 false answers and 1 correct answer. The correct answer should be the most relevant answer based on the context derived from the article. False answers should not contain the correct answer. False answers should contain false information but also be reasonably plausible for answering the question. ONLY RESPOND WITH RAW JSON!!!"
|
79 |
+
|
80 |
+
questions = f"Generate {difficulty} questions and answers from the following article:\n"
|
81 |
+
|
82 |
+
message = [
|
83 |
+
{
|
84 |
+
"role": "system",
|
85 |
+
"content": [
|
86 |
+
{"type": "text", "text": question},
|
87 |
+
],
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"role": "user",
|
91 |
+
"content": [
|
92 |
+
{"type": "text", "text": questions + summary},
|
93 |
+
],
|
94 |
+
},
|
95 |
+
]
|
96 |
+
|
97 |
+
return message
|
98 |
+
|
99 |
+
def get_questions(self, message) -> dict:
|
100 |
+
questions = self.generate(message)
|
101 |
+
|
102 |
+
return json.loads(questions.strip("```").replace("json\n", ""))
|
News.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from newsapi import NewsApiClient
|
2 |
+
from newspaper import Article
|
3 |
+
import os
|
4 |
+
|
5 |
+
__export__ = ["News"]
|
6 |
+
|
7 |
+
class News:
|
8 |
+
__EX_SOURCES__ = {"ABC News", "Bloomberg", "The Hill", "Fox Sports", "Google News"}
|
9 |
+
__CATEGORIES__ = {
|
10 |
+
"business",
|
11 |
+
"entertainment",
|
12 |
+
"general",
|
13 |
+
"health",
|
14 |
+
"science",
|
15 |
+
"sports",
|
16 |
+
"technology"
|
17 |
+
}
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
newsapi_key = os.environ.get("NEWS_API_KEY")
|
21 |
+
self.newsapi = NewsApiClient(api_key=newsapi_key)
|
22 |
+
|
23 |
+
def get_sources(self, category=None):
|
24 |
+
sources = self.newsapi.get_sources(language="en", country="us", category=category)["sources"]
|
25 |
+
sources = {source["name"] for source in sources if source["name"] not in self.__EX_SOURCES__}
|
26 |
+
print(sources)
|
27 |
+
return sources
|
28 |
+
|
29 |
+
|
30 |
+
def get_top_headlines(self, num_headlines=None, category=None):
|
31 |
+
sources = self.get_sources(category=category)
|
32 |
+
|
33 |
+
headlines = self.newsapi.get_top_headlines(
|
34 |
+
sources=", ".join(sources),
|
35 |
+
page_size=num_headlines
|
36 |
+
)["articles"]
|
37 |
+
|
38 |
+
return headlines
|
39 |
+
|
40 |
+
def get_headlines(self, num_headlines=None, query=None):
|
41 |
+
sources = self.get_sources()
|
42 |
+
|
43 |
+
headlines = self.newsapi.get_everything(
|
44 |
+
q=query,
|
45 |
+
sources=", ".join(sources),
|
46 |
+
page_size=num_headlines
|
47 |
+
)["articles"]
|
48 |
+
|
49 |
+
return headlines
|
50 |
+
|
51 |
+
def get_articles_from_headlines(self, headlines):
|
52 |
+
for headline in headlines:
|
53 |
+
article = Article(headline["url"])
|
54 |
+
article.download()
|
55 |
+
article.parse()
|
56 |
+
headline["content"] = article.text
|
57 |
+
# headline["authors"] = article.authors
|
58 |
+
headline["source"] = headline["source"]["name"]
|
59 |
+
del headline["author"]
|
60 |
+
# headline.pop("author", None)
|
61 |
+
|
62 |
+
return headlines
|