mrfakename commited on
Commit
987b437
·
verified ·
1 Parent(s): 9cc2d55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -14
app.py CHANGED
@@ -14,27 +14,42 @@ model = AutoModel.from_pretrained(
14
  ).to(device).eval()
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
 
17
 
18
  @spaces.GPU
19
  def generate_code(query, temperature=0.4, top_p=0.95, max_new_tokens=256):
20
- # Format prompt using chat template
21
- prompt = f"""<|im_start|>system
22
- You are a helpful coding assistant.<|im_end|>
23
- <|im_start|>user
24
- {query.strip()}<|im_end|>
25
- <|im_start|>assistant
26
- """
 
 
 
 
 
 
27
 
28
  inputs = tokenizer(prompt, return_tensors="pt")
29
  input_ids = inputs.input_ids.to(device)
30
  attention_mask = inputs.attention_mask.to(device)
31
 
 
 
 
 
 
 
32
  # Generate with token streaming
33
  TOKEN_PER_STEP = 1
34
  steps = max_new_tokens // TOKEN_PER_STEP
35
 
36
- full_output = ""
37
- for _ in range(steps):
 
 
38
  output = model.diffusion_generate(
39
  input_ids,
40
  attention_mask=attention_mask,
@@ -48,9 +63,19 @@ You are a helpful coding assistant.<|im_end|>
48
  alg_temp=0.,
49
  )
50
 
 
 
 
 
 
 
 
 
 
 
51
  # Decode new tokens
52
- new_tokens = tokenizer.decode(
53
- output.sequences[0, -TOKEN_PER_STEP:].tolist(),
54
  skip_special_tokens=True
55
  )
56
 
@@ -61,9 +86,11 @@ You are a helpful coding assistant.<|im_end|>
61
  torch.ones(1, 1, dtype=attention_mask.dtype, device=device)
62
  ], dim=1)
63
 
64
- # Append to full output and stream
65
- full_output += new_tokens
66
- yield full_output.split('<|dlm_pad|>')[0].strip()
 
 
67
 
68
  # Create Gradio interface
69
  demo = gr.Interface(
 
14
  ).to(device).eval()
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
17
+ tokenizer.eos_token = "<|im_end|>" # Set EOS token
18
 
19
  @spaces.GPU
20
  def generate_code(query, temperature=0.4, top_p=0.95, max_new_tokens=256):
21
+ # Format prompt using ChatML template
22
+ messages = [
23
+ {"role": "system", "content": "You are a helpful coding assistant."},
24
+ {"role": "user", "content": query.strip()},
25
+ {"role": "assistant", "content": ""} # Start of assistant response
26
+ ]
27
+
28
+ # Apply chat template
29
+ prompt = tokenizer.apply_chat_template(
30
+ messages,
31
+ tokenize=False,
32
+ add_generation_prompt=True
33
+ )
34
 
35
  inputs = tokenizer(prompt, return_tensors="pt")
36
  input_ids = inputs.input_ids.to(device)
37
  attention_mask = inputs.attention_mask.to(device)
38
 
39
+ # Calculate initial prompt length
40
+ initial_prompt_len = input_ids.shape[1]
41
+
42
+ # Track EOS status
43
+ eos_detected = False
44
+
45
  # Generate with token streaming
46
  TOKEN_PER_STEP = 1
47
  steps = max_new_tokens // TOKEN_PER_STEP
48
 
49
+ for i in range(steps):
50
+ if eos_detected:
51
+ break
52
+
53
  output = model.diffusion_generate(
54
  input_ids,
55
  attention_mask=attention_mask,
 
63
  alg_temp=0.,
64
  )
65
 
66
+ # Get all new tokens (after initial prompt)
67
+ new_tokens = output.sequences[0, initial_prompt_len:]
68
+
69
+ # Check for EOS token
70
+ if tokenizer.eos_token_id in new_tokens:
71
+ eos_index = (new_tokens == tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
72
+ if eos_index.numel() > 0:
73
+ new_tokens = new_tokens[:eos_index[0]]
74
+ eos_detected = True
75
+
76
  # Decode new tokens
77
+ new_text = tokenizer.decode(
78
+ new_tokens,
79
  skip_special_tokens=True
80
  )
81
 
 
86
  torch.ones(1, 1, dtype=attention_mask.dtype, device=device)
87
  ], dim=1)
88
 
89
+ # Yield current output
90
+ yield new_text.split('<|dlm_pad|>')[0].strip()
91
+
92
+ if eos_detected:
93
+ break
94
 
95
  # Create Gradio interface
96
  demo = gr.Interface(