Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import clip | |
| import numpy as np | |
| from PIL import Image | |
| from typing import Dict, List, Tuple, Any, Optional, Union | |
| from clip_prompts import ( | |
| SCENE_TYPE_PROMPTS, | |
| CULTURAL_SCENE_PROMPTS, | |
| COMPARATIVE_PROMPTS, | |
| LIGHTING_CONDITION_PROMPTS, | |
| SPECIALIZED_SCENE_PROMPTS, | |
| VIEWPOINT_PROMPTS, | |
| OBJECT_COMBINATION_PROMPTS, | |
| ACTIVITY_PROMPTS | |
| ) | |
| class CLIPAnalyzer: | |
| """ | |
| Use Clip to intergrate scene understanding function | |
| """ | |
| def __init__(self, model_name: str = "ViT-B/32", device: str = None): | |
| """ | |
| 初始化 CLIP 分析器。 | |
| Args: | |
| model_name: CLIP Model name, "ViT-B/32"、"ViT-B/16"、"ViT-L/14" | |
| device: Use GPU if it can use | |
| """ | |
| # 自動選擇設備 | |
| if device is None: | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| else: | |
| self.device = device | |
| print(f"Loading CLIP model {model_name} on {self.device}...") | |
| try: | |
| self.model, self.preprocess = clip.load(model_name, device=self.device) | |
| print(f"CLIP model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading CLIP model: {e}") | |
| raise | |
| self.scene_type_prompts = SCENE_TYPE_PROMPTS | |
| self.cultural_scene_prompts = CULTURAL_SCENE_PROMPTS | |
| self.comparative_prompts = COMPARATIVE_PROMPTS | |
| self.lighting_condition_prompts = LIGHTING_CONDITION_PROMPTS | |
| self.specialized_scene_prompts = SPECIALIZED_SCENE_PROMPTS | |
| self.viewpoint_prompts = VIEWPOINT_PROMPTS | |
| self.object_combination_prompts = OBJECT_COMBINATION_PROMPTS | |
| self.activity_prompts = ACTIVITY_PROMPTS | |
| # turn to CLIP format | |
| self._prepare_text_prompts() | |
| def _prepare_text_prompts(self): | |
| """準備所有文本提示的 CLIP 特徵""" | |
| # base prompt | |
| scene_texts = [self.scene_type_prompts[scene_type] for scene_type in self.scene_type_prompts] | |
| self.scene_type_tokens = clip.tokenize(scene_texts).to(self.device) | |
| # cultural | |
| self.cultural_tokens_dict = {} | |
| for scene_type, prompts in self.cultural_scene_prompts.items(): | |
| self.cultural_tokens_dict[scene_type] = clip.tokenize(prompts).to(self.device) | |
| # Light | |
| lighting_texts = [self.lighting_condition_prompts[cond] for cond in self.lighting_condition_prompts] | |
| self.lighting_tokens = clip.tokenize(lighting_texts).to(self.device) | |
| # specializes_status | |
| self.specialized_tokens_dict = {} | |
| for scene_type, prompts in self.specialized_scene_prompts.items(): | |
| self.specialized_tokens_dict[scene_type] = clip.tokenize(prompts).to(self.device) | |
| # view point | |
| viewpoint_texts = [self.viewpoint_prompts[viewpoint] for viewpoint in self.viewpoint_prompts] | |
| self.viewpoint_tokens = clip.tokenize(viewpoint_texts).to(self.device) | |
| # object combination | |
| object_combination_texts = [self.object_combination_prompts[combo] for combo in self.object_combination_prompts] | |
| self.object_combination_tokens = clip.tokenize(object_combination_texts).to(self.device) | |
| # activicty prompt | |
| activity_texts = [self.activity_prompts[activity] for activity in self.activity_prompts] | |
| self.activity_tokens = clip.tokenize(activity_texts).to(self.device) | |
| def analyze_image(self, image, include_cultural_analysis: bool = True) -> Dict[str, Any]: | |
| """ | |
| 分析圖像,預測場景類型和光照條件。 | |
| Args: | |
| image: 輸入圖像 (PIL Image 或 numpy array) | |
| include_cultural_analysis: 是否包含文化場景的詳細分析 | |
| Returns: | |
| Dict: 包含場景類型預測和光照條件的分析結果 | |
| """ | |
| try: | |
| # 確保圖像是 PIL 格式 | |
| if not isinstance(image, Image.Image): | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| else: | |
| raise ValueError("Unsupported image format. Expected PIL Image or numpy array.") | |
| # 預處理圖像 | |
| image_input = self.preprocess(image).unsqueeze(0).to(self.device) | |
| # 獲取圖像特徵 | |
| with torch.no_grad(): | |
| image_features = self.model.encode_image(image_input) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| # 分析場景類型 | |
| scene_scores = self._analyze_scene_type(image_features) | |
| # 分析光照條件 | |
| lighting_scores = self._analyze_lighting_condition(image_features) | |
| # 文化場景的增強分析 | |
| cultural_analysis = {} | |
| if include_cultural_analysis: | |
| for scene_type in self.cultural_scene_prompts: | |
| if scene_type in scene_scores and scene_scores[scene_type] > 0.2: | |
| cultural_analysis[scene_type] = self._analyze_cultural_scene( | |
| image_features, scene_type | |
| ) | |
| specialized_analysis = {} | |
| for scene_type in self.specialized_scene_prompts: | |
| if scene_type in scene_scores and scene_scores[scene_type] > 0.2: | |
| specialized_analysis[scene_type] = self._analyze_specialized_scene( | |
| image_features, scene_type | |
| ) | |
| viewpoint_scores = self._analyze_viewpoint(image_features) | |
| object_combination_scores = self._analyze_object_combinations(image_features) | |
| activity_scores = self._analyze_activities(image_features) | |
| # display results | |
| result = { | |
| "scene_scores": scene_scores, | |
| "top_scene": max(scene_scores.items(), key=lambda x: x[1]), | |
| "lighting_condition": max(lighting_scores.items(), key=lambda x: x[1]), | |
| "embedding": image_features.cpu().numpy().tolist()[0] if self.device == "cuda" else image_features.numpy().tolist()[0], | |
| "viewpoint": max(viewpoint_scores.items(), key=lambda x: x[1]), | |
| "object_combinations": sorted(object_combination_scores.items(), key=lambda x: x[1], reverse=True)[:3], | |
| "activities": sorted(activity_scores.items(), key=lambda x: x[1], reverse=True)[:3] | |
| } | |
| if cultural_analysis: | |
| result["cultural_analysis"] = cultural_analysis | |
| if specialized_analysis: | |
| result["specialized_analysis"] = specialized_analysis | |
| return result | |
| except Exception as e: | |
| print(f"Error analyzing image with CLIP: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": str(e)} | |
| def _analyze_scene_type(self, image_features: torch.Tensor) -> Dict[str, float]: | |
| """分析圖像特徵與各場景類型的相似度""" | |
| with torch.no_grad(): | |
| # 計算場景類型文本特徵 | |
| text_features = self.model.encode_text(self.scene_type_tokens) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| # 計算相似度分數 | |
| similarity = (100 * image_features @ text_features.T).softmax(dim=-1) | |
| similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
| # 建立場景分數字典 | |
| scene_scores = {} | |
| for i, scene_type in enumerate(self.scene_type_prompts.keys()): | |
| scene_scores[scene_type] = float(similarity[i]) | |
| return scene_scores | |
| def _analyze_lighting_condition(self, image_features: torch.Tensor) -> Dict[str, float]: | |
| """分析圖像的光照條件""" | |
| with torch.no_grad(): | |
| # 計算光照條件文本特徵 | |
| text_features = self.model.encode_text(self.lighting_tokens) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| # 計算相似度分數 | |
| similarity = (100 * image_features @ text_features.T).softmax(dim=-1) | |
| similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
| # 建立光照條件分數字典 | |
| lighting_scores = {} | |
| for i, lighting_type in enumerate(self.lighting_condition_prompts.keys()): | |
| lighting_scores[lighting_type] = float(similarity[i]) | |
| return lighting_scores | |
| def _analyze_cultural_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]: | |
| """針對特定文化場景進行深入分析""" | |
| if scene_type not in self.cultural_tokens_dict: | |
| return {"error": f"No cultural analysis available for {scene_type}"} | |
| with torch.no_grad(): | |
| # 獲取特定文化場景的文本特徵 | |
| cultural_tokens = self.cultural_tokens_dict[scene_type] | |
| text_features = self.model.encode_text(cultural_tokens) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| # 計算相似度分數 | |
| similarity = (100 * image_features @ text_features.T) | |
| similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
| # 找到最匹配的文化描述 | |
| prompts = self.cultural_scene_prompts[scene_type] | |
| scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))] | |
| scores.sort(key=lambda x: x[1], reverse=True) | |
| return { | |
| "best_description": scores[0][0], | |
| "confidence": scores[0][1], | |
| "all_matches": scores | |
| } | |
| def _analyze_specialized_scene(self, image_features: torch.Tensor, scene_type: str) -> Dict[str, Any]: | |
| """針對特定專門場景進行深入分析""" | |
| if scene_type not in self.specialized_tokens_dict: | |
| return {"error": f"No specialized analysis available for {scene_type}"} | |
| with torch.no_grad(): | |
| # 獲取特定專門場景的文本特徵 | |
| specialized_tokens = self.specialized_tokens_dict[scene_type] | |
| text_features = self.model.encode_text(specialized_tokens) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| # 計算相似度分數 | |
| similarity = (100 * image_features @ text_features.T) | |
| similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
| # 找到最匹配的專門描述 | |
| prompts = self.specialized_scene_prompts[scene_type] | |
| scores = [(prompts[i], float(similarity[i])) for i in range(len(prompts))] | |
| scores.sort(key=lambda x: x[1], reverse=True) | |
| return { | |
| "best_description": scores[0][0], | |
| "confidence": scores[0][1], | |
| "all_matches": scores | |
| } | |
| def _analyze_viewpoint(self, image_features: torch.Tensor) -> Dict[str, float]: | |
| """分析圖像的拍攝視角""" | |
| with torch.no_grad(): | |
| # 計算視角文本特徵 | |
| text_features = self.model.encode_text(self.viewpoint_tokens) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| # 計算相似度分數 | |
| similarity = (100 * image_features @ text_features.T).softmax(dim=-1) | |
| similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
| # 建立視角分數字典 | |
| viewpoint_scores = {} | |
| for i, viewpoint in enumerate(self.viewpoint_prompts.keys()): | |
| viewpoint_scores[viewpoint] = float(similarity[i]) | |
| return viewpoint_scores | |
| def _analyze_object_combinations(self, image_features: torch.Tensor) -> Dict[str, float]: | |
| """分析圖像中的物體組合""" | |
| with torch.no_grad(): | |
| # 計算物體組合文本特徵 | |
| text_features = self.model.encode_text(self.object_combination_tokens) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| # 計算相似度分數 | |
| similarity = (100 * image_features @ text_features.T).softmax(dim=-1) | |
| similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
| # 建立物體組合分數字典 | |
| combination_scores = {} | |
| for i, combination in enumerate(self.object_combination_prompts.keys()): | |
| combination_scores[combination] = float(similarity[i]) | |
| return combination_scores | |
| def _analyze_activities(self, image_features: torch.Tensor) -> Dict[str, float]: | |
| """分析圖像中的活動""" | |
| with torch.no_grad(): | |
| # 計算活動文本特徵 | |
| text_features = self.model.encode_text(self.activity_tokens) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| # 計算相似度分數 | |
| similarity = (100 * image_features @ text_features.T).softmax(dim=-1) | |
| similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
| # 建立活動分數字典 | |
| activity_scores = {} | |
| for i, activity in enumerate(self.activity_prompts.keys()): | |
| activity_scores[activity] = float(similarity[i]) | |
| return activity_scores | |
| def get_image_embedding(self, image) -> np.ndarray: | |
| """ | |
| 獲取圖像的 CLIP 嵌入表示 | |
| Args: | |
| image: PIL Image 或 numpy array | |
| Returns: | |
| np.ndarray: 圖像的 CLIP 特徵向量 | |
| """ | |
| # 確保圖像是 PIL 格式 | |
| if not isinstance(image, Image.Image): | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| else: | |
| raise ValueError("Unsupported image format. Expected PIL Image or numpy array.") | |
| # 預處理並編碼 | |
| image_input = self.preprocess(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| image_features = self.model.encode_image(image_input) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| # 轉換為 numpy 並返回 | |
| return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0] | |
| def text_to_embedding(self, text: str) -> np.ndarray: | |
| """ | |
| 將文本轉換為 CLIP 嵌入表示 | |
| Args: | |
| text: 輸入文本 | |
| Returns: | |
| np.ndarray: 文本的 CLIP 特徵向量 | |
| """ | |
| text_token = clip.tokenize([text]).to(self.device) | |
| with torch.no_grad(): | |
| text_features = self.model.encode_text(text_token) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| return text_features.cpu().numpy()[0] if self.device == "cuda" else text_features.numpy()[0] | |
| def calculate_similarity(self, image, text_queries: List[str]) -> Dict[str, float]: | |
| """ | |
| 計算圖像與多個文本查詢的相似度 | |
| Args: | |
| image: PIL Image 或 numpy array | |
| text_queries: 文本查詢列表 | |
| Returns: | |
| Dict: 每個查詢的相似度分數 | |
| """ | |
| # 獲取圖像嵌入 | |
| if isinstance(image, np.ndarray) and len(image.shape) == 1: | |
| # 已經是嵌入向量 | |
| image_features = torch.tensor(image).unsqueeze(0).to(self.device) | |
| else: | |
| # 是圖像,需要提取嵌入 | |
| image_features = torch.tensor(self.get_image_embedding(image)).unsqueeze(0).to(self.device) | |
| # calulate similarity | |
| text_tokens = clip.tokenize(text_queries).to(self.device) | |
| with torch.no_grad(): | |
| text_features = self.model.encode_text(text_tokens) | |
| text_features = text_features / text_features.norm(dim=-1, keepdim=True) | |
| similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
| similarity = similarity.cpu().numpy()[0] if self.device == "cuda" else similarity.numpy()[0] | |
| # display results | |
| result = {} | |
| for i, query in enumerate(text_queries): | |
| result[query] = float(similarity[i]) | |
| return result | |