File size: 8,879 Bytes
aa6e6b1
 
 
 
 
c24971a
6f21ce1
aa6e6b1
c24971a
 
aa6e6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24971a
 
 
 
 
 
6f21ce1
c24971a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa6e6b1
c24971a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa6e6b1
c24971a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa6e6b1
c24971a
 
 
aa6e6b1
327846c
 
 
 
aa6e6b1
327846c
 
c24971a
 
 
 
327846c
 
 
 
c24971a
327846c
 
c24971a
 
 
327846c
c24971a
 
aa6e6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24971a
aa6e6b1
ccae5de
 
 
 
c24971a
aa6e6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
c24971a
aa6e6b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327846c
aa6e6b1
327846c
aa6e6b1
327846c
aa6e6b1
c24971a
aa6e6b1
c24971a
aa6e6b1
c24971a
aa6e6b1
e1e36f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import os
import time
import logging
import re
import gradio as gr
from spaces import zero  # 关键:引入 zero 装饰器
import spaces

# 不要在这里 import torch 或加载模型
# from transformers import TextIteratorStreamer, AutoTokenizer  # 不再需要

# 尝试导入 qwen_vl_utils,若失败则提供降级实现(返回空的图像/视频输入)
try:
    from qwen_vl_utils import process_vision_info
except Exception:
    def process_vision_info(messages):
        return None, None

def replace_single_quotes(text):
    pattern = r"\B'([^']*)'\B"
    replaced_text = re.sub(pattern, r'"\1"', text)
    replaced_text = replaced_text.replace("’", "”").replace("‘", "“")
    return replaced_text

DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B")

def _str_to_dtype(dtype_str):
    # 在子进程中再真正用 torch;这里仅返回字符串用于传参
    if dtype_str in ("bfloat16", "float16", "float32"):
        return dtype_str
    return "float32"

@spaces.GPU  # 在子进程(拥有 GPU)中执行:包含模型加载与推理
def gpu_predict(model_path, device_map, torch_dtype,
                prompt_cot, sys_prompt, temperature, max_new_tokens, device):
    # 注意:所有 CUDA 相关 import 放在子进程函数内部
    import torch
    from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor

    # logger(可选)
    if not logging.getLogger(__name__).handlers:
        logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    # dtype
    if torch_dtype == "bfloat16":
        dtype = torch.bfloat16
    elif torch_dtype == "float16":
        dtype = torch.float16
    else:
        dtype = torch.float32

    # 设备映射:根据 UI 的 device / device_map 决定
    # ZeroGPU 建议 GPU 推理时用 "cuda"
    target_device = "cuda" if device == "cuda" else "cpu"
    load_device_map = "cuda" if device_map == "cuda" else "cpu"

    # 加载模型与处理器(在子进程)
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        model_path,
        torch_dtype=dtype,
        device_map=load_device_map,
        attn_implementation="sdpa",  # 禁用 flash-attn,兼容性更好
    )
    processor = AutoProcessor.from_pretrained(model_path)

    # 组装消息
    org_prompt_cot = prompt_cot
    try:
        user_prompt_format = sys_prompt + "\n" + org_prompt_cot
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": user_prompt_format},
                ],
            }
        ]

        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        image_inputs, video_inputs = process_vision_info(messages)

        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        # 把输入移动到目标设备
        inputs = inputs.to(target_device)

        # 生成
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=int(max_new_tokens),
            temperature=float(temperature),
            do_sample=False,
            top_k=5,
            top_p=0.9,
        )
        # 仅解码新增 token
        generated_ids_trimmed = [
            out_ids[len(in_ids):]
            for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        output_text = processor.batch_decode(
            generated_ids_trimmed,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False,
        )
        output_res = output_text[0]
        # 兼容原逻辑:提取 think> 之后的内容
        try:
            assert output_res.count("think>") == 2
            new_prompt = output_res.split("think>")[-1]
            if new_prompt.startswith("\n"):
                new_prompt = new_prompt[1:]
            new_prompt = replace_single_quotes(new_prompt)
        except Exception:
            # 如果格式不符合预期,则直接回退为原始输入
            new_prompt = org_prompt_cot
        return new_prompt, ""
    except Exception as e:
        # 失败则返回原始提示词和错误信息
        return org_prompt_cot, f"推理失败:{e}"

# -------------------------
# Gradio app
# -------------------------

def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
               model_path, device_map, torch_dtype, state):
    if not prompt or not str(prompt).strip():
        return "", "请先输入提示词。", state

    t0 = time.time()
    try:
        new_prompt, err = gpu_predict(
            model_path=model_path,
            device_map=device_map,
            torch_dtype=_str_to_dtype(torch_dtype),
            prompt_cot=prompt,
            sys_prompt=sys_prompt,
            temperature=temperature,
            max_new_tokens=max_new_tokens,
            device=device,
        )
        dt = time.time() - t0
        if err:
            return new_prompt, f"{err}(耗时 {dt:.2f}s)", state
        return new_prompt, f"耗时:{dt:.2f}s", state
    except Exception as e:
        return "", f"调用失败:{e}", state

