import random import numpy as np import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import spaces # Available models MODEL_OPTIONS_TRANSLATE = { #"Flan-T5-A-Dhivehi-Latin Model": "alakxender/flan-t5-base-dhivehi-en-latin-v2", "Flan-T5-B-Dhivehi-Latin Model": "alakxender/flan-t5-base-dhivehi-en-latin", "MT5-B-Dhivehi-English Model": "alakxender/mt5-base-dv-en", "MT5-B1-Dhivehi-English Model": "alakxender/mt5-base-dv-en-md", "@politecat314-Dhivehi-Latin Model": "politecat314/flan-t5-base-dv2latin-mihaaru" } # Cache for loaded models/tokenizers MODEL_CACHE = {} def get_model_and_tokenizer(model_dir): if model_dir not in MODEL_CACHE: print(f"Loading model: {model_dir}") tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Moving model to device: {device}") model.to(device) MODEL_CACHE[model_dir] = (tokenizer, model) return MODEL_CACHE[model_dir] max_input_length = 128 max_output_length = 128 @spaces.GPU() def translate(instruction, input_text, model_choice, max_new_tokens=128, num_beams=4, repetition_penalty=1.2, no_repeat_ngram_size=3): model_dir = MODEL_OPTIONS_TRANSLATE[model_choice] tokenizer, model = get_model_and_tokenizer(model_dir) combined_input = f"{instruction.strip()} {input_text.strip()}" if input_text else instruction.strip() inputs = tokenizer( combined_input, return_tensors="pt", truncation=True, max_length=max_input_length ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") inputs = {k: v.to(device) for k, v in inputs.items()} gen_kwargs = { **inputs, "max_length":max_new_tokens, "min_length":10, "num_beams":num_beams, "early_stopping":True, "no_repeat_ngram_size":no_repeat_ngram_size, "repetition_penalty":repetition_penalty, "do_sample":False, "pad_token_id":tokenizer.pad_token_id, "eos_token_id":tokenizer.eos_token_id } with torch.no_grad(): outputs = model.generate(**gen_kwargs) decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) return decoded_output