Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,592 Bytes
423fa16 cd9f86b 423fa16 877ec22 423fa16 cd9f86b 423fa16 cd9f86b 423fa16 cd9f86b 423fa16 e84ce77 423fa16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import random
import numpy as np
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
import spaces
# Available models
MODEL_OPTIONS_INSTRUCT = {
"I2 Model": "alakxender/flan-t5-base-alpaca-dv5",
"I1 Model": "alakxender/flan-t5-base-alpaca-dv"
}
# Cache for loaded models/tokenizers
MODEL_CACHE = {}
def get_model_and_tokenizer(model_dir):
if model_dir not in MODEL_CACHE:
print(f"Loading model: {model_dir}")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = T5ForConditionalGeneration.from_pretrained(model_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Moving model to device: {device}")
model.to(device)
MODEL_CACHE[model_dir] = (tokenizer, model)
return MODEL_CACHE[model_dir]
max_input_length = 256
max_output_length = 256
@spaces.GPU()
def generate_response(instruction, input_text, seed, use_sampling, model_choice,max_tokens,num_beams):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
model_dir = MODEL_OPTIONS_INSTRUCT[model_choice]
tokenizer, model = get_model_and_tokenizer(model_dir)
combined_input = f"{instruction.strip()} {input_text.strip()}" if input_text else instruction.strip()
inputs = tokenizer(
combined_input,
return_tensors="pt",
truncation=True,
max_length=max_input_length
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = {k: v.to(device) for k, v in inputs.items()}
gen_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens":max_tokens,
"max_length": max_output_length,
"no_repeat_ngram_size": 3,
"repetition_penalty": 1.5,
}
if use_sampling:
gen_kwargs.update({
"do_sample": True,
"temperature": 0.1,
"num_return_sequences": 1,
"num_beams": 1,
})
else:
gen_kwargs.update({
"num_beams": 8,
"do_sample": False,
"early_stopping": True,
})
with torch.no_grad():
outputs = model.generate(**gen_kwargs)
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Trim to the last period
if '.' in decoded_output:
last_period = decoded_output.rfind('.')
decoded_output = decoded_output[:last_period+1]
return decoded_output |