# 示例数据
test_list_zh = [
    "第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
    "韩系插画风女生头像,粉紫色短发+透明感腮红,侧光渲染。",
    "点彩派,盛夏海滨,两位渔夫正在搬运木箱,三艘帆船停在岸边,对角线构图。",
    "一幅由梵高绘制的梦境麦田,旋转的蓝色星云与燃烧的向日葵相纠缠。",
]
test_list_en = [
    "Create a painting depicting a 30-year-old white female white-collar worker on a business trip by plane.",
    "Depicted in the anime style of Studio Ghibli, a girl stands quietly at the deck with a gentle smile.",
    "Blue background, a lone girl gazes into the distant sea; her expression is sorrowful.",
    "A blend of expressionist and vintage styles, drawing a building with colorful walls.",
    "Paint a winter scene with crystalline ice hangings from an Antarctic research station.",
]

with gr.Blocks(title="Prompt Enhancer_V2") as demo:
    gr.Markdown("## 提示词重写器")
    with gr.Row():
        with gr.Column(scale=2):
            model_path = gr.Textbox(
                label="模型路径(本地或HF地址)",
                value=DEFAULT_MODEL_PATH,
                placeholder="例如:Qwen/Qwen2.5-VL-7B-Instruct",
            )
            device_map = gr.Dropdown(
                choices=["cuda", "cpu"],
                value="cuda",
                label="device_map(模型加载映射)"
            )
            torch_dtype = gr.Dropdown(
                choices=["bfloat16", "float16", "float32"],
                value="bfloat16",
                label="torch_dtype"
            )

        with gr.Column(scale=3):
            sys_prompt = gr.Textbox(
                label="系统提示词(默认无需修改)",
                value="请根据用户的输入,生成思考过程的思维链并改写提示词:",
                lines=3
            )
            with gr.Row():
                temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature")
                max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens")
                device = gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="推理device")

    state = gr.State(value=None)

    with gr.Tab("推理"):
        with gr.Row():
            with gr.Column(scale=2):
                prompt = gr.Textbox(label="输入提示词", lines=6, placeholder="在此粘贴要改写的提示词...")
                run_btn = gr.Button("生成重写", variant="primary")
                gr.Examples(
                    examples=test_list_zh + test_list_en,
                    inputs=prompt,
                    label="示例"
                )
            with gr.Column(scale=3):
                out_text = gr.Textbox(label="重写结果", lines=10)
                out_info = gr.Markdown("准备就绪。")

        run_btn.click(
            run_single,
            inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
                    model_path, device_map, torch_dtype, state],
            outputs=[out_text, out_info, state]
        )

    gr.Markdown("提示:如有任何问题可 email 联系:linqing1995@buaa.edu.cn")

# 为避免多并发导致显存爆,可限制并发(ZeroGPU 本身是无状态,仍建议限制)
# demo.queue(concurrency_count=1, max_size=10)

if __name__ == "__main__":
    demo.launch(ssr_mode=False, show_error=True, share=True)