Luigi commited on
Commit
3350989
·
1 Parent(s): 9c6147f

bugfix for diverse search with zero diversity_penalty

Browse files
Files changed (1) hide show
  1. app.py +33 -117
app.py CHANGED
@@ -26,30 +26,19 @@ MODEL_LIST = [
26
  ]
27
 
28
  def merge_common_prefixes(suggestions, min_len=2):
29
- """
30
- 合併具有共同前綴的建議:
31
- - 找出所有長度 ≥ min_len 的共同前綴
32
- - 將這些前綴作為新建議,移除原有被合併的項目
33
- """
34
  prefixes = []
35
  to_remove = set()
36
-
37
  for i in range(len(suggestions)):
38
  for j in range(i+1, len(suggestions)):
39
  s1, s2 = suggestions[i], suggestions[j]
40
- # 計算字元級共同前綴
41
  common = ''.join(c1 for c1, c2 in zip(s1, s2) if c1 == c2)
42
  if len(common) >= min_len:
43
  prefixes.append(common)
44
  to_remove.update([s1, s2])
45
-
46
- # 去重前綴
47
  unique_prefixes = []
48
  for p in prefixes:
49
  if p not in unique_prefixes:
50
  unique_prefixes.append(p)
51
-
52
- # 剩下沒有被合併的建議
53
  remainder = [s for s in suggestions if s not in to_remove]
54
  return unique_prefixes + remainder
55
 
@@ -64,28 +53,29 @@ def get_pipeline(model_name):
64
 
65
  @spaces.GPU
66
  def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
67
- """
68
- 使用 Diverse Beam Search 產生 m 條候選:
69
- - num_beams = m
70
- - num_beam_groups, diversity_penalty 可調整多樣性
71
- 之後轉繁體、去重、合併共同前綴後回傳。
72
- """
73
  gen_pipe = get_pipeline(model_name)
74
- outs = gen_pipe(
75
- text,
76
- max_new_tokens=k,
77
- num_beams=m,
78
- num_beam_groups=num_beam_groups,
79
- diversity_penalty=diversity_penalty,
80
- num_return_sequences=m,
81
- do_sample=False,
82
- early_stopping=True
83
- )
84
- # 提取新生成文本,過濾空字串,轉繁體
85
- suggestions = [out["generated_text"][len(text):].strip() for out in outs]
86
- suggestions = [s for s in suggestions if s]
87
- suggestions = [cc.convert(s) for s in suggestions]
88
- # 去除重複,保留順序
 
 
 
 
 
 
 
89
  unique_suggestions = []
90
  for s in suggestions:
91
  if s not in unique_suggestions:
@@ -96,88 +86,18 @@ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
96
 
97
  return update(choices=final_suggestions, value=None)
98
 
 
 
99
 
100
- def append_suggestion(current, choice):
101
- if choice is None:
102
- return current
103
- # 直接插入選中的候選文字
104
- return current + choice
105
-
106
- # 自訂 CSS:模擬經典中文輸入法候選欄樣式,並優化手機響應與自動高度
107
- custom_css = """
108
- #suggestions-bar {
109
- width: 100%;
110
- margin-bottom: 8px;
111
- }
112
  #suggestions-bar .candidate-list {
113
- display: flex;
114
- gap: 8px;
115
- background: #fff;
116
- border: 1px solid #999;
117
- border-radius: 4px;
118
- padding: 6px;
119
- overflow-x: auto;
120
- white-space: nowrap;
121
- }
122
- #suggestions-bar .candidate-list label {
123
- cursor: pointer;
124
- padding: 6px 10px;
125
- font-size: 16px;
126
  }
127
- #suggestions-bar .candidate-list label:hover {
128
- background: #f5f5f5;
129
- }
130
- #suggestions-bar .candidate-list input[type=radio]:checked + label {
131
- background: #e6f7ff;
132
- border: 1px solid #1890ff;
133
- }
134
- #input-box textarea {
135
- width: 100%;
136
- font-size: 16px;
137
- padding: 6px;
138
- box-sizing: border-box;
139
- overflow: hidden;
140
- resize: none;
141
- }
142
- #predict-button {
143
- margin-top: 8px;
144
- width: 100%;
145
- }
146
- /* 手機響應式 */
147
- @media only screen and (max-width: 600px) {
148
- #suggestions-bar .candidate-list label {
149
- padding: 8px;
150
- font-size: 18px;
151
- }
152
- #predict-button {
153
- font-size: 18px;
154
- }
155
- }
156
- """
157
-
158
- # 自動增高腳本
159
- auto_height_js = """
160
- <script>
161
- window.addEventListener('load', () => {
162
- const textarea = document.querySelector('#input-box textarea');
163
- if (!textarea) return;
164
- textarea.style.height = 'auto';
165
- textarea.addEventListener('input', function() {
166
- this.style.height = 'auto';
167
- this.style.height = this.scrollHeight + 'px';
168
- });
169
- });
170
- </script>
171
- """
172
-
173
- with gr.Blocks(css=custom_css) as demo:
174
- gr.HTML(auto_height_js)
175
- gr.Markdown(
176
- "## 🇹🇼 繁體中文 IME 加速器 \
177
- "
178
- "結合小型語言模型與 ZeroGPU,提供即時輸入法風格候選欄。"
179
- )
180
-
181
  with gr.Column():
