CodeNyx / app.py
AryanRathod3097's picture
Update app.py
0d74b25 verified
raw
history blame
3.61 kB
# app.py – CodeNyx (StarCoderBase-1B) – full generation & FIM
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
# ------------------------------------------------------------------
# 1. 1 B model – identical to official snippet
# ------------------------------------------------------------------
CHECKPOINT = "bigcode/starcoderbase-1b"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
model = AutoModelForCausalLM.from_pretrained(CHECKPOINT).to(DEVICE)
# ------------------------------------------------------------------
# 2. Branding
# ------------------------------------------------------------------
BOT_NAME = "CodeNyx"
SYSTEM = (
f"You are {BOT_NAME}, an expert coding assistant trained on The Stack v1.2. "
"Return only complete, runnable code with a short comment."
)
# ------------------------------------------------------------------
# 3. Helper: full generation
# ------------------------------------------------------------------
def full_generation(prompt: str):
inputs = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=512,
temperature=0.2,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# ------------------------------------------------------------------
# 4. Helper: fill-in-the-middle (FIM)
# ------------------------------------------------------------------
def fim_generation(prefix: str, suffix: str):
fim_text = (
f"<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>"
)
inputs = tokenizer.encode(fim_text, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model.generate(
inputs,
max_new_tokens=256,
temperature=0.2,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# ------------------------------------------------------------------
# 5. Gradio interface
# ------------------------------------------------------------------
with gr.Blocks(title=f"{BOT_NAME} – StarCoderBase-1B") as demo:
gr.Markdown(f"""
# πŸ€– {BOT_NAME} – powered by StarCoderBase-1B (The Stack v1.2)
*Ask for full code or let the model **fill-in-the-middle** of any snippet.*
""")
with gr.Tab("Full Generation"):
prompt_in = gr.Textbox(label="Prompt", lines=3, placeholder="def fibonacci(n):")
full_out = gr.Code(label="Generated Code", language="python")
gen_btn = gr.Button("Generate")
gen_btn.click(full_generation, prompt_in, full_out)
with gr.Tab("Fill-in-the-Middle"):
with gr.Row():
prefix_in = gr.Textbox(label="Prefix", lines=3, placeholder="def fibonacci(n):\n ")
suffix_in = gr.Textbox(label="Suffix", lines=3, placeholder="\n return result")
fim_out = gr.Code(label="Completed Code", language="python")
fim_btn = gr.Button("Complete")
fim_btn.click(fim_generation, [prefix_in, suffix_in], fim_out)
# ------------------------------------------------------------------
# 6. Launch
# ------------------------------------------------------------------
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)