File size: 6,651 Bytes
6d98a7c
dc13b6d
 
f358820
dc13b6d
 
6d98a7c
05df099
8b84de4
2f31c84
 
 
dc13b6d
da40aec
dc13b6d
e894885
812c80c
80b7444
2a9fd77
cc1322f
fd99b07
dc13b6d
 
 
 
 
42dc39c
 
 
dc13b6d
 
9c6147f
 
dc13b6d
05d7c3d
2f31c84
 
 
6eb5ae5
 
 
 
05d7c3d
dc13b6d
 
05d7c3d
6d98a7c
 
 
 
 
 
da40aec
6d98a7c
3350989
 
 
 
 
 
 
 
05d7c3d
05df099
6ec4350
3350989
 
 
216d8ce
4a676fe
216d8ce
9d21dc2
216d8ce
 
9d21dc2
4a676fe
 
3350989
4a676fe
dc13b6d
3b87f8e
6d98a7c
 
 
 
 
 
 
 
 
 
 
 
44ac7e8
6d98a7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44ac7e8
6d98a7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109cc2f
 
 
 
 
 
 
ee21d1a
109cc2f
dc13b6d
6d98a7c
109cc2f
 
 
 
6d98a7c
 
 
041841e
44ac7e8
da40aec
44ac7e8
da40aec
 
016c189
da40aec
 
016c189
da40aec
9c6147f
83bd284
9c6147f
 
 
83bd284
9c6147f
 
dc13b6d
6d98a7c
041841e
dc13b6d
9c6147f
 
 
 
 
 
05d7c3d
9c6147f
dc13b6d
 
6277588
05d7c3d
 
9c6147f
 
 
 
 
 
 
 
 
05d7c3d
9c6147f
6277588
 
dc13b6d
 
 
 
 
 
6d98a7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# app.py
import spaces
import gradio as gr
from gradio import update
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from opencc import OpenCC  # 用於簡體轉繁體
from math import gcd
from termcolor import cprint

# 初始化簡體到繁體轉換器
cc = OpenCC('s2t')

# 可選模型列表
MODEL_LIST = [
    "liswei/Taiwan-ELM-270M",
    "Mxode/SmolLM-Chinese-180M",
    "openbmb/BitCPM4-0.5B",
    "flyingfishinwater/chinese-baby-llama2",
    "unsloth/gemma-3-1b-pt",
    "taide/TAIDE-LX-7B",
    "ckiplab/gpt2-tiny-chinese",
    "ckiplab/gpt2-base-chinese",
    "liswei/Taiwan-ELM-1_1B",
    "benchang1110/Qwen2.5-Taiwan-1.5B-Instruct",
    "benchang1110/Taiwan-tinyllama-v1.0-base",
    "lianghsun/Llama-3.2-Taiwan-3B",
    "twinkle-ai/Llama-3.2-3B-F1-Instruct",
    "Epiculous/Violet_Twilight-v0.2",
]


@lru_cache(maxsize=8)
def get_pipeline(model_name):
    tok = AutoTokenizer.from_pretrained(model_name)
    mdl = AutoModelForCausalLM.from_pretrained(
        model_name, weights_only=False, trust_remote_code=True
    )
    try:
        mdl.to("cuda")
    except Exception as e:
        print(f'Error: {e}')
    return pipeline("text-generation", model=mdl, tokenizer=tok, device=0)

@spaces.GPU
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
    """
    使用 Diverse Beam Search 產生 m 條候選:
     - num_beams = m
     - num_beam_groups, diversity_penalty 可調整多樣性
    之後轉繁體、去重、合併共同前綴後回傳。
    """
    gen_pipe = get_pipeline(model_name)
    # 構造 generate 參數字典,僅在 penalty>0 時加入 diversity 相關
    gen_kwargs = {
        "max_new_tokens": k,
        "num_beams": m,
        "num_return_sequences": m,
        "do_sample": False,
        "early_stopping": True,
    }
    if diversity_penalty and diversity_penalty > 0:
        valid_group = gcd(m, num_beam_groups)
        gen_kwargs["num_beam_groups"] = valid_group
        gen_kwargs["diversity_penalty"] = float(diversity_penalty)

    outs = gen_pipe(text, **gen_kwargs)

    # 提取純下文、過濾空字串、繁體化、確保 strip 處理
    suggestions = set()
    for out in outs:
        snippet = out["generated_text"][len(text):].rstrip()
        if not snippet:
            continue
        converted = cc.convert(snippet)
        suggestions.add(converted)
    suggestions = list(suggestions)

    return update(choices=suggestions, value=None)


