t5-ft-demo / title_gen.py
alakxender's picture
a
82b0ab8
import random
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import spaces
# Available models
MODEL_OPTIONS_TITLE = {
"V6 Model": "alakxender/t5-divehi-title-generation-v6",
"XS Model": "alakxender/t5-dhivehi-title-generation-xs"
}
# Cache for loaded models/tokenizers
MODEL_CACHE = {}
def get_model_and_tokenizer(model_dir):
if model_dir not in MODEL_CACHE:
print(f"Loading model: {model_dir}")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Moving model to device: {device}")
model.to(device)
MODEL_CACHE[model_dir] = (tokenizer, model)
return MODEL_CACHE[model_dir]
prefix = "2title: "
max_input_length = 512
max_target_length = 32
@spaces.GPU()
def generate_title(content, seed, use_sampling, model_choice):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
model_dir = MODEL_OPTIONS_TITLE[model_choice]
tokenizer, model = get_model_and_tokenizer(model_dir)
input_text = prefix + content.strip()
inputs = tokenizer(
input_text,
max_length=max_input_length,
truncation=True,
return_tensors="pt"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = {k: v.to(device) for k, v in inputs.items()}
gen_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_length": max_target_length,
"no_repeat_ngram_size": 2,
}
if use_sampling:
gen_kwargs.update({
"do_sample": True,
"temperature": 1.0,
"top_p": 0.95,
"num_return_sequences": 1,
})
else:
gen_kwargs.update({
"num_beams": 4,
"do_sample": False,
"early_stopping": True,
})
with torch.no_grad():
outputs = model.generate(**gen_kwargs)
title = tokenizer.decode(outputs[0], skip_special_tokens=True)
return title