182
  suggestions = gr.Radio(
183
  [], label="", interactive=True, type="value",
@@ -188,14 +108,11 @@ with gr.Blocks(css=custom_css) as demo:
188
  lines=1, max_lines=20, elem_id="input-box"
189
  )
190
 
191
- # 永遠顯示預測按鈕
192
  with gr.Row():
193
  auto_predict = gr.Checkbox(
194
  value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
195
  )
196
- predict_button = gr.Button(
197
- "預測", elem_id="predict-button"
198
- )
199
 
200
  with gr.Accordion("進階設定", open=False):
201
  model_selector = gr.Dropdown(
@@ -216,7 +133,6 @@ with gr.Blocks(css=custom_css) as demo:
216
  label="多樣性懲罰 (diversity_penalty)"
217
  )
218
 
219
- # 綁定事件
220
  predict_button.click(
221
  fn=suggest_next,
222
  inputs=[
@@ -251,4 +167,4 @@ with gr.Blocks(css=custom_css) as demo:
251
  outputs=input_text,
252
  )
253
 
254
- demo.launch()
 
26
  ]
27
 
28
  def merge_common_prefixes(suggestions, min_len=2):
 
 
 
 
 
29
  prefixes = []
30
  to_remove = set()
 
31
  for i in range(len(suggestions)):
32
  for j in range(i+1, len(suggestions)):
33
  s1, s2 = suggestions[i], suggestions[j]
 
34
  common = ''.join(c1 for c1, c2 in zip(s1, s2) if c1 == c2)
35
  if len(common) >= min_len:
36
  prefixes.append(common)
37
  to_remove.update([s1, s2])
 
 
38
  unique_prefixes = []
39
  for p in prefixes:
40
  if p not in unique_prefixes:
41
  unique_prefixes.append(p)
 
 
42
  remainder = [s for s in suggestions if s not in to_remove]
43
  return unique_prefixes + remainder
44
 
 
53
 
54
  @spaces.GPU
55
  def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
 
 
 
 
 
 
56
  gen_pipe = get_pipeline(model_name)
57
+ # 構造 generate 參數字典,僅在 penalty>0 時加入 diversity 相關
58
+ gen_kwargs = {
59
+ "max_new_tokens": k,
60
+ "num_beams": m,
61
+ "num_return_sequences": m,
62
+ "do_sample": False,
63
+ "early_stopping": True,
64
+ }
65
+ if diversity_penalty and diversity_penalty > 0:
66
+ gen_kwargs["num_beam_groups"] = num_beam_groups
67
+ gen_kwargs["diversity_penalty"] = diversity_penalty
68
+
69
+ outs = gen_pipe(text, **gen_kwargs)
70
+
71
+ # 提取純下文、過濾空字串、繁體化
72
+ suggestions = [
73
+ cc.convert(out["generated_text"][len(text):].strip())
74
+ for out in outs
75
+ if out["generated_text"][len(text):].strip()
76
+ ]
77
+
78
+ # 去重
79
  unique_suggestions = []
80
  for s in suggestions:
81
  if s not in unique_suggestions:
 
86
 
87
  return update(choices=final_suggestions, value=None)
88
 
89
+ def append_suggestion(text, choice):
90
+ return text + choice
91
 
92
+ with gr.Blocks(css="""
93
+ #suggestions-bar { width: 100%; margin-bottom: 8px; }
 
 
 
 
 
 
 
 
 
 
94
  #suggestions-bar .candidate-list {
95
+ display: flex; gap: 8px; background: #fff;
96
+ border: 1px solid #999; border-radius: 4px;
97
+ padding: 6px; overflow-x: auto; white-space: nowrap;
 
 
 
 
 
 
 
 
 
 
98
  }
99
+ #suggestions-bar .candidate-list label { cursor: pointer; }
100
+ """) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  with gr.Column():
102
  suggestions = gr.Radio(
103
  [], label="", interactive=True, type="value",
 
108
  lines=1, max_lines=20, elem_id="input-box"
109
  )
110
 
 
111
  with gr.Row():
112
  auto_predict = gr.Checkbox(
113
  value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
114
  )
115
+ predict_button = gr.Button("預測", elem_id="predict-button")
 
 
116
 
117
  with gr.Accordion("進階設定", open=False):
118
  model_selector = gr.Dropdown(
 
133
  label="多樣性懲罰 (diversity_penalty)"
134
  )
135
 
 
136
  predict_button.click(
137
  fn=suggest_next,
138
  inputs=[
 
167
  outputs=input_text,
168
  )
169
 
170
+ demo.launch()