File size: 7,335 Bytes
9c6147f
dc13b6d
 
f358820
dc13b6d
 
2f31c84
 
 
 
dc13b6d
da40aec
dc13b6d
e894885
812c80c
2a9fd77
cc1322f
dc13b6d
 
 
 
 
42dc39c
 
 
dc13b6d
 
9c6147f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc13b6d
 
2f31c84
 
 
dc13b6d
 
 
 
9c6147f
da40aec
9c6147f
 
 
 
da40aec
 
 
 
 
 
9c6147f
 
da40aec
 
 
8722708
fe26a7b
72f7051
 
2f31c84
fe26a7b
 
 
 
 
9c6147f
 
 
 
 
72f7051
dc13b6d
 
91852e0
f4fd3fb
c28e606
 
6277588
ee21d1a
44ac7e8
041841e
109cc2f
6bcfde9
041841e
44ac7e8
 
041841e
 
 
44ac7e8
ee21d1a
041841e
 
44ac7e8
 
 
ee21d1a
 
44ac7e8
 
041841e
44ac7e8
 
041841e
 
44ac7e8
ee21d1a
 
 
 
 
 
 
 
 
 
109cc2f
 
ee21d1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44ac7e8
 
 
ee21d1a
da40aec
041841e
 
 
da40aec
dc13b6d
109cc2f
 
 
 
 
 
 
ee21d1a
109cc2f
dc13b6d
bb40af0
109cc2f
 
 
 
 
bb40af0
109cc2f
041841e
44ac7e8
da40aec
44ac7e8
da40aec
 
a4c33c1
da40aec
 
a4c33c1
da40aec
9c6147f
 
 
 
 
 
 
 
dc13b6d
bb40af0
041841e
dc13b6d
9c6147f
 
 
 
 
 
 
 
dc13b6d
 
6277588
9c6147f
 
 
 
 
 
 
 
 
 
 
 
 
6277588
 
dc13b6d
 
 
 
 
 
c28e606
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# app.py
import spaces
import gradio as gr
from gradio import update
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from opencc import OpenCC  # 用於簡體轉繁體

# 初始化簡體到繁體轉換器
cc = OpenCC('s2t')

# 可選模型列表
MODEL_LIST = [
    "liswei/Taiwan-ELM-270M",
    "Mxode/SmolLM-Chinese-180M",
    "flyingfishinwater/chinese-baby-llama2",
    "unsloth/gemma-3-1b-pt",
    "ckiplab/gpt2-tiny-chinese",
    "ckiplab/gpt2-base-chinese",
    "liswei/Taiwan-ELM-1_1B",
    "benchang1110/Qwen2.5-Taiwan-1.5B-Instruct",
    "benchang1110/Taiwan-tinyllama-v1.0-base",
    "lianghsun/Llama-3.2-Taiwan-3B",
    "twinkle-ai/Llama-3.2-3B-F1-Instruct",
    "Epiculous/Violet_Twilight-v0.2",
]

def merge_common_prefixes(suggestions, min_len=2):
    """
    合併具有共同前綴的建議:
    - 找出所有長度 ≥ min_len 的共同前綴
    - 將這些前綴作為新建議,移除原有被合併的項目
    """
    prefixes = []
    to_remove = set()

    for i in range(len(suggestions)):
        for j in range(i+1, len(suggestions)):
            s1, s2 = suggestions[i], suggestions[j]
            # 計算字元級共同前綴
            common = ''.join(c1 for c1, c2 in zip(s1, s2) if c1 == c2)
            if len(common) >= min_len:
                prefixes.append(common)
                to_remove.update([s1, s2])

    # 去重前綴
    unique_prefixes = []
    for p in prefixes:
        if p not in unique_prefixes:
            unique_prefixes.append(p)

    # 剩下沒有被合併的建議
    remainder = [s for s in suggestions if s not in to_remove]
    return unique_prefixes + remainder

@lru_cache(maxsize=8)
def get_pipeline(model_name):
    tok = AutoTokenizer.from_pretrained(model_name)
    mdl = AutoModelForCausalLM.from_pretrained(
        model_name, weights_only=False, trust_remote_code=True
    )
    mdl.to("cuda")
    return pipeline("text-generation", model=mdl, tokenizer=tok, device=0)

