Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,651 Bytes
6d98a7c dc13b6d f358820 dc13b6d 6d98a7c 05df099 8b84de4 2f31c84 dc13b6d da40aec dc13b6d e894885 812c80c 80b7444 2a9fd77 cc1322f fd99b07 dc13b6d 42dc39c dc13b6d 9c6147f dc13b6d 05d7c3d 2f31c84 6eb5ae5 05d7c3d dc13b6d 05d7c3d 6d98a7c da40aec 6d98a7c 3350989 05d7c3d 05df099 6ec4350 3350989 216d8ce 4a676fe 216d8ce 9d21dc2 216d8ce 9d21dc2 4a676fe 3350989 4a676fe dc13b6d 3b87f8e 6d98a7c 44ac7e8 6d98a7c 44ac7e8 6d98a7c 109cc2f ee21d1a 109cc2f dc13b6d 6d98a7c 109cc2f 6d98a7c 041841e 44ac7e8 da40aec 44ac7e8 da40aec 016c189 da40aec 016c189 da40aec 9c6147f 83bd284 9c6147f 83bd284 9c6147f dc13b6d 6d98a7c 041841e dc13b6d 9c6147f 05d7c3d 9c6147f dc13b6d 6277588 05d7c3d 9c6147f 05d7c3d 9c6147f 6277588 dc13b6d 6d98a7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
# app.py
import spaces
import gradio as gr
from gradio import update
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from opencc import OpenCC # 用於簡體轉繁體
from math import gcd
from termcolor import cprint
# 初始化簡體到繁體轉換器
cc = OpenCC('s2t')
# 可選模型列表
MODEL_LIST = [
"liswei/Taiwan-ELM-270M",
"Mxode/SmolLM-Chinese-180M",
"openbmb/BitCPM4-0.5B",
"flyingfishinwater/chinese-baby-llama2",
"unsloth/gemma-3-1b-pt",
"taide/TAIDE-LX-7B",
"ckiplab/gpt2-tiny-chinese",
"ckiplab/gpt2-base-chinese",
"liswei/Taiwan-ELM-1_1B",
"benchang1110/Qwen2.5-Taiwan-1.5B-Instruct",
"benchang1110/Taiwan-tinyllama-v1.0-base",
"lianghsun/Llama-3.2-Taiwan-3B",
"twinkle-ai/Llama-3.2-3B-F1-Instruct",
"Epiculous/Violet_Twilight-v0.2",
]
@lru_cache(maxsize=8)
def get_pipeline(model_name):
tok = AutoTokenizer.from_pretrained(model_name)
mdl = AutoModelForCausalLM.from_pretrained(
model_name, weights_only=False, trust_remote_code=True
)
try:
mdl.to("cuda")
except Exception as e:
print(f'Error: {e}')
return pipeline("text-generation", model=mdl, tokenizer=tok, device=0)
@spaces.GPU
def suggest_next(text, model_name, k, m, num_beam_groups, diversity_penalty):
"""
使用 Diverse Beam Search 產生 m 條候選:
- num_beams = m
- num_beam_groups, diversity_penalty 可調整多樣性
之後轉繁體、去重、合併共同前綴後回傳。
"""
gen_pipe = get_pipeline(model_name)
# 構造 generate 參數字典,僅在 penalty>0 時加入 diversity 相關
gen_kwargs = {
"max_new_tokens": k,
"num_beams": m,
"num_return_sequences": m,
"do_sample": False,
"early_stopping": True,
}
if diversity_penalty and diversity_penalty > 0:
valid_group = gcd(m, num_beam_groups)
gen_kwargs["num_beam_groups"] = valid_group
gen_kwargs["diversity_penalty"] = float(diversity_penalty)
outs = gen_pipe(text, **gen_kwargs)
# 提取純下文、過濾空字串、繁體化、確保 strip 處理
suggestions = set()
for out in outs:
snippet = out["generated_text"][len(text):].rstrip()
if not snippet:
continue
converted = cc.convert(snippet)
suggestions.add(converted)
suggestions = list(suggestions)
return update(choices=suggestions, value=None)
def append_suggestion(current, choice):
if choice is None:
return current
# 直接插入選中的候選文字
return current + choice
# 自訂 CSS:模擬經典中文輸入法候選欄樣式,並優化手機響應與自動高度
custom_css = """
#suggestions-bar {
width: 100%;
margin-bottom: 8px;
}
#suggestions-bar .candidate-list {
display: flex;
gap: 8px;
background: #fff;
border: 1px solid #999;
border-radius: 4px;
padding: 6px;
overflow-x: auto;
white-space: nowrap;
}
#suggestions-bar .candidate-list label {
cursor: pointer;
padding: 6px 10px;
font-size: 16px;
}
#suggestions-bar .candidate-list label:hover {
background: #f5f5f5;
}
#suggestions-bar .candidate-list input[type=radio]:checked + label {
background: #e6f7ff;
border: 1px solid #1890ff;
}
#input-box textarea {
width: 100%;
font-size: 16px;
padding: 6px;
box-sizing: border-box;
overflow: hidden;
resize: none;
}
#predict-button {
margin-top: 8px;
width: 100%;
}
/* 手機響應式 */
@media only screen and (max-width: 600px) {
#suggestions-bar .candidate-list label {
padding: 8px;
font-size: 18px;
}
#predict-button {
font-size: 18px;
}
}
"""
# 自動增高腳本
auto_height_js = """
<script>
window.addEventListener('load', () => {
const textarea = document.querySelector('#input-box textarea');
if (!textarea) return;
textarea.style.height = 'auto';
textarea.addEventListener('input', function() {
this.style.height = 'auto';
this.style.height = this.scrollHeight + 'px';
});
});
</script>
"""
with gr.Blocks(css=custom_css) as demo:
gr.HTML(auto_height_js)
gr.Markdown(
"## 🇹🇼 繁體中文 IME 加速器 \
"
"結合小型語言模型與 ZeroGPU,提供即時輸入法風格候選欄。"
)
with gr.Column():
suggestions = gr.Radio(
[], label="", interactive=True, type="value",
elem_id="suggestions-bar", elem_classes="candidate-list"
)
input_text = gr.Textbox(
label="", placeholder="請輸入拼音或文字…",
lines=1, max_lines=20, elem_id="input-box"
)
# 永遠顯示預測按鈕
with gr.Row():
auto_predict = gr.Checkbox(
value=True, label="自動預測(內容變更時觸發)", elem_id="auto-predict"
)
predict_button = gr.Button(
"預測", elem_id="predict-button"
)
with gr.Accordion("進階設定", open=False):
model_selector = gr.Dropdown(
MODEL_LIST, value=MODEL_LIST[0], label="模型"
)
k_slider = gr.Slider(
minimum=1, maximum=50, step=1, value=1, label="K(最大新詞元數)"
)
m_slider = gr.Slider(
minimum=1, maximum=30, step=1, value=10, label="M(建議數/Beam 數)"
)
group_slider = gr.Slider(
minimum=2, maximum=30, step=2, value=6,
label="Beam 群組數 (num_beam_groups)"
)
diversity_penalty_slider = gr.Slider(
minimum=0.0, maximum=2.0, step=0.1, value=0.0,
label="多樣性懲罰 (diversity_penalty)"
)
# 綁定事件
predict_button.click(
fn=suggest_next,
inputs=[
input_text,
model_selector,
k_slider,
m_slider,
group_slider,
diversity_penalty_slider
],
outputs=suggestions,
)
input_text.change(
fn=lambda txt, mdl, k, m, g, d, auto: (
suggest_next(txt, mdl, k, m, g, d)
if auto else update(choices=[], value=None)
),
inputs=[
input_text,
model_selector,
k_slider,
m_slider,
group_slider,
diversity_penalty_slider,
auto_predict
],
outputs=suggestions,
)
suggestions.change(
fn=append_suggestion,
inputs=[input_text, suggestions],
outputs=input_text,
)
demo.launch() |