|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
def predict_code(input): |
|
tokenizer = AutoTokenizer.from_pretrained('JetBrains/Mellum-4b-base') |
|
model = AutoModelForCausalLM.from_pretrained('JetBrains/Mellum-4b-base') |
|
encoded_input = tokenizer(input, return_tensors='pt', return_token_type_ids=False) |
|
input_len = len(encoded_input["input_ids"][0]) |
|
out = model.generate( |
|
**encoded_input, |
|
max_new_tokens=100, |
|
) |
|
|
|
prediction = tokenizer.decode(out[0][input_len:]) |
|
return prediction |
|
|
|
def run(input): |
|
return predict_code(input) |
|
|
|
app = gr.Interface( |
|
fn=run, |
|
inputs=["text"], |
|
outputs=["text"] |
|
) |
|
|
|
app.launch() |
|
|