mellum / app.py
Asorano's picture
Update app.py
e24e14f verified
raw
history blame contribute delete
697 Bytes
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
def predict_code(input):
tokenizer = AutoTokenizer.from_pretrained('JetBrains/Mellum-4b-base')
model = AutoModelForCausalLM.from_pretrained('JetBrains/Mellum-4b-base')
encoded_input = tokenizer(input, return_tensors='pt', return_token_type_ids=False)
input_len = len(encoded_input["input_ids"][0])
out = model.generate(
**encoded_input,
max_new_tokens=100,
)
prediction = tokenizer.decode(out[0][input_len:])
return prediction
def run(input):
return predict_code(input)
app = gr.Interface(
fn=run,
inputs=["text"],
outputs=["text"]
)
app.launch()