mrfakename commited on
Commit
6f96c51
·
verified ·
1 Parent(s): 987b437

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -20
app.py CHANGED
@@ -14,29 +14,29 @@ model = AutoModel.from_pretrained(
14
  ).to(device).eval()
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
17
- tokenizer.eos_token = "<|im_end|>" # Set EOS token
18
 
19
  @spaces.GPU
20
  def generate_code(query, temperature=0.4, top_p=0.95, max_new_tokens=256):
21
- # Format prompt using ChatML template
22
  messages = [
23
  {"role": "system", "content": "You are a helpful coding assistant."},
24
- {"role": "user", "content": query.strip()},
25
- {"role": "assistant", "content": ""} # Start of assistant response
26
  ]
27
 
28
- # Apply chat template
29
  prompt = tokenizer.apply_chat_template(
30
  messages,
31
  tokenize=False,
32
  add_generation_prompt=True
33
  )
34
 
 
35
  inputs = tokenizer(prompt, return_tensors="pt")
36
  input_ids = inputs.input_ids.to(device)
37
  attention_mask = inputs.attention_mask.to(device)
38
 
39
- # Calculate initial prompt length
40
  initial_prompt_len = input_ids.shape[1]
41
 
42
  # Track EOS status
@@ -44,7 +44,10 @@ def generate_code(query, temperature=0.4, top_p=0.95, max_new_tokens=256):
44
 
45
  # Generate with token streaming
46
  TOKEN_PER_STEP = 1
47
- steps = max_new_tokens // TOKEN_PER_STEP
 
 
 
48
 
49
  for i in range(steps):
50
  if eos_detected:
@@ -63,20 +66,24 @@ def generate_code(query, temperature=0.4, top_p=0.95, max_new_tokens=256):
63
  alg_temp=0.,
64
  )
65
 
66
- # Get all new tokens (after initial prompt)
67
- new_tokens = output.sequences[0, initial_prompt_len:]
68
 
69
- # Check for EOS token
70
- if tokenizer.eos_token_id in new_tokens:
71
- eos_index = (new_tokens == tokenizer.eos_token_id).nonzero(as_tuple=True)[0]
72
- if eos_index.numel() > 0:
73
- new_tokens = new_tokens[:eos_index[0]]
74
- eos_detected = True
 
 
 
75
 
76
- # Decode new tokens
77
  new_text = tokenizer.decode(
78
- new_tokens,
79
- skip_special_tokens=True
 
80
  )
81
 
82
  # Update input for next step
@@ -86,8 +93,11 @@ def generate_code(query, temperature=0.4, top_p=0.95, max_new_tokens=256):
86
  torch.ones(1, 1, dtype=attention_mask.dtype, device=device)
87
  ], dim=1)
88
 
89
- # Yield current output
90
- yield new_text.split('<|dlm_pad|>')[0].strip()
 
 
 
91
 
92
  if eos_detected:
93
  break
 
14
  ).to(device).eval()
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
17
+ tokenizer.eos_token = "<|im_end|>"
18
 
19
  @spaces.GPU
20
  def generate_code(query, temperature=0.4, top_p=0.95, max_new_tokens=256):
21
+ # Format prompt using chat template
22
  messages = [
23
  {"role": "system", "content": "You are a helpful coding assistant."},
24
+ {"role": "user", "content": query.strip()}
 
25
  ]
26
 
27
+ # Apply chat template - this creates the prompt but doesn't include assistant response
28
  prompt = tokenizer.apply_chat_template(
29
  messages,
30
  tokenize=False,
31
  add_generation_prompt=True
32
  )
33
 
34
+ # Tokenize only the prompt (without any assistant response)
35
  inputs = tokenizer(prompt, return_tensors="pt")
36
  input_ids = inputs.input_ids.to(device)
37
  attention_mask = inputs.attention_mask.to(device)
38
 
39
+ # Calculate initial prompt length - this is where the assistant response will start
40
  initial_prompt_len = input_ids.shape[1]
41
 
42
  # Track EOS status
 
44
 
45
  # Generate with token streaming
46
  TOKEN_PER_STEP = 1
47
+ steps = min(max_new_tokens // TOKEN_PER_STEP, 512) # Limit to max 512 steps
48
+
49
+ # This will accumulate only the assistant's response
50
+ assistant_response = ""
51
 
52
  for i in range(steps):
53
  if eos_detected:
 
66
  alg_temp=0.,
67
  )
68
 
69
+ # Get only the new tokens generated in this step
70
+ new_token_ids = output.sequences[0, -TOKEN_PER_STEP:]
71
 
72
+ # Check for EOS token in the new tokens
73
+ if tokenizer.eos_token_id in new_token_ids:
74
+ # If EOS is found, stop after this token
75
+ eos_detected = True
76
+ # Remove EOS token from output
77
+ new_token_ids = new_token_ids[new_token_ids != tokenizer.eos_token_id]
78
+ if new_token_ids.numel() == 0:
79
+ # Only EOS was generated, nothing to add
80
+ break
81
 
82
+ # Decode only the new tokens
83
  new_text = tokenizer.decode(
84
+ new_token_ids,
85
+ skip_special_tokens=True,
86
+ clean_up_tokenization_spaces=False
87
  )
88
 
89
  # Update input for next step
 
93
  torch.ones(1, 1, dtype=attention_mask.dtype, device=device)
94
  ], dim=1)
95
 
96
+ # Append to assistant response and yield
97
+ assistant_response += new_text
98
+ # Remove any trailing special tokens
99
+ clean_response = assistant_response.replace('<|dlm_pad|>', '').strip()
100
+ yield clean_response
101
 
102
  if eos_detected:
103
  break