File size: 8,879 Bytes
aa6e6b1 c24971a 6f21ce1 aa6e6b1 c24971a aa6e6b1 c24971a 6f21ce1 c24971a aa6e6b1 c24971a aa6e6b1 c24971a aa6e6b1 c24971a aa6e6b1 327846c aa6e6b1 327846c c24971a 327846c c24971a 327846c c24971a 327846c c24971a aa6e6b1 c24971a aa6e6b1 ccae5de c24971a aa6e6b1 c24971a aa6e6b1 327846c aa6e6b1 327846c aa6e6b1 327846c aa6e6b1 c24971a aa6e6b1 c24971a aa6e6b1 c24971a aa6e6b1 e1e36f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import os
import time
import logging
import re
import gradio as gr
from spaces import zero # 关键:引入 zero 装饰器
import spaces
# 不要在这里 import torch 或加载模型
# from transformers import TextIteratorStreamer, AutoTokenizer # 不再需要
# 尝试导入 qwen_vl_utils,若失败则提供降级实现(返回空的图像/视频输入)
try:
from qwen_vl_utils import process_vision_info
except Exception:
def process_vision_info(messages):
return None, None
def replace_single_quotes(text):
pattern = r"\B'([^']*)'\B"
replaced_text = re.sub(pattern, r'"\1"', text)
replaced_text = replaced_text.replace("’", "”").replace("‘", "“")
return replaced_text
DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B")
def _str_to_dtype(dtype_str):
# 在子进程中再真正用 torch;这里仅返回字符串用于传参
if dtype_str in ("bfloat16", "float16", "float32"):
return dtype_str
return "float32"
@spaces.GPU # 在子进程(拥有 GPU)中执行:包含模型加载与推理
def gpu_predict(model_path, device_map, torch_dtype,
prompt_cot, sys_prompt, temperature, max_new_tokens, device):
# 注意:所有 CUDA 相关 import 放在子进程函数内部
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
# logger(可选)
if not logging.getLogger(__name__).handlers:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# dtype
if torch_dtype == "bfloat16":
dtype = torch.bfloat16
elif torch_dtype == "float16":
dtype = torch.float16
else:
dtype = torch.float32
# 设备映射:根据 UI 的 device / device_map 决定
# ZeroGPU 建议 GPU 推理时用 "cuda"
target_device = "cuda" if device == "cuda" else "cpu"
load_device_map = "cuda" if device_map == "cuda" else "cpu"
# 加载模型与处理器(在子进程)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=dtype,
device_map=load_device_map,
attn_implementation="sdpa", # 禁用 flash-attn,兼容性更好
)
processor = AutoProcessor.from_pretrained(model_path)
# 组装消息
org_prompt_cot = prompt_cot
try:
user_prompt_format = sys_prompt + "\n" + org_prompt_cot
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt_format},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# 把输入移动到目标设备
inputs = inputs.to(target_device)
# 生成
generated_ids = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
do_sample=False,
top_k=5,
top_p=0.9,
)
# 仅解码新增 token
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
output_res = output_text[0]
# 兼容原逻辑:提取 think> 之后的内容
try:
assert output_res.count("think>") == 2
new_prompt = output_res.split("think>")[-1]
if new_prompt.startswith("\n"):
new_prompt = new_prompt[1:]
new_prompt = replace_single_quotes(new_prompt)
except Exception:
# 如果格式不符合预期,则直接回退为原始输入
new_prompt = org_prompt_cot
return new_prompt, ""
except Exception as e:
# 失败则返回原始提示词和错误信息
return org_prompt_cot, f"推理失败:{e}"
# -------------------------
# Gradio app
# -------------------------
def run_single(prompt, sys_prompt, temperature, max_new_tokens, device,
model_path, device_map, torch_dtype, state):
if not prompt or not str(prompt).strip():
return "", "请先输入提示词。", state
t0 = time.time()
try:
new_prompt, err = gpu_predict(
model_path=model_path,
device_map=device_map,
torch_dtype=_str_to_dtype(torch_dtype),
prompt_cot=prompt,
sys_prompt=sys_prompt,
temperature=temperature,
max_new_tokens=max_new_tokens,
device=device,
)
dt = time.time() - t0
if err:
return new_prompt, f"{err}(耗时 {dt:.2f}s)", state
return new_prompt, f"耗时:{dt:.2f}s", state
except Exception as e:
return "", f"调用失败:{e}", state
# 示例数据
test_list_zh = [
"第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
"韩系插画风女生头像,粉紫色短发+透明感腮红,侧光渲染。",
"点彩派,盛夏海滨,两位渔夫正在搬运木箱,三艘帆船停在岸边,对角线构图。",
"一幅由梵高绘制的梦境麦田,旋转的蓝色星云与燃烧的向日葵相纠缠。",
]
test_list_en = [
"Create a painting depicting a 30-year-old white female white-collar worker on a business trip by plane.",
"Depicted in the anime style of Studio Ghibli, a girl stands quietly at the deck with a gentle smile.",
"Blue background, a lone girl gazes into the distant sea; her expression is sorrowful.",
"A blend of expressionist and vintage styles, drawing a building with colorful walls.",
"Paint a winter scene with crystalline ice hangings from an Antarctic research station.",
]
with gr.Blocks(title="Prompt Enhancer_V2") as demo:
gr.Markdown("## 提示词重写器")
with gr.Row():
with gr.Column(scale=2):
model_path = gr.Textbox(
label="模型路径(本地或HF地址)",
value=DEFAULT_MODEL_PATH,
placeholder="例如:Qwen/Qwen2.5-VL-7B-Instruct",
)
device_map = gr.Dropdown(
choices=["cuda", "cpu"],
value="cuda",
label="device_map(模型加载映射)"
)
torch_dtype = gr.Dropdown(
choices=["bfloat16", "float16", "float32"],
value="bfloat16",
label="torch_dtype"
)
with gr.Column(scale=3):
sys_prompt = gr.Textbox(
label="系统提示词(默认无需修改)",
value="请根据用户的输入,生成思考过程的思维链并改写提示词:",
lines=3
)
with gr.Row():
temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature")
max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens")
device = gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="推理device")
state = gr.State(value=None)
with gr.Tab("推理"):
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(label="输入提示词", lines=6, placeholder="在此粘贴要改写的提示词...")
run_btn = gr.Button("生成重写", variant="primary")
gr.Examples(
examples=test_list_zh + test_list_en,
inputs=prompt,
label="示例"
)
with gr.Column(scale=3):
out_text = gr.Textbox(label="重写结果", lines=10)
out_info = gr.Markdown("准备就绪。")
run_btn.click(
run_single,
inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
model_path, device_map, torch_dtype, state],
outputs=[out_text, out_info, state]
)
gr.Markdown("提示:如有任何问题可 email 联系:linqing1995@buaa.edu.cn")
# 为避免多并发导致显存爆,可限制并发(ZeroGPU 本身是无状态,仍建议限制)
# demo.queue(concurrency_count=1, max_size=10)
if __name__ == "__main__":
demo.launch(ssr_mode=False, show_error=True, share=True) |