Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import numpy as np | |
import torch | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
import spaces | |
# Available models | |
MODEL_OPTIONS_INSTRUCT = { | |
"I2 Model": "alakxender/flan-t5-base-alpaca-dv5", | |
"I1 Model": "alakxender/flan-t5-base-alpaca-dv" | |
} | |
# Cache for loaded models/tokenizers | |
MODEL_CACHE = {} | |
def get_model_and_tokenizer(model_dir): | |
if model_dir not in MODEL_CACHE: | |
print(f"Loading model: {model_dir}") | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
model = T5ForConditionalGeneration.from_pretrained(model_dir) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Moving model to device: {device}") | |
model.to(device) | |
MODEL_CACHE[model_dir] = (tokenizer, model) | |
return MODEL_CACHE[model_dir] | |
max_input_length = 256 | |
max_output_length = 256 | |
def generate_response(instruction, input_text, seed, use_sampling, model_choice,max_tokens,num_beams): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(seed) | |
model_dir = MODEL_OPTIONS_INSTRUCT[model_choice] | |
tokenizer, model = get_model_and_tokenizer(model_dir) | |
combined_input = f"{instruction.strip()} {input_text.strip()}" if input_text else instruction.strip() | |
inputs = tokenizer( | |
combined_input, | |
return_tensors="pt", | |
truncation=True, | |
max_length=max_input_length | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
gen_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_new_tokens":max_tokens, | |
"max_length": max_output_length, | |
"no_repeat_ngram_size": 3, | |
"repetition_penalty": 1.5, | |
} | |
if use_sampling: | |
gen_kwargs.update({ | |
"do_sample": True, | |
"temperature": 0.1, | |
"num_return_sequences": 1, | |
"num_beams": 1, | |
}) | |
else: | |
gen_kwargs.update({ | |
"num_beams": 8, | |
"do_sample": False, | |
"early_stopping": True, | |
}) | |
with torch.no_grad(): | |
outputs = model.generate(**gen_kwargs) | |
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Trim to the last period | |
if '.' in decoded_output: | |
last_period = decoded_output.rfind('.') | |
decoded_output = decoded_output[:last_period+1] | |
return decoded_output |