File size: 3,889 Bytes
64fde65
c247815
eebf9a7
64fde65
 
 
 
 
 
 
 
 
c247815
 
64fde65
 
 
 
 
 
 
c247815
64fde65
 
eebf9a7
64fde65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from transformers import AutoTokenizer, Gemma3ForCausalLM
from huggingface_hub import login
import spaces
import torch
import json
import os

__export__ = ["GemmaLLM"]

class GemmaLLM:
    
    def __init__(self):
        login(token=os.environ.get("GEMMA_TOKEN"))
        
        model_id = "google/gemma-3-1b-it"
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = Gemma3ForCausalLM.from_pretrained(
            model_id,
            device_map="cuda" if torch.cuda.is_available() else "cpu",
            torch_dtype=torch.float16,
            # token=os.environ.get("GEMMA_TOKEN"),
        ).eval()

    @spaces.GPU
    def generate(self, message) -> str:
        inputs = self.tokenizer.apply_chat_template(
            message,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(self.model.device)

        input_length = inputs["input_ids"].shape[1]

        with torch.inference_mode():
            outputs = self.model.generate(**inputs, max_new_tokens=1024)[0][input_length:]
            outputs = self.tokenizer.decode(outputs, skip_special_tokens=True)
        
        return outputs

    def _get_summary_message(self, article, num_paragraphs) -> dict:
        
        summarize = "You are a helpful assistant. Your main task is to summarize articles. You will be given an article that you will generate a summary for. The summary should include all the key points of the article. ONLY RESPOND WITH THE SUMMARY!!!"

        summary = f"Summarize the data in the following JSON into {num_paragraphs} paragraph(s) so that it is easy to read and understand:\n"

        message = [{"role": "system", "content": [{"type": "text", "text": summarize}]},
                   {"role": "user",   "content": [{"type": "text", "text": summary + json.dumps(article, indent=4)}]}]

        return message

    def get_summary(self, article, num_paragraphs) -> str:
        message = self._get_summary_message(article, num_paragraphs)
        summary = self.generate(message)
        
        return summary

    def _get_questions_message(self, summary, num_questions, difficulty) -> dict:
        question = f"""
            You are a helpful assistant. Your main task is to generate {num_questions} multiple choice questions from an article. Respond in the following JSON structure and schema:\n\njson\n```{json.dumps(list((
            dict(question=str.__name__, correct_answer=str.__name__, false_answers=[str.__name__, str.__name__, str.__name__]),
            dict(question=str.__name__, correct_answer=str.__name__, false_answers=[str.__name__, str.__name__, str.__name__]),
            dict(question=str.__name__, correct_answer=str.__name__, false_answers=[str.__name__, str.__name__, str.__name__]))), indent=4)}```\n\nThere should only be {num_questions} questions generated. Each question should only have 3 false answers and 1 correct answer. The correct answer should be the most relevant answer based on the context derived from the article. False answers should not contain the correct answer. False answers should contain false information but also be reasonably plausible for answering the question. ONLY RESPOND WITH RAW JSON!!!
        """

        questions = f"Generate {difficulty.lower()} questions and answers from the following article:\n"

        message = [{"role": "system", "content": [{"type": "text", "text": question}]},
                   {"role": "user", "content": [{"type": "text", "text": questions + summary}]}]

        return message
    
    def get_questions(self, summary, num_questions, difficulty) -> dict:
        message = self._get_questions_message(summary, num_questions, difficulty)
        questions = self.generate(message)

        return json.loads(questions.strip("```").replace("json\n", ""))