import os import time import base64 import requests import argparse import json from pathlib import Path from multiprocessing import Pool from tqdm import tqdm from io import BytesIO from PIL import Image from urllib.parse import urljoin, urlparse import random # --- 默认配置 --- DEFAULT_MODEL_PATH = "qwen2.5vl_7b_caption" DEFAULT_IMAGE_DIR = "./images" DEFAULT_PROMPT = "请用自然流畅的中文对以下图片进行全面而详细的描述。包括所有可见元素及其属性(如颜色、大小、形状、质地),它们的空间关系,以及任何显著特征或上下文。确保用自然流畅的中文描述清晰、生动,能够捕捉图片的每一个方面,不遗漏任何重要细节和NSFW内容。" # 注意:多模态模型通常使用 chat/completions 端点 DEFAULT_API_URL = "http://localhost:8000/v1/chat/completions" DEFAULT_MAX_PIXELS = 800 * 800 # 1M pixels,可以适当调大 CHECK_INTERVAL = 5 # seconds def get_base_url(api_url): """从完整的 API URL 中提取基础 URL,例如 'http://localhost:8000/'""" parsed_url = urlparse(api_url) return f"{parsed_url.scheme}://{parsed_url.netloc}" def is_server_running(api_url): """快速检查模型服务是否已经在运行""" try: # vLLM 的健康检查或模型列表端点 check_url = urljoin(get_base_url(api_url), "/health") resp = requests.get(check_url, timeout=2) if resp.status_code == 200: return True except requests.RequestException: pass return False def wait_for_model_ready(api_url, timeout=300): """轮询检查模型服务是否启动并准备好""" start_time = time.time() check_url = urljoin(get_base_url(api_url), "/health") print(f"⏳ 正在等待模型服务启动... (检查点: {check_url})") while True: try: resp = requests.get(check_url) if resp.status_code == 200: print("✅ 模型服务已就绪!") return True except requests.RequestException: pass if time.time() - start_time > timeout: print(f"❌ 模型服务启动超时(超过 {timeout} 秒)。") return False time.sleep(CHECK_INTERVAL) print(f" ...仍在等待...") def load_and_resize_image(image_path, max_pixels): """加载并根据需要缩放图像,然后返回 base64 编码的字符串""" with Image.open(image_path) as img: if img.mode != "RGB": img = img.convert("RGB") w, h = img.size if w * h > max_pixels: ratio = (max_pixels / (w * h)) ** 0.5 new_w, new_h = int(w * ratio), int(h * ratio) img = img.resize((new_w, new_h), Image.LANCZOS) buffer = BytesIO() img.save(buffer, format="JPEG") return base64.b64encode(buffer.getvalue()).decode("utf-8") def generate_caption(args): """调用 API 为单个图片生成 caption""" image_path, prompt, api_url, max_pixels, model_name = args txt_path = Path(image_path).with_suffix(".txt") if txt_path.exists() and txt_path.stat().st_size > 300: return f"✅ 已跳过 (caption 已存在): {txt_path.name}" try: base64_image = load_and_resize_image(image_path, max_pixels) # --- 这是符合 vLLM 多模态聊天补全 API 的正确 payload 格式 --- payload = { "messages": [ { "role": "user", "content": [ {"type": "text", "text": prompt}, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}" } } ] } ], # "max_tokens": 1024, # 可以根据需要调整 # "temperature": 0.1, # 可以根据需要调整 } response = requests.post( api_url, headers={"Content-Type": "application/json"}, data=json.dumps(payload) ) if response.status_code == 200: result = response.json() # 从聊天补全的响应中提取内容 caption = result.get("choices", [{}])[0].get("message", {}).get("content", "").strip() print(image_path) print(caption) if caption: with open(txt_path, "w", encoding="utf-8") as f: f.write(caption) return f"✅ 成功生成: {txt_path.name}" else: return f"⚠️ 生成内容为空: {Path(image_path).name}" else: return f"⚠️ 生成失败: {Path(image_path).name}, 状态码: {response.status_code}, 响应: {response.text}" except Exception as e: return f"❌ 发生异常: {Path(image_path).name}, 错误: {str(e)}" def collect_images(image_dir, extensions=(".jpg", ".jpeg", ".png", ".webp")): """递归收集所有图片文件的路径""" image_paths = [] print(f"🔍 正在从 '{image_dir}' 目录中收集图片...") for root, _, files in os.walk(image_dir): for file in files: if file.lower().endswith(extensions): image_paths.append(os.path.join(root, file)) return image_paths def main(): parser = argparse.ArgumentParser(description="为图片目录生成 caption (使用 vLLM 托管的多模态模型)") parser.add_argument("--model-path", type=str, default=DEFAULT_MODEL_PATH, help="vLLM 加载的本地模型路径或 HuggingFace 名称") parser.add_argument("--image-dir", type=str, default=DEFAULT_IMAGE_DIR, help="图片目录路径") parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="生成 caption 的提示词") parser.add_argument("--api-url", type=str, default=DEFAULT_API_URL, help="vLLM 的聊天补全 API 地址") parser.add_argument("--max-pixels", type=int, default=DEFAULT_MAX_PIXELS, help="图片最大像素数,超过此值会按比例缩放") parser.add_argument("--num-process", type=int, default=18, help="用于处理图片的并发进程数") args = parser.parse_args() # --- 检查模型服务是否已运行 --- if is_server_running(args.api_url): print("✅ 检测到模型服务已在运行,直接使用。") else: print("ℹ️ 未检测到正在运行的模型服务,现在尝试启动...") # 在后台启动 vLLM 模型服务 # 注意:这里的 --model 参数直接使用了 args.model_path,它将被用作 API 请求中的模型名称 command = f"nohup vllm serve {args.model_path} --max_model_len 3072 --trust-remote-code > /tmp/vllm.log 2>&1 &" print(f"🚀 执行启动命令: {command}") os.system(command) # 轮询检测模型是否启动完成 if not wait_for_model_ready(args.api_url): print("❌ 模型启动失败,请检查 /tmp/vllm_caption.log 文件获取错误详情。程序退出。") exit(1) # 收集所有图片文件 image_paths = collect_images(args.image_dir) if not image_paths: print("⚠️ 在指定目录中没有找到任何图片。") return random.shuffle(image_paths) print(f"📸 找到 {len(image_paths)} 张图片,准备开始处理...") # 多进程处理图像 # 将模型路径(作为模型名称)传递给处理函数 pool_args = [(img, args.prompt, args.api_url, args.max_pixels, args.model_path) for img in image_paths] # 使用 args.num_process 控制并发数 with Pool(args.num_process) as pool: # 使用 tqdm 显示进度条 for result in tqdm(pool.imap_unordered(generate_caption, pool_args), total=len(image_paths), desc="处理进度"): # 只打印非成功的消息,避免刷屏 if not result.startswith("✅"): print(result) print("\n🎉 全部处理完成!") if __name__ == "__main__": main()