File size: 3,607 Bytes
0d74b25
6f77195
 
0d74b25
6f77195
 
9e2faf0
0d74b25
6f77195
0d74b25
 
20fc52e
0d74b25
 
1638eb1
0d74b25
 
 
 
 
 
 
1638eb1
6f77195
 
0d74b25
6f77195
0d74b25
 
 
 
 
 
 
 
 
 
 
6f77195
0d74b25
 
 
 
 
 
6f77195
0d74b25
 
 
 
 
 
 
 
 
 
6f77195
 
0d74b25
6f77195
0d74b25
1638eb1
0d74b25
 
1638eb1
20fc52e
0d74b25
 
 
 
 
 
 
 
 
 
 
 
 
20fc52e
04fe410
0d74b25
04fe410
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# app.py – CodeNyx (StarCoderBase-1B) – full generation & FIM
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread

# ------------------------------------------------------------------
# 1. 1 B model – identical to official snippet
# ------------------------------------------------------------------
CHECKPOINT = "bigcode/starcoderbase-1b"
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
model     = AutoModelForCausalLM.from_pretrained(CHECKPOINT).to(DEVICE)

# ------------------------------------------------------------------
# 2. Branding
# ------------------------------------------------------------------
BOT_NAME = "CodeNyx"
SYSTEM   = (
    f"You are {BOT_NAME}, an expert coding assistant trained on The Stack v1.2. "
    "Return only complete, runnable code with a short comment."
)

# ------------------------------------------------------------------
# 3. Helper: full generation
# ------------------------------------------------------------------
def full_generation(prompt: str):
    inputs = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=512,
            temperature=0.2,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# ------------------------------------------------------------------
# 4. Helper: fill-in-the-middle (FIM)
# ------------------------------------------------------------------
def fim_generation(prefix: str, suffix: str):
    fim_text = (
        f"<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>"
    )
    inputs = tokenizer.encode(fim_text, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=256,
            temperature=0.2,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# ------------------------------------------------------------------
# 5. Gradio interface
# ------------------------------------------------------------------
with gr.Blocks(title=f"{BOT_NAME} – StarCoderBase-1B") as demo:
    gr.Markdown(f"""
    # πŸ€– {BOT_NAME} – powered by StarCoderBase-1B (The Stack v1.2)
    *Ask for full code or let the model **fill-in-the-middle** of any snippet.*
    """)

    with gr.Tab("Full Generation"):
        prompt_in   = gr.Textbox(label="Prompt", lines=3, placeholder="def fibonacci(n):")
        full_out    = gr.Code(label="Generated Code", language="python")
        gen_btn     = gr.Button("Generate")
        gen_btn.click(full_generation, prompt_in, full_out)

    with gr.Tab("Fill-in-the-Middle"):
        with gr.Row():
            prefix_in = gr.Textbox(label="Prefix", lines=3, placeholder="def fibonacci(n):\n    ")
            suffix_in = gr.Textbox(label="Suffix", lines=3, placeholder="\n    return result")
        fim_out = gr.Code(label="Completed Code", language="python")
        fim_btn = gr.Button("Complete")
        fim_btn.click(fim_generation, [prefix_in, suffix_in], fim_out)

# ------------------------------------------------------------------
# 6. Launch
# ------------------------------------------------------------------
if __name__ == "__main__":
    demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)