Luigi commited on
Commit
a94f020
·
1 Parent(s): ce9b231

add suggestions cleanning

Browse files
Files changed (1) hide show
  1. app.py +88 -8
app.py CHANGED
@@ -10,6 +10,7 @@ from termcolor import cprint
10
 
11
  # 初始化簡體到繁體轉換器
12
  cc = OpenCC('s2t')
 
13
 
14
  # 可選模型列表
15
  MODEL_LIST = [
@@ -27,10 +28,82 @@ MODEL_LIST = [
27
  "Epiculous/Violet_Twilight-v0.2",
28
  ]
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  @lru_cache(maxsize=8)
32
  def get_pipeline(model_name):
33
- tok = AutoTokenizer.from_pretrained(model_name)
 
34
  mdl = AutoModelForCausalLM.from_pretrained(
35
  model_name, weights_only=False, trust_remote_code=True
36
  )
@@ -38,10 +111,10 @@ def get_pipeline(model_name):
38
  mdl.to("cuda")
39
  except Exception as e:
40
  print(f'Error: {e}')
41
- return pipeline("text-generation", model=mdl, tokenizer=tok, device=0)
42
 
43
  @spaces.GPU
44
- def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
45
  """
46
  使用 Diverse Beam Search 產生 m 條候選:
47
  - num_beams = m
@@ -58,7 +131,7 @@ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
58
  "early_stopping": True,
59
  }
60
  if diversity_penalty and diversity_penalty > 0:
61
- valid_group = gcd(m, num_beam_groups)
62
  gen_kwargs["num_beam_groups"] = valid_group
63
  gen_kwargs["diversity_penalty"] = float(diversity_penalty)
64
 
@@ -73,6 +146,7 @@ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
73
  converted = cc.convert(snippet).strip()
74
  suggestions.add(converted)
75
  suggestions = list(suggestions)
 
76
 
77
  return update(choices=suggestions, value=None)
78
 
@@ -195,6 +269,10 @@ with gr.Blocks(css=custom_css) as demo:
195
  minimum=0.0, maximum=2.0, step=0.1, value=1.0,
196
  label="多樣性懲罰 (diversity_penalty)"
197
  )
 
 
 
 
198
 
199
  # 綁定事件
200
  predict_button.click(
@@ -205,13 +283,14 @@ with gr.Blocks(css=custom_css) as demo:
205
  k_slider,
206
  m_slider,
207
  group_slider,
208
- diversity_penalty_slider
 
209
  ],
210
  outputs=suggestions,
211
  )
212
  input_text.change(
213
- fn=lambda txt, mdl, k, m, g, d, auto: (
214
- suggest_next(txt, mdl, k, m, g, d)
215
  if auto else update(choices=[], value=None)
216
  ),
217
  inputs=[
@@ -221,7 +300,8 @@ with gr.Blocks(css=custom_css) as demo:
221
  m_slider,
222
  group_slider,
223
  diversity_penalty_slider,
224
- auto_predict
 
225
  ],
226
  outputs=suggestions,
227
  )
 
10
 
11
  # 初始化簡體到繁體轉換器
12
  cc = OpenCC('s2t')
13
+ tokenizer = None
14
 
15
  # 可選模型列表
16
  MODEL_LIST = [
 
28
  "Epiculous/Violet_Twilight-v0.2",
29
  ]
30
 
31
+ def clean_suggestions(suggestions: list[str], max_levels: int) -> list[str]:
32
+ """
33
+ 清洗建议列表:
34
+ 1. 对每条建议用 tokenizer.tokenize 得到 token 序列。
35
+ 2. 构建前缀树,将所有 token 序列插入。
36
+ 3. 遍历前缀树,仅在深度 <= max_levels 且该节点有子节点时,提取对应 token 前缀。
37
+ 4. 将这些 token 前缀转换回文本并去重,返回列表。
38
+ """
39
+ # 定义 Trie 节点结构
40
+ class TrieNode:
41
+ __slots__ = ("children", "count")
42
+ def __init__(self):
43
+ self.children: dict[str, TrieNode] = {}
44
+ self.count: int = 0 # 可以记录有多少序列经过此节点(可选)
45
+
46
+ # 构建前缀树
47
+ root = TrieNode()
48
+ token_seqs: list[list[str]] = []
49
+
50
+ for text in suggestions:
51
+ # tokenizer.tokenize 可能返回子词 token 列表
52
+ try:
53
+ toks = tokenizer.tokenize(text)
54
+ except Exception:
55
+ # 如果 tokenizer 不支持直接 tokenize raw text,可以先用 basic tokenization,如按空白分割
56
+ toks = text.split()
57
+ if not toks:
58
+ continue
59
+ token_seqs.append(toks)
60
+ node = root
61
+ node.count += 1
62
+ for tok in toks:
63
+ if tok not in node.children:
64
+ node.children[tok] = TrieNode()
65
+ node = node.children[tok]
66
+ node.count += 1
67
+
68
+ # 遍历 Trie,收集深度 <= max_levels 且有子节点的前缀序列
69
+ results_prefix_tokens: set[tuple[str, ...]] = set()
70
+
71
+ def dfs(node: TrieNode, path: list[str], depth: int):
72
+ # node: 当前 TrieNode; path: 已走过的 token 列表; depth: len(path)
73
+ if depth > max_levels:
74
+ return
75
+ # 如果当前节点有子节点,且 depth>0 (排除根节点本身),则为一个候选前缀
76
+ if depth > 0 and node.children:
77
+ results_prefix_tokens.add(tuple(path))
78
+ # 继续往下遍历,直到 depth == max_levels
79
+ if depth == max_levels:
80
+ return
81
+ for tok, child in node.children.items():
82
+ path.append(tok)
83
+ dfs(child, path, depth + 1)
84
+ path.pop()
85
+
86
+ dfs(root, [], 0)
87
+
88
+ # 将 token 前缀转换回字符串
89
+ cleaned: set[str] = set()
90
+ for tok_prefix in results_prefix_tokens:
91
+ try:
92
+ # tokenizer.convert_tokens_to_string 在大多数 tokenizer 支持
93
+ text_pref = tokenizer.convert_tokens_to_string(list(tok_prefix)).strip()
94
+ except Exception:
95
+ # fallback: 直接拼接 token(可能需要根据 tokenizer 规范加空格或直接连起来)
96
+ text_pref = "".join(tok_prefix).strip()
97
+ if text_pref:
98
+ cleaned.add(text_pref)
99
+
100
+ # 返回去重之后的列表
101
+ return list(cleaned)
102
 
103
  @lru_cache(maxsize=8)
104
  def get_pipeline(model_name):
105
+ global tokenizer
106
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
107
  mdl = AutoModelForCausalLM.from_pretrained(
108
  model_name, weights_only=False, trust_remote_code=True
109
  )
 
111
  mdl.to("cuda")
112
  except Exception as e:
113
  print(f'Error: {e}')
114
+ return pipeline("text-generation", model=mdl, tokenizer=tokenizer, device=0)
115
 
116
  @spaces.GPU
117
+ def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty, max_prefix_levels=2):
118
  """
119
  使用 Diverse Beam Search 產生 m 條候選:
120
  - num_beams = m
 
131
  "early_stopping": True,
132
  }
