Spaces:
Runtime error
Runtime error
# app.py β CodeNyx (StarCoderBase-1B) β full generation & FIM | |
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from threading import Thread | |
# ------------------------------------------------------------------ | |
# 1. 1 B model β identical to official snippet | |
# ------------------------------------------------------------------ | |
CHECKPOINT = "bigcode/starcoderbase-1b" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT) | |
model = AutoModelForCausalLM.from_pretrained(CHECKPOINT).to(DEVICE) | |
# ------------------------------------------------------------------ | |
# 2. Branding | |
# ------------------------------------------------------------------ | |
BOT_NAME = "CodeNyx" | |
SYSTEM = ( | |
f"You are {BOT_NAME}, an expert coding assistant trained on The Stack v1.2. " | |
"Return only complete, runnable code with a short comment." | |
) | |
# ------------------------------------------------------------------ | |
# 3. Helper: full generation | |
# ------------------------------------------------------------------ | |
def full_generation(prompt: str): | |
inputs = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=512, | |
temperature=0.2, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# ------------------------------------------------------------------ | |
# 4. Helper: fill-in-the-middle (FIM) | |
# ------------------------------------------------------------------ | |
def fim_generation(prefix: str, suffix: str): | |
fim_text = ( | |
f"<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>" | |
) | |
inputs = tokenizer.encode(fim_text, return_tensors="pt").to(DEVICE) | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
max_new_tokens=256, | |
temperature=0.2, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# ------------------------------------------------------------------ | |
# 5. Gradio interface | |
# ------------------------------------------------------------------ | |
with gr.Blocks(title=f"{BOT_NAME} β StarCoderBase-1B") as demo: | |
gr.Markdown(f""" | |
# π€ {BOT_NAME} β powered by StarCoderBase-1B (The Stack v1.2) | |
*Ask for full code or let the model **fill-in-the-middle** of any snippet.* | |
""") | |
with gr.Tab("Full Generation"): | |
prompt_in = gr.Textbox(label="Prompt", lines=3, placeholder="def fibonacci(n):") | |
full_out = gr.Code(label="Generated Code", language="python") | |
gen_btn = gr.Button("Generate") | |
gen_btn.click(full_generation, prompt_in, full_out) | |
with gr.Tab("Fill-in-the-Middle"): | |
with gr.Row(): | |
prefix_in = gr.Textbox(label="Prefix", lines=3, placeholder="def fibonacci(n):\n ") | |
suffix_in = gr.Textbox(label="Suffix", lines=3, placeholder="\n return result") | |
fim_out = gr.Code(label="Completed Code", language="python") | |
fim_btn = gr.Button("Complete") | |
fim_btn.click(fim_generation, [prefix_in, suffix_in], fim_out) | |
# ------------------------------------------------------------------ | |
# 6. Launch | |
# ------------------------------------------------------------------ | |
if __name__ == "__main__": | |
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True) |