def append_suggestion(current, choice):
    if choice is None:
        return current
    # 直接插入選中的候選文字
    return current + choice

# 自訂 CSS:模擬經典中文輸入法候選欄樣式,並優化手機響應與自動高度
custom_css = """
#suggestions-bar {
    width: 100%;
    margin-bottom: 8px;
}
#suggestions-bar .candidate-list {
    display: flex;
    gap: 8px;
    background: #fff;
    border: 1px solid #999;
    border-radius: 4px;
    padding: 6px;
    overflow-x: auto;
    white-space: nowrap;
}
#suggestions-bar .candidate-list label {
    cursor: pointer;
    padding: 6px 10px;
    font-size: 16px;
}
#suggestions-bar .candidate-list label:hover {
    background: #f5f5f5;
}
#suggestions-bar .candidate-list input[type=radio]:checked + label {
    background: #e6f7ff;
    border: 1px solid #1890ff;
}
#input-box textarea {
    width: 100%;
    font-size: 16px;
    padding: 6px;
    box-sizing: border-box;
    overflow: hidden;
    resize: none;
}
#predict-button {
    margin-top: 8px;
    width: 100%;
}
/* 手機響應式 */
@media only screen and (max-width: 600px) {
    #suggestions-bar .candidate-list label {
        padding: 8px;
        font-size: 18px;
    }
    #predict-button {
        font-size: 18px;
    }
}
"""

# 自動增高腳本
auto_height_js = """
<script>
  window.addEventListener('load', () => {
    const textarea = document.querySelector('#input-box textarea');
    if (!textarea) return;
    textarea.style.height = 'auto';
    textarea.addEventListener('input', function() {
      this.style.height = 'auto';
      this.style.height = this.scrollHeight + 'px';
    });
  });
</script>
"""

with gr.Blocks(css=custom_css) as demo:
    gr.HTML(auto_height_js)
    gr.Markdown(
        "## 🇹🇼 繁體中文 IME 加速器  \
"
        "結合小型語言模型與 ZeroGPU,提供即時輸入法風格候選欄。"
    )

    with gr.Column():
        suggestions = gr.Radio(
            [], label="", interactive=True, type="value",
            elem_id="suggestions-bar", elem_classes="candidate-list"
        )
        input_text = gr.Textbox(
            label="", placeholder="請輸入拼音或文字…",
            lines=1, max_lines=20, elem_id="input-box"
        )

    # 永遠顯示預測按鈕
    with gr.Row():
        auto_predict = gr.Checkbox(
            value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
        )
        predict_button = gr.Button(
            "預測", elem_id="predict-button"
        )

    with gr.Accordion("進階設定", open=False):
        model_selector = gr.Dropdown(
            MODEL_LIST, value=MODEL_LIST[0], label="模型"
        )
        k_slider = gr.Slider(
            minimum=1, maximum=50, step=1, value=1, label="K(最大新詞元數)"
        )
        m_slider = gr.Slider(
            minimum=1, maximum=30, step=1, value=10, label="M(建議數/Beam 數)"
        )
        group_slider = gr.Slider(
            minimum=2, maximum=30, step=2, value=6,
            label="Beam 群組數 (num_beam_groups)"
        )
        diversity_penalty_slider = gr.Slider(
            minimum=0.0, maximum=2.0, step=0.1, value=0.0,
            label="多樣性懲罰 (diversity_penalty)"
        )

    # 綁定事件
    predict_button.click(
        fn=suggest_next,
        inputs=[
            input_text,
            model_selector,
            k_slider,
            m_slider,
            group_slider,
            diversity_penalty_slider
        ],
        outputs=suggestions,
    )
    input_text.change(
        fn=lambda txt, mdl, k, m, g, d, auto: (
            suggest_next(txt, mdl, k, m, g, d)
            if auto else update(choices=[], value=None)
        ),
        inputs=[
            input_text,
            model_selector,
            k_slider,
            m_slider,
            group_slider,
            diversity_penalty_slider,
            auto_predict
        ],
        outputs=suggestions,
    )
    suggestions.change(
        fn=append_suggestion,
        inputs=[input_text, suggestions],
        outputs=input_text,
    )

    demo.launch()