133
  if diversity_penalty and diversity_penalty > 0:
134
+ valid_group = max(gcd(m, num_beam_groups),2)
135
  gen_kwargs["num_beam_groups"] = valid_group
136
  gen_kwargs["diversity_penalty"] = float(diversity_penalty)
137
 
 
146
  converted = cc.convert(snippet).strip()
147
  suggestions.add(converted)
148
  suggestions = list(suggestions)
149
+ suggestions = clean_suggestions(suggestions, max_prefix_levels)
150
 
151
  return update(choices=suggestions, value=None)
152
 
 
269
  minimum=0.0, maximum=2.0, step=0.1, value=1.0,
270
  label="多樣性懲罰 (diversity_penalty)"
271
  )
272
+ prefix_levels_slider = gr.Slider(
273
+ minimum=1, maximum=5, step=1, value=2,
274
+ label="Clean 前綴深度 (max_levels)"
275
+ )
276
 
277
  # 綁定事件
278
  predict_button.click(
 
283
  k_slider,
284
  m_slider,
285
  group_slider,
286
+ diversity_penalty_slider,
287
+ prefix_levels_slider # 新增
288
  ],
289
  outputs=suggestions,
290
  )
291
  input_text.change(
292
+ fn=lambda txt, mdl, k, m, g, d, auto, pl: (
293
+ suggest_next(txt, mdl, k, m, g, d, pl)
294
  if auto else update(choices=[], value=None)
295
  ),
296
  inputs=[
 
300
  m_slider,
301
  group_slider,
302
  diversity_penalty_slider,
303
+ auto_predict,
304
+ prefix_levels_slider # 新增
305
  ],
306
  outputs=suggestions,
307
  )