File size: 3,645 Bytes
2a7fe05
9cc2d55
 
 
a0e2cb7
9cc2d55
 
 
a0e2cb7
9cc2d55
 
 
 
 
a0e2cb7
9cc2d55
987b437
a0e2cb7
9cc2d55
 
987b437
 
 
 
 
 
 
 
 
 
 
 
 
9cc2d55
 
 
 
 
987b437
 
 
 
 
 
9cc2d55
 
 
 
987b437
 
 
 
9cc2d55
 
 
 
 
 
 
 
 
 
 
 
a0e2cb7
987b437
 
 
 
 
 
 
 
 
 
9cc2d55
987b437
 
9cc2d55
 
a0e2cb7
9cc2d55
 
 
 
 
 
a0e2cb7
987b437
 
 
 
 
a0e2cb7
 
9cc2d55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0e2cb7
9cc2d55
a0e2cb7
9cc2d55
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import spaces
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer

# Load model and tokenizer
model_path = "apple/DiffuCoder-7B-cpGRPO"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModel.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
).to(device).eval()

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
tokenizer.eos_token = "<|im_end|>"  # Set EOS token

@spaces.GPU
def generate_code(query, temperature=0.4, top_p=0.95, max_new_tokens=256):
    # Format prompt using ChatML template
    messages = [
        {"role": "system", "content": "You are a helpful coding assistant."},
        {"role": "user", "content": query.strip()},
        {"role": "assistant", "content": ""}  # Start of assistant response
    ]
    
    # Apply chat template
    prompt = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)
    
    # Calculate initial prompt length
    initial_prompt_len = input_ids.shape[1]
    
    # Track EOS status
    eos_detected = False
    
    # Generate with token streaming
    TOKEN_PER_STEP = 1
    steps = max_new_tokens // TOKEN_PER_STEP
    
    for i in range(steps):
        if eos_detected:
            break
            
        output = model.diffusion_generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=TOKEN_PER_STEP,
            output_history=True,
            return_dict_in_generate=True,
            steps=1,
            temperature=temperature,
            top_p=top_p,
            alg="entropy",
            alg_temp=0.,
        )
        
        # Get all new tokens (after initial prompt)
        new_tokens = output.sequences[0, initial_prompt_len:]
        
        # Check for EOS token
        if tokenizer.eos_token_id in new_tokens:
            eos_index = (new_tokens == tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
            if eos_index.numel() > 0:
                new_tokens = new_tokens[:eos_index[0]]
                eos_detected = True
        
        # Decode new tokens
        new_text = tokenizer.decode(
            new_tokens, 
            skip_special_tokens=True
        )
        
        # Update input for next step
        input_ids = output.sequences
        attention_mask = torch.cat([
            attention_mask, 
            torch.ones(1, 1, dtype=attention_mask.dtype, device=device)
        ], dim=1)
        
        # Yield current output
        yield new_text.split('<|dlm_pad|>')[0].strip()
        
        if eos_detected:
            break

# Create Gradio interface
demo = gr.Interface(
    fn=generate_code,
    inputs=[
        gr.Textbox(label="Code Request", lines=3, 
                  placeholder="Describe the code you want..."),
        gr.Slider(0.1, 1.0, value=0.4, label="Temperature"),
        gr.Slider(0.5, 1.0, value=0.95, label="Top-p"),
        gr.Slider(32, 512, value=256, step=32, label="Max Tokens")
    ],
    outputs=gr.Textbox(label="Generated Code", lines=10),
    title="🧠 DiffuCoder Code Generator",
    description="Generate code with Apple's DiffuCoder-7B model",
    examples=[
        ["Write a Python function to calculate factorial"],
        ["Create a function to merge two sorted lists"],
        ["How to reverse a string in JavaScript?"]
    ]
)

# Run the demo
if __name__ == "__main__":
    demo.queue().launch()