@spaces.GPU
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
    """
    使用 Diverse Beam Search 產生 m 條候選:
     - num_beams = m
     - num_beam_groups, diversity_penalty 可調整多樣性
    之後轉繁體、去重、合併共同前綴後回傳。
    """
    gen_pipe = get_pipeline(model_name)
    outs = gen_pipe(
        text,
        max_new_tokens=k,
        num_beams=m,
        num_beam_groups=num_beam_groups,
        diversity_penalty=diversity_penalty,
        num_return_sequences=m,
        do_sample=False,
        early_stopping=True
    )
    # 提取新生成文本,過濾空字串,轉繁體
    suggestions = [out["generated_text"][len(text):].strip() for out in outs]
    suggestions = [s for s in suggestions if s]
    suggestions = [cc.convert(s) for s in suggestions]
    # 去除重複,保留順序
    unique_suggestions = []
    for s in suggestions:
        if s not in unique_suggestions:
            unique_suggestions.append(s)

    # 合併共同前綴
    final_suggestions = merge_common_prefixes(unique_suggestions, min_len=2)

    return update(choices=final_suggestions, value=None)


def append_suggestion(current, choice):
    if choice is None:
        return current
    # 直接插入選中的候選文字
    return current + choice

# 自訂 CSS:模擬經典中文輸入法候選欄樣式,並優化手機響應與自動高度
custom_css = """
#suggestions-bar {
    width: 100%;
    margin-bottom: 8px;
}
#suggestions-bar .candidate-list {
    display: flex;
    gap: 8px;
    background: #fff;
    border: 1px solid #999;
    border-radius: 4px;
    padding: 6px;
    overflow-x: auto;
    white-space: nowrap;
}
#suggestions-bar .candidate-list label {
    cursor: pointer;
    padding: 6px 10px;
    font-size: 16px;
}
#suggestions-bar .candidate-list label:hover {
    background: #f5f5f5;
}
#suggestions-bar .candidate-list input[type=radio]:checked + label {
    background: #e6f7ff;
    border: 1px solid #1890ff;
}
#input-box textarea {
    width: 100%;
    font-size: 16px;
    padding: 6px;
    box-sizing: border-box;
    overflow: hidden;
    resize: none;
}
#predict-button {
    margin-top: 8px;
    width: 100%;
}
/* 手機響應式 */
@media only screen and (max-width: 600px) {
    #suggestions-bar .candidate-list label {
        padding: 8px;
        font-size: 18px;
    }
    #predict-button {
        font-size: 18px;
    }
}
"""

# 自動增高腳本
auto_height_js = """
<script>
  window.addEventListener('load', () => {
    const textarea = document.querySelector('#input-box textarea');
    if (!textarea) return;
    textarea.style.height = 'auto';
    textarea.addEventListener('input', function() {
      this.style.height = 'auto';
      this.style.height = this.scrollHeight + 'px';
    });
  });
</script>
"""

with gr.Blocks(css=custom_css) as demo:
    gr.HTML(auto_height_js)
    gr.Markdown(
        "## 🇹🇼 繁體中文 IME 加速器  \
"
        "結合小型語言模型與 ZeroGPU,提供即時輸入法風格候選欄。"
    )

    with gr.Column():
        suggestions = gr.Radio(
            [], label="", interactive=True, type="value",
            elem_id="suggestions-bar", elem_classes="candidate-list"
        )
        input_text = gr.Textbox(
            label="", placeholder="請輸入拼音或文字…",
            lines=1, max_lines=20, elem_id="input-box"
        )

    # 永遠顯示預測按鈕
    with gr.Row():
        auto_predict = gr.Checkbox(
            value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
        )
        predict_button = gr.Button(
            "預測", elem_id="predict-button"
        )

    with gr.Accordion("進階設定", open=False):
        model_selector = gr.Dropdown(
            MODEL_LIST, value=MODEL_LIST[0], label="模型"
        )
        k_slider = gr.Slider(
            minimum=1, maximum=50, step=1, value=10, label="K(最大新詞元數)"
        )
        m_slider = gr.Slider(
            minimum=1, maximum=30, step=1, value=30, label="M(建議數/Beam 數)"
        )
        group_slider = gr.Slider(
            minimum=1, maximum=30, step=1, value=30,
            label="Beam 群組數 (num_beam_groups)"
        )
        diversity_penalty_slider = gr.Slider(
            minimum=0.0, maximum=2.0, step=0.1, value=1.0,
            label="多樣性懲罰 (diversity_penalty)"
        )

    # 綁定事件
    predict_button.click(
        fn=suggest_next,
        inputs=[
            input_text,
            model_selector,
            k_slider,
            m_slider,
            group_slider,
            diversity_penalty_slider
        ],
        outputs=suggestions,
    )
    input_text.change(
        fn=lambda txt, mdl, k, m, g, d, auto: (
            suggest_next(txt, mdl, k, m, g, d)
            if auto else update(choices=[], value=None)
        ),
        inputs=[
            input_text,
            model_selector,
            k_slider,
            m_slider,
            group_slider,
            diversity_penalty_slider,
            auto_predict
        ],
        outputs=suggestions,
    )
    suggestions.change(
        fn=append_suggestion,
        inputs=[input_text, suggestions],
        outputs=input_text,
    )

    demo.launch()