File size: 2,152 Bytes
b4cd3cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688d8ab
b4cd3cb
688d8ab
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import random
import numpy as np
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
import spaces

# Available models for content generation
MODEL_OPTIONS_CONTENT = {
    "MX02 (mixed)": {
        "model_id": "alakxender/flan-t5-corpora-mixed",
        "default_prompt": "Tell me about: "
    },
    "MX01 (articles)": {
        "model_id": "alakxender/flan-t5-news-articles",
        "default_prompt": "Create an article about: "
    }
}

# Cache for loaded models/tokenizers
MODEL_CACHE = {}

def get_model_and_tokenizer(model_choice):
    model_dir = MODEL_OPTIONS_CONTENT[model_choice]["model_id"]
    if model_dir not in MODEL_CACHE:
        print(f"Loading model: {model_dir}")
        tokenizer = T5Tokenizer.from_pretrained(model_dir)
        model = T5ForConditionalGeneration.from_pretrained(model_dir)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Moving model to device: {device}")
        model.to(device)
        MODEL_CACHE[model_dir] = (tokenizer, model)
    return MODEL_CACHE[model_dir]

def get_default_prompt(model_choice):
    return MODEL_OPTIONS_CONTENT[model_choice]["default_prompt"]

@spaces.GPU()
def generate_content(prompt, max_new_tokens, num_beams, repetition_penalty, no_repeat_ngram_size, do_sample, model_choice):

    tokenizer, model = get_model_and_tokenizer(model_choice)
    prompt = get_default_prompt(model_choice) + prompt

    inputs = tokenizer(prompt, return_tensors="pt")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        num_beams=num_beams,
        repetition_penalty=repetition_penalty,
        no_repeat_ngram_size=no_repeat_ngram_size,
        do_sample=do_sample,
        early_stopping=False
    )
    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Trim to the last period
    if '.' in output_text:
        last_period = output_text.rfind('.')
        output_text = output_text[:last_period+1]
    return output_text