Spaces:
Running
on
Zero
Running
on
Zero
Upload 8 files
Browse files- clip_analyzer.py +24 -27
- clip_model_manager.py +15 -18
- clip_prompts.py +128 -5
- clip_zero_shot_classifier.py +4 -4
- llm_enhancer.py +2 -2
- llm_model_manager.py +358 -0
- requirements.txt +2 -2
- scene_scoring_engine.py +9 -8
clip_analyzer.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import torch
|
| 2 |
-
import
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
| 5 |
from typing import Dict, List, Tuple, Any, Optional, Union
|
|
@@ -20,13 +20,14 @@ class CLIPAnalyzer:
|
|
| 20 |
Use Clip to intergrate scene understanding function
|
| 21 |
"""
|
| 22 |
|
| 23 |
-
def __init__(self, model_name: str = "ViT-B
|
| 24 |
"""
|
| 25 |
-
初始化 CLIP
|
| 26 |
|
| 27 |
Args:
|
| 28 |
-
model_name:
|
| 29 |
-
device:
|
|
|
|
| 30 |
"""
|
| 31 |
# 自動選擇設備
|
| 32 |
if device is None:
|
|
@@ -34,12 +35,17 @@ class CLIPAnalyzer:
|
|
| 34 |
else:
|
| 35 |
self.device = device
|
| 36 |
|
| 37 |
-
print(f"Loading
|
| 38 |
try:
|
| 39 |
-
self.model, self.preprocess =
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
except Exception as e:
|
| 42 |
-
print(f"Error loading
|
| 43 |
raise
|
| 44 |
|
| 45 |
self.scene_type_prompts = SCENE_TYPE_PROMPTS
|
|
@@ -64,7 +70,7 @@ class CLIPAnalyzer:
|
|
| 64 |
if scene_texts:
|
| 65 |
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
|
| 66 |
try:
|
| 67 |
-
self.text_features_cache["scene_type_tokens"] =
|
| 68 |
except Exception as e:
|
| 69 |
print(f"Warning: Error tokenizing scene_type_prompts: {e}")
|
| 70 |
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
|
|
@@ -82,7 +88,7 @@ class CLIPAnalyzer:
|
|
| 82 |
for scene_type, prompts in self.cultural_scene_prompts.items():
|
| 83 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
| 84 |
try:
|
| 85 |
-
cultural_tokens_dict_val[scene_type] =
|
| 86 |
except Exception as e:
|
| 87 |
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
|
| 88 |
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
|
|
@@ -96,7 +102,7 @@ class CLIPAnalyzer:
|
|
| 96 |
if lighting_texts:
|
| 97 |
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
|
| 98 |
try:
|
| 99 |
-
self.text_features_cache["lighting_tokens"] =
|
| 100 |
except Exception as e:
|
| 101 |
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
|
| 102 |
self.text_features_cache["lighting_tokens"] = None
|
|
@@ -113,7 +119,7 @@ class CLIPAnalyzer:
|
|
| 113 |
for scene_type, prompts in self.specialized_scene_prompts.items():
|
| 114 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
| 115 |
try:
|
| 116 |
-
specialized_tokens_dict_val[scene_type] =
|
| 117 |
except Exception as e:
|
| 118 |
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
|
| 119 |
specialized_tokens_dict_val[scene_type] = None
|
|
@@ -127,7 +133,7 @@ class CLIPAnalyzer:
|
|
| 127 |
if viewpoint_texts:
|
| 128 |
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
|
| 129 |
try:
|
| 130 |
-
self.text_features_cache["viewpoint_tokens"] =
|
| 131 |
except Exception as e:
|
| 132 |
print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
|
| 133 |
self.text_features_cache["viewpoint_tokens"] = None
|
|
@@ -144,7 +150,7 @@ class CLIPAnalyzer:
|
|
| 144 |
if object_combination_texts:
|
| 145 |
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
|
| 146 |
try:
|
| 147 |
-
self.text_features_cache["object_combination_tokens"] =
|
| 148 |
except Exception as e:
|
| 149 |
print(f"Warning: Error tokenizing object_combination_prompts: {e}")
|
| 150 |
self.text_features_cache["object_combination_tokens"] = None
|
|
@@ -161,7 +167,7 @@ class CLIPAnalyzer:
|
|
| 161 |
if activity_texts:
|
| 162 |
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
|
| 163 |
try:
|
| 164 |
-
self.text_features_cache["activity_tokens"] =
|
| 165 |
except Exception as e:
|
| 166 |
print(f"Warning: Error tokenizing activity_prompts: {e}")
|
| 167 |
self.text_features_cache["activity_tokens"] = None
|
|
@@ -180,7 +186,7 @@ class CLIPAnalyzer:
|
|
| 180 |
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
|
| 181 |
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
|
| 182 |
|
| 183 |
-
print("
|
| 184 |
|
| 185 |
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
|
| 186 |
"""
|
|
@@ -581,16 +587,7 @@ class CLIPAnalyzer:
|
|
| 581 |
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
|
| 582 |
|
| 583 |
def text_to_embedding(self, text: str) -> np.ndarray:
|
| 584 |
-
|
| 585 |
-
將文本轉換為 CLIP 嵌入表示
|
| 586 |
-
|
| 587 |
-
Args:
|
| 588 |
-
text: 輸入文本
|
| 589 |
-
|
| 590 |
-
Returns:
|
| 591 |
-
np.ndarray: 文本的 CLIP 特徵向量
|
| 592 |
-
"""
|
| 593 |
-
text_token = clip.tokenize([text]).to(self.device)
|
| 594 |
|
| 595 |
with torch.no_grad():
|
| 596 |
text_features = self.model.encode_text(text_token)
|
|
|
|
| 1 |
import torch
|
| 2 |
+
import open_clip
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
| 5 |
from typing import Dict, List, Tuple, Any, Optional, Union
|
|
|
|
| 20 |
Use Clip to intergrate scene understanding function
|
| 21 |
"""
|
| 22 |
|
| 23 |
+
def __init__(self, model_name: str = "ViT-B-16", device: str = None, pretrained: str = "laion2b_s34b_b88k"):
|
| 24 |
"""
|
| 25 |
+
初始化 CLIP 分析器,使用 OpenCLIP 實現
|
| 26 |
|
| 27 |
Args:
|
| 28 |
+
model_name: OpenCLIP 模型名稱,默認 "ViT-B-16"
|
| 29 |
+
device: 運行設備
|
| 30 |
+
pretrained: 預訓練權重,使用 "laion2b_s34b_b79k"
|
| 31 |
"""
|
| 32 |
# 自動選擇設備
|
| 33 |
if device is None:
|
|
|
|
| 35 |
else:
|
| 36 |
self.device = device
|
| 37 |
|
| 38 |
+
print(f"Loading OpenCLIP model {model_name} with {pretrained} on {self.device}...")
|
| 39 |
try:
|
| 40 |
+
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
| 41 |
+
model_name,
|
| 42 |
+
pretrained=pretrained,
|
| 43 |
+
device=self.device
|
| 44 |
+
)
|
| 45 |
+
self.tokenizer = open_clip.get_tokenizer(model_name)
|
| 46 |
+
print(f"OpenCLIP model loaded successfully.")
|
| 47 |
except Exception as e:
|
| 48 |
+
print(f"Error loading OpenCLIP model: {e}")
|
| 49 |
raise
|
| 50 |
|
| 51 |
self.scene_type_prompts = SCENE_TYPE_PROMPTS
|
|
|
|
| 70 |
if scene_texts:
|
| 71 |
self.text_features_cache["scene_type_keys"] = list(self.scene_type_prompts.keys())
|
| 72 |
try:
|
| 73 |
+
self.text_features_cache["scene_type_tokens"] = self.tokenizer(scene_texts).to(self.device)
|
| 74 |
except Exception as e:
|
| 75 |
print(f"Warning: Error tokenizing scene_type_prompts: {e}")
|
| 76 |
self.text_features_cache["scene_type_tokens"] = None # 標記錯誤或空
|
|
|
|
| 88 |
for scene_type, prompts in self.cultural_scene_prompts.items():
|
| 89 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
| 90 |
try:
|
| 91 |
+
cultural_tokens_dict_val[scene_type] = self.tokenizer(prompts).to(self.device)
|
| 92 |
except Exception as e:
|
| 93 |
print(f"Warning: Error tokenizing cultural_scene_prompts for {scene_type}: {e}")
|
| 94 |
cultural_tokens_dict_val[scene_type] = None # 標記錯誤或空
|
|
|
|
| 102 |
if lighting_texts:
|
| 103 |
self.text_features_cache["lighting_condition_keys"] = list(self.lighting_condition_prompts.keys())
|
| 104 |
try:
|
| 105 |
+
self.text_features_cache["lighting_tokens"] = self.tokenizer(lighting_texts).to(self.device)
|
| 106 |
except Exception as e:
|
| 107 |
print(f"Warning: Error tokenizing lighting_condition_prompts: {e}")
|
| 108 |
self.text_features_cache["lighting_tokens"] = None
|
|
|
|
| 119 |
for scene_type, prompts in self.specialized_scene_prompts.items():
|
| 120 |
if prompts and isinstance(prompts, list) and all(isinstance(p, str) for p in prompts):
|
| 121 |
try:
|
| 122 |
+
specialized_tokens_dict_val[scene_type] = self.tokenizer(prompts).to(self.device)
|
| 123 |
except Exception as e:
|
| 124 |
print(f"Warning: Error tokenizing specialized_scene_prompts for {scene_type}: {e}")
|
| 125 |
specialized_tokens_dict_val[scene_type] = None
|
|
|
|
| 133 |
if viewpoint_texts:
|
| 134 |
self.text_features_cache["viewpoint_keys"] = list(self.viewpoint_prompts.keys())
|
| 135 |
try:
|
| 136 |
+
self.text_features_cache["viewpoint_tokens"] = self.tokenizer(viewpoint_texts).to(self.device)
|
| 137 |
except Exception as e:
|
| 138 |
print(f"Warning: Error tokenizing viewpoint_prompts: {e}")
|
| 139 |
self.text_features_cache["viewpoint_tokens"] = None
|
|
|
|
| 150 |
if object_combination_texts:
|
| 151 |
self.text_features_cache["object_combination_keys"] = list(self.object_combination_prompts.keys())
|
| 152 |
try:
|
| 153 |
+
self.text_features_cache["object_combination_tokens"] = self.tokenizer(object_combination_texts).to(self.device)
|
| 154 |
except Exception as e:
|
| 155 |
print(f"Warning: Error tokenizing object_combination_prompts: {e}")
|
| 156 |
self.text_features_cache["object_combination_tokens"] = None
|
|
|
|
| 167 |
if activity_texts:
|
| 168 |
self.text_features_cache["activity_keys"] = list(self.activity_prompts.keys())
|
| 169 |
try:
|
| 170 |
+
self.text_features_cache["activity_tokens"] = self.tokenizer(activity_texts).to(self.device)
|
| 171 |
except Exception as e:
|
| 172 |
print(f"Warning: Error tokenizing activity_prompts: {e}")
|
| 173 |
self.text_features_cache["activity_tokens"] = None
|
|
|
|
| 186 |
self.cultural_tokens_dict = self.text_features_cache["cultural_tokens_dict"]
|
| 187 |
self.specialized_tokens_dict = self.text_features_cache["specialized_tokens_dict"]
|
| 188 |
|
| 189 |
+
print("OpenCLIP text_features_cache prepared.")
|
| 190 |
|
| 191 |
def analyze_image(self, image, include_cultural_analysis=True, exclude_categories=None, enable_landmark=True, places365_guidance=None):
|
| 192 |
"""
|
|
|
|
| 587 |
return image_features.cpu().numpy()[0] if self.device == "cuda" else image_features.numpy()[0]
|
| 588 |
|
| 589 |
def text_to_embedding(self, text: str) -> np.ndarray:
|
| 590 |
+
text_token = self.tokenizer([text]).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
|
| 592 |
with torch.no_grad():
|
| 593 |
text_features = self.model.encode_text(text_token)
|
clip_model_manager.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
-
import
|
| 4 |
import numpy as np
|
| 5 |
import logging
|
| 6 |
import traceback
|
|
@@ -12,7 +12,7 @@ class CLIPModelManager:
|
|
| 12 |
專門管理 CLIP 模型相關的操作,包括模型載入、設備管理、圖像和文本的特徵編碼等核心功能
|
| 13 |
"""
|
| 14 |
|
| 15 |
-
def __init__(self, model_name: str = "ViT-B
|
| 16 |
"""
|
| 17 |
初始化 CLIP 模型管理器
|
| 18 |
|
|
@@ -22,6 +22,8 @@ class CLIPModelManager:
|
|
| 22 |
"""
|
| 23 |
self.logger = logging.getLogger(__name__)
|
| 24 |
self.model_name = model_name
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# 設置運行設備
|
| 27 |
if device is None:
|
|
@@ -29,19 +31,23 @@ class CLIPModelManager:
|
|
| 29 |
else:
|
| 30 |
self.device = device
|
| 31 |
|
| 32 |
-
self.model = None
|
| 33 |
self.preprocess = None
|
| 34 |
|
| 35 |
self._initialize_model()
|
| 36 |
|
| 37 |
def _initialize_model(self):
|
| 38 |
"""
|
| 39 |
-
初始化
|
| 40 |
"""
|
| 41 |
try:
|
| 42 |
-
self.logger.info(f"Initializing
|
| 43 |
-
self.model, self.preprocess =
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
except Exception as e:
|
| 46 |
self.logger.error(f"Error loading CLIP model: {e}")
|
| 47 |
self.logger.error(traceback.format_exc())
|
|
@@ -87,7 +93,7 @@ class CLIPModelManager:
|
|
| 87 |
|
| 88 |
for i in range(0, len(text_prompts), batch_size):
|
| 89 |
batch_prompts = text_prompts[i:i+batch_size]
|
| 90 |
-
text_tokens =
|
| 91 |
batch_features = self.model.encode_text(text_tokens)
|
| 92 |
batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
|
| 93 |
features_list.append(batch_features)
|
|
@@ -106,18 +112,9 @@ class CLIPModelManager:
|
|
| 106 |
raise
|
| 107 |
|
| 108 |
def encode_single_text(self, text_prompts: List[str]) -> torch.Tensor:
|
| 109 |
-
"""
|
| 110 |
-
編碼單個文本批次的特徵
|
| 111 |
-
|
| 112 |
-
Args:
|
| 113 |
-
text_prompts: 文本提示列表
|
| 114 |
-
|
| 115 |
-
Returns:
|
| 116 |
-
torch.Tensor: 標準化後的文本特徵
|
| 117 |
-
"""
|
| 118 |
try:
|
| 119 |
with torch.no_grad():
|
| 120 |
-
text_tokens =
|
| 121 |
text_features = self.model.encode_text(text_tokens)
|
| 122 |
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 123 |
return text_features
|
|
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
+
import open_clip
|
| 4 |
import numpy as np
|
| 5 |
import logging
|
| 6 |
import traceback
|
|
|
|
| 12 |
專門管理 CLIP 模型相關的操作,包括模型載入、設備管理、圖像和文本的特徵編碼等核心功能
|
| 13 |
"""
|
| 14 |
|
| 15 |
+
def __init__(self, model_name: str = "ViT-B-16", device: str = None, pretrained: str = "laion2b_s34b_b88k"):
|
| 16 |
"""
|
| 17 |
初始化 CLIP 模型管理器
|
| 18 |
|
|
|
|
| 22 |
"""
|
| 23 |
self.logger = logging.getLogger(__name__)
|
| 24 |
self.model_name = model_name
|
| 25 |
+
self.pretrained = pretrained
|
| 26 |
+
self.tokenizer = None
|
| 27 |
|
| 28 |
# 設置運行設備
|
| 29 |
if device is None:
|
|
|
|
| 31 |
else:
|
| 32 |
self.device = device
|
| 33 |
|
|
|
|
| 34 |
self.preprocess = None
|
| 35 |
|
| 36 |
self._initialize_model()
|
| 37 |
|
| 38 |
def _initialize_model(self):
|
| 39 |
"""
|
| 40 |
+
初始化OpenCLIP模型
|
| 41 |
"""
|
| 42 |
try:
|
| 43 |
+
self.logger.info(f"Initializing OpenCLIP model ({self.model_name}) with {self.pretrained} on {self.device}")
|
| 44 |
+
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
|
| 45 |
+
self.model_name,
|
| 46 |
+
pretrained=self.pretrained,
|
| 47 |
+
device=self.device
|
| 48 |
+
)
|
| 49 |
+
self.tokenizer = open_clip.get_tokenizer(self.model_name)
|
| 50 |
+
self.logger.info("Successfully loaded OpenCLIP model")
|
| 51 |
except Exception as e:
|
| 52 |
self.logger.error(f"Error loading CLIP model: {e}")
|
| 53 |
self.logger.error(traceback.format_exc())
|
|
|
|
| 93 |
|
| 94 |
for i in range(0, len(text_prompts), batch_size):
|
| 95 |
batch_prompts = text_prompts[i:i+batch_size]
|
| 96 |
+
text_tokens = self.tokenizer(batch_prompts).to(self.device)
|
| 97 |
batch_features = self.model.encode_text(text_tokens)
|
| 98 |
batch_features = batch_features / batch_features.norm(dim=-1, keepdim=True)
|
| 99 |
features_list.append(batch_features)
|
|
|
|
| 112 |
raise
|
| 113 |
|
| 114 |
def encode_single_text(self, text_prompts: List[str]) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
try:
|
| 116 |
with torch.no_grad():
|
| 117 |
+
text_tokens = self.tokenizer(text_prompts).to(self.device)
|
| 118 |
text_features = self.model.encode_text(text_tokens)
|
| 119 |
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 120 |
return text_features
|
clip_prompts.py
CHANGED
|
@@ -69,7 +69,49 @@ SCENE_TYPE_PROMPTS = {
|
|
| 69 |
"construction_site": "A photo of a construction site with building materials, equipment and workers.",
|
| 70 |
"medical_facility": "A photo of a medical facility with healthcare equipment and professional staff.",
|
| 71 |
"educational_setting": "A photo of an educational setting with learning spaces and academic resources.",
|
| 72 |
-
"professional_kitchen": "A photo of a professional commercial kitchen with industrial cooking equipment and food preparation stations."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
}
|
| 74 |
|
| 75 |
# 文化特定場景提示
|
|
@@ -151,6 +193,30 @@ COMPARATIVE_PROMPTS = {
|
|
| 151 |
"A street-level view showing pedestrian perspective and immediate surroundings.",
|
| 152 |
"A bird's-eye view of city organization and movement patterns from high above.",
|
| 153 |
"An eye-level perspective showing direct human interaction with urban elements."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
]
|
| 155 |
}
|
| 156 |
|
|
@@ -170,7 +236,16 @@ LIGHTING_CONDITION_PROMPTS = {
|
|
| 170 |
"mixed_lighting": "A scene with combined natural and artificial light sources creating transition zones.",
|
| 171 |
"beach_daylight": "A photo taken at a beach with bright natural sunlight and reflections from water.",
|
| 172 |
"sports_arena_lighting": "A photo of a sports venue illuminated by powerful overhead lighting systems.",
|
| 173 |
-
"kitchen_task_lighting": "A photo of a kitchen with focused lighting concentrated on work surfaces."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
}
|
| 175 |
|
| 176 |
# 針對新場景類型的特殊提示
|
|
@@ -228,6 +303,29 @@ SPECIALIZED_SCENE_PROMPTS = {
|
|
| 228 |
"A high-angle view of an intersection showing traffic and pedestrian flow patterns.",
|
| 229 |
"A drone perspective of urban crossing design viewed from directly above.",
|
| 230 |
"A vertical view of a street intersection showing crossing infrastructure."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
]
|
| 232 |
}
|
| 233 |
|
|
@@ -239,7 +337,15 @@ VIEWPOINT_PROMPTS = {
|
|
| 239 |
"bird_eye": "A photo taken from very high above showing a complete overhead perspective.",
|
| 240 |
"street_level": "A photo taken from the perspective of someone standing on the street.",
|
| 241 |
"interior": "A photo taken from inside a building showing the internal environment.",
|
| 242 |
-
"vehicular": "A photo taken from inside or mounted on a moving vehicle."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
}
|
| 244 |
|
| 245 |
OBJECT_COMBINATION_PROMPTS = {
|
|
@@ -250,7 +356,15 @@ OBJECT_COMBINATION_PROMPTS = {
|
|
| 250 |
"retail_environment": "A scene with merchandise displays, shoppers, and store fixtures.",
|
| 251 |
"crosswalk_scene": "A scene with street markings, pedestrians crossing, and traffic signals.",
|
| 252 |
"cooking_area": "A scene with stoves, prep surfaces, cooking utensils, and food items.",
|
| 253 |
-
"recreational_space": "A scene with sports equipment, play areas, and activity participants."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
}
|
| 255 |
|
| 256 |
ACTIVITY_PROMPTS = {
|
|
@@ -261,5 +375,14 @@ ACTIVITY_PROMPTS = {
|
|
| 261 |
"exercising": "People engaged in physical activities, using sports equipment, and training.",
|
| 262 |
"cooking": "People preparing food, using kitchen equipment, and creating meals.",
|
| 263 |
"crossing_street": "People walking across designated crosswalks and navigating intersections.",
|
| 264 |
-
"recreational_activity": "People engaged in leisure activities, games, and social recreation."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
}
|
|
|
|
| 69 |
"construction_site": "A photo of a construction site with building materials, equipment and workers.",
|
| 70 |
"medical_facility": "A photo of a medical facility with healthcare equipment and professional staff.",
|
| 71 |
"educational_setting": "A photo of an educational setting with learning spaces and academic resources.",
|
| 72 |
+
"professional_kitchen": "A photo of a professional commercial kitchen with industrial cooking equipment and food preparation stations.",
|
| 73 |
+
|
| 74 |
+
# 工作空間的多樣化
|
| 75 |
+
"modern_open_office": "A photo of a modern open office with collaborative workspaces, standing desks and contemporary furniture design.",
|
| 76 |
+
"traditional_cubicle_office": "A photo of a traditional office with individual cubicles, separated workstations and formal business environment.",
|
| 77 |
+
"home_office_study": "A photo of a home office or study room with personal workspace setup and residential comfort elements.",
|
| 78 |
+
"creative_workspace": "A photo of a creative workspace with design materials, artistic tools and inspiring work environment.",
|
| 79 |
+
"shared_workspace_hub": "A photo of a shared coworking space with flexible seating, community areas and collaborative atmosphere.",
|
| 80 |
+
|
| 81 |
+
# 用餐空間的情境化
|
| 82 |
+
"casual_family_dining": "A photo of a casual family dining area with comfortable seating and everyday meal setup.",
|
| 83 |
+
"formal_dining_room": "A photo of a formal dining room with elegant table setting and sophisticated dining arrangement.",
|
| 84 |
+
"breakfast_nook_area": "A photo of a cozy breakfast nook with intimate seating and morning meal atmosphere.",
|
| 85 |
+
"outdoor_patio_dining": "A photo of an outdoor patio dining area with weather-resistant furniture and al fresco dining setup.",
|
| 86 |
+
"kitchen_island_dining": "A photo of a kitchen island used for casual dining with bar-style seating and integrated cooking space.",
|
| 87 |
+
|
| 88 |
+
# 生活空間的使用情境
|
| 89 |
+
"family_entertainment_room": "A photo of a family room focused on entertainment with large TV, comfortable seating and recreational atmosphere.",
|
| 90 |
+
"reading_lounge_area": "A photo of a quiet reading area with comfortable chairs, good lighting and book storage.",
|
| 91 |
+
"social_gathering_space": "A photo of a living area arranged for social gatherings with multiple seating options and conversation-friendly layout.",
|
| 92 |
+
"relaxation_living_space": "A photo of a living room designed for relaxation with soft furnishings and calm atmosphere.",
|
| 93 |
+
|
| 94 |
+
# 商業空間的服務導向
|
| 95 |
+
"quick_service_restaurant": "A photo of a quick service restaurant with efficient ordering system and fast-casual dining setup.",
|
| 96 |
+
"coffee_shop_workspace": "A photo of a coffee shop that doubles as workspace with WiFi-friendly seating and laptop users.",
|
| 97 |
+
"boutique_retail_space": "A photo of a boutique retail store with curated merchandise display and personalized shopping experience.",
|
| 98 |
+
"convenience_store_market": "A photo of a convenience store with everyday items, quick shopping layout and accessible product arrangement.",
|
| 99 |
+
|
| 100 |
+
# 學習環境的專業化
|
| 101 |
+
"collaborative_classroom": "A photo of a modern classroom designed for group work with flexible seating and interactive learning setup.",
|
| 102 |
+
"lecture_hall_setting": "A photo of a traditional lecture hall with tiered seating and formal educational presentation setup.",
|
| 103 |
+
"study_hall_library": "A photo of a quiet study area in a library with individual study spaces and academic atmosphere.",
|
| 104 |
+
"computer_lab_classroom": "A photo of a computer lab or digital classroom with technology workstations and learning equipment.",
|
| 105 |
+
|
| 106 |
+
# 用時間判斷
|
| 107 |
+
"morning_routine_kitchen": "A photo of a kitchen during morning routine with breakfast preparation and daily startup activities.",
|
| 108 |
+
"evening_relaxation_living": "A photo of a living room in evening mode with dim lighting and relaxation activities.",
|
| 109 |
+
"weekend_leisure_space": "A photo of a living area during weekend with casual activities and relaxed atmosphere.",
|
| 110 |
+
|
| 111 |
+
# 活動強度的描述
|
| 112 |
+
"busy_work_environment": "A photo of an active workplace with multiple people engaged in work tasks and productive atmosphere.",
|
| 113 |
+
"quiet_study_atmosphere": "A photo of a peaceful study or work environment with focused activity and minimal distractions.",
|
| 114 |
+
"social_interaction_space": "A photo of a space designed for social interaction with multiple people engaging in conversation."
|
| 115 |
}
|
| 116 |
|
| 117 |
# 文化特定場景提示
|
|
|
|
| 193 |
"A street-level view showing pedestrian perspective and immediate surroundings.",
|
| 194 |
"A bird's-eye view of city organization and movement patterns from high above.",
|
| 195 |
"An eye-level perspective showing direct human interaction with urban elements."
|
| 196 |
+
],
|
| 197 |
+
"modern_vs_traditional_kitchen": [
|
| 198 |
+
"A modern kitchen with sleek stainless steel appliances, minimalist design and contemporary fixtures.",
|
| 199 |
+
"A traditional kitchen with classic wooden cabinets, vintage appliances and conventional design elements."
|
| 200 |
+
],
|
| 201 |
+
|
| 202 |
+
"business_vs_leisure_dining": [
|
| 203 |
+
"A business dining environment with professional atmosphere, formal table settings and corporate meeting setup.",
|
| 204 |
+
"A leisure dining space with relaxed atmosphere, casual seating and recreational meal environment."
|
| 205 |
+
],
|
| 206 |
+
|
| 207 |
+
"dense_vs_spacious_retail": [
|
| 208 |
+
"A densely packed retail space with closely arranged merchandise and compact shopping aisles.",
|
| 209 |
+
"A spacious retail environment with wide aisles, generous display spacing and open shopping layout."
|
| 210 |
+
],
|
| 211 |
+
|
| 212 |
+
"private_vs_shared_workspace": [
|
| 213 |
+
"A private office space with individual workstation, personal storage and isolated work environment.",
|
| 214 |
+
"A shared workspace with communal tables, collaborative areas and open interaction zones."
|
| 215 |
+
],
|
| 216 |
+
|
| 217 |
+
"functional_vs_aesthetic_space": [
|
| 218 |
+
"A purely functional workspace focused on efficiency with practical furniture and utilitarian design.",
|
| 219 |
+
"An aesthetically designed space emphasizing visual appeal with decorative elements and stylistic choices."
|
| 220 |
]
|
| 221 |
}
|
| 222 |
|
|
|
|
| 236 |
"mixed_lighting": "A scene with combined natural and artificial light sources creating transition zones.",
|
| 237 |
"beach_daylight": "A photo taken at a beach with bright natural sunlight and reflections from water.",
|
| 238 |
"sports_arena_lighting": "A photo of a sports venue illuminated by powerful overhead lighting systems.",
|
| 239 |
+
"kitchen_task_lighting": "A photo of a kitchen with focused lighting concentrated on work surfaces.",
|
| 240 |
+
"photography_studio_lighting": "A photo taken in a photography studio with controlled professional lighting and even illumination.",
|
| 241 |
+
"retail_display_lighting": "A photo taken in retail environment with strategic product lighting and commercial illumination design.",
|
| 242 |
+
"conference_room_lighting": "A photo taken in a conference room with balanced meeting lighting and presentation-friendly illumination.",
|
| 243 |
+
"golden_hour_outdoor": "A photo taken during golden hour with warm, low-angle sunlight creating dramatic shadows and highlights.",
|
| 244 |
+
"overcast_diffused_light": "A photo taken under overcast sky with soft, even diffused lighting and minimal shadows.",
|
| 245 |
+
"harsh_midday_sun": "A photo taken under intense midday sunlight with strong contrasts and sharp shadows.",
|
| 246 |
+
"office_mixed_lighting": "A photo taken in office environment combining natural window light with artificial ceiling illumination.",
|
| 247 |
+
"restaurant_ambient_lighting": "A photo taken in restaurant with carefully designed ambient lighting combining multiple warm light sources.",
|
| 248 |
+
"retail_accent_lighting": "A photo taken in retail space with accent lighting highlighting products against general ambient illumination."
|
| 249 |
}
|
| 250 |
|
| 251 |
# 針對新場景類型的特殊提示
|
|
|
|
| 303 |
"A high-angle view of an intersection showing traffic and pedestrian flow patterns.",
|
| 304 |
"A drone perspective of urban crossing design viewed from directly above.",
|
| 305 |
"A vertical view of a street intersection showing crossing infrastructure."
|
| 306 |
+
],
|
| 307 |
+
"medical_waiting_room": [
|
| 308 |
+
"A medical facility waiting area with comfortable seating, health information displays and patient-focused design.",
|
| 309 |
+
"A healthcare waiting space with sanitized surfaces, medical equipment visibility and clinical atmosphere.",
|
| 310 |
+
"A medical office reception area with appointment scheduling setup and healthcare service information."
|
| 311 |
+
],
|
| 312 |
+
|
| 313 |
+
"science_laboratory": [
|
| 314 |
+
"A science laboratory with experimental equipment, safety features and research workstations.",
|
| 315 |
+
"A chemistry lab with fume hoods, lab benches and scientific instrument arrangements.",
|
| 316 |
+
"A biology laboratory with microscopes, specimen storage and life science research setup."
|
| 317 |
+
],
|
| 318 |
+
|
| 319 |
+
"design_studio_workspace": [
|
| 320 |
+
"A design studio with creative tools, inspiration boards and artistic project development areas.",
|
| 321 |
+
"An architecture office with drafting tables, model displays and design development workspaces.",
|
| 322 |
+
"A graphic design workspace with computer workstations, color calibration tools and creative project areas."
|
| 323 |
+
],
|
| 324 |
+
|
| 325 |
+
"maintenance_workshop": [
|
| 326 |
+
"A maintenance workshop with repair tools, work benches and technical service equipment.",
|
| 327 |
+
"A mechanical service area with diagnostic equipment, repair stations and automotive maintenance setup.",
|
| 328 |
+
"A technical workshop with specialized tools, parts storage and equipment maintenance facilities."
|
| 329 |
]
|
| 330 |
}
|
| 331 |
|
|
|
|
| 337 |
"bird_eye": "A photo taken from very high above showing a complete overhead perspective.",
|
| 338 |
"street_level": "A photo taken from the perspective of someone standing on the street.",
|
| 339 |
"interior": "A photo taken from inside a building showing the internal environment.",
|
| 340 |
+
"vehicular": "A photo taken from inside or mounted on a moving vehicle.",
|
| 341 |
+
|
| 342 |
+
# 較詳細的視角
|
| 343 |
+
"security_camera_angle": "A photo taken from fixed security camera position showing surveillance perspective of the area.",
|
| 344 |
+
"drone_inspection_view": "A photo taken from drone perspective for inspection purposes showing detailed overhead examination angle.",
|
| 345 |
+
"architectural_documentation_view": "A photo taken specifically for architectural documentation showing building features and structural details.",
|
| 346 |
+
"customer_entering_view": "A photo taken from the perspective of a customer or visitor entering the space for the first time.",
|
| 347 |
+
"worker_daily_perspective": "A photo taken from the viewpoint of someone who works in this environment on a daily basis.",
|
| 348 |
+
"maintenance_access_view": "A photo taken from the perspective needed for maintenance or service access to equipment and facilities."
|
| 349 |
}
|
| 350 |
|
| 351 |
OBJECT_COMBINATION_PROMPTS = {
|
|
|
|
| 356 |
"retail_environment": "A scene with merchandise displays, shoppers, and store fixtures.",
|
| 357 |
"crosswalk_scene": "A scene with street markings, pedestrians crossing, and traffic signals.",
|
| 358 |
"cooking_area": "A scene with stoves, prep surfaces, cooking utensils, and food items.",
|
| 359 |
+
"recreational_space": "A scene with sports equipment, play areas, and activity participants.",
|
| 360 |
+
"medical_examination_setup": "A scene with medical examination table, diagnostic equipment, and healthcare monitoring devices.",
|
| 361 |
+
"laboratory_research_arrangement": "A scene with scientific instruments, sample containers, and research documentation materials.",
|
| 362 |
+
"technical_repair_station": "A scene with diagnostic tools, replacement parts, and mechanical repair equipment.",
|
| 363 |
+
"art_creation_workspace": "A scene with artistic supplies, canvases, and creative project materials arranged for art making.",
|
| 364 |
+
"music_practice_setup": "A scene with musical instruments, sheet music, and sound equipment for music practice.",
|
| 365 |
+
"craft_workshop_arrangement": "A scene with crafting tools, materials, and project supplies organized for handmade creation.",
|
| 366 |
+
"language_learning_environment": "A scene with language learning materials, reference books, and communication practice tools.",
|
| 367 |
+
"science_experiment_setup": "A scene with scientific apparatus, measurement tools, and experimental materials for hands-on learning."
|
| 368 |
}
|
| 369 |
|
| 370 |
ACTIVITY_PROMPTS = {
|
|
|
|
| 375 |
"exercising": "People engaged in physical activities, using sports equipment, and training.",
|
| 376 |
"cooking": "People preparing food, using kitchen equipment, and creating meals.",
|
| 377 |
"crossing_street": "People walking across designated crosswalks and navigating intersections.",
|
| 378 |
+
"recreational_activity": "People engaged in leisure activities, games, and social recreation.",
|
| 379 |
+
"consulting": "People engaged in professional consultation with documents, presentations, and advisory discussions.",
|
| 380 |
+
"training": "People participating in skill development training with instructional materials and practice exercises.",
|
| 381 |
+
"maintenance": "People performing maintenance tasks with technical equipment and repair procedures.",
|
| 382 |
+
"brainstorming": "People engaged in creative brainstorming with idea development tools and collaborative thinking.",
|
| 383 |
+
"designing": "People working on design projects with creative tools, sketches, and visual development materials.",
|
| 384 |
+
"prototyping": "People building and testing prototypes with development materials and experimental approaches.",
|
| 385 |
+
"researching": "People conducting research with reference materials, databases, and investigative methods.",
|
| 386 |
+
"experimenting": "People performing experiments with scientific equipment and systematic testing procedures.",
|
| 387 |
+
"practicing": "People engaged in skill practice with repetitive exercises and performance improvement activities."
|
| 388 |
}
|
clip_zero_shot_classifier.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
-
import
|
| 4 |
from PIL import Image
|
| 5 |
import numpy as np
|
| 6 |
import logging
|
|
@@ -21,18 +21,18 @@ class CLIPZeroShotClassifier:
|
|
| 21 |
這是一個總窗口class,協調各個組件的工作以提供統一的對外接口。
|
| 22 |
"""
|
| 23 |
|
| 24 |
-
def __init__(self, model_name: str = "ViT-B
|
| 25 |
"""
|
| 26 |
初始化CLIP零樣本分類器
|
| 27 |
|
| 28 |
Args:
|
| 29 |
-
model_name:
|
| 30 |
device: 運行設備,None則自動選擇
|
| 31 |
"""
|
| 32 |
self.logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
# 初始化各個組件
|
| 35 |
-
self.clip_model_manager = CLIPModelManager(model_name, device)
|
| 36 |
self.landmark_data_manager = LandmarkDataManager()
|
| 37 |
self.image_analyzer = ImageAnalyzer()
|
| 38 |
self.confidence_manager = ConfidenceManager()
|
|
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
+
import open_clip
|
| 4 |
from PIL import Image
|
| 5 |
import numpy as np
|
| 6 |
import logging
|
|
|
|
| 21 |
這是一個總窗口class,協調各個組件的工作以提供統一的對外接口。
|
| 22 |
"""
|
| 23 |
|
| 24 |
+
def __init__(self, model_name: str = "ViT-B-16", device: str = None, pretrained: str = "laion2b_s34b_b88k"):
|
| 25 |
"""
|
| 26 |
初始化CLIP零樣本分類器
|
| 27 |
|
| 28 |
Args:
|
| 29 |
+
model_name: OpenCLIP模型名稱,默認為"ViT-B-16"
|
| 30 |
device: 運行設備,None則自動選擇
|
| 31 |
"""
|
| 32 |
self.logger = logging.getLogger(__name__)
|
| 33 |
|
| 34 |
# 初始化各個組件
|
| 35 |
+
self.clip_model_manager = CLIPModelManager(model_name, device, pretrained)
|
| 36 |
self.landmark_data_manager = LandmarkDataManager()
|
| 37 |
self.image_analyzer = ImageAnalyzer()
|
| 38 |
self.confidence_manager = ConfidenceManager()
|
llm_enhancer.py
CHANGED
|
@@ -3,7 +3,7 @@ import traceback
|
|
| 3 |
import re
|
| 4 |
from typing import Dict, List, Any, Optional
|
| 5 |
|
| 6 |
-
from
|
| 7 |
from prompt_template_manager import PromptTemplateManager
|
| 8 |
from response_processor import ResponseProcessor
|
| 9 |
from text_quality_validator import TextQualityValidator
|
|
@@ -44,7 +44,7 @@ class LLMEnhancer:
|
|
| 44 |
|
| 45 |
try:
|
| 46 |
# 初始化四個核心組件
|
| 47 |
-
self.model_manager =
|
| 48 |
model_path=model_path,
|
| 49 |
tokenizer_path=tokenizer_path,
|
| 50 |
device=device,
|
|
|
|
| 3 |
import re
|
| 4 |
from typing import Dict, List, Any, Optional
|
| 5 |
|
| 6 |
+
from llm_model_manager import LLMModelManager
|
| 7 |
from prompt_template_manager import PromptTemplateManager
|
| 8 |
from response_processor import ResponseProcessor
|
| 9 |
from text_quality_validator import TextQualityValidator
|
|
|
|
| 44 |
|
| 45 |
try:
|
| 46 |
# 初始化四個核心組件
|
| 47 |
+
self.model_manager = LLMModelManager(
|
| 48 |
model_path=model_path,
|
| 49 |
tokenizer_path=tokenizer_path,
|
| 50 |
device=device,
|
llm_model_manager.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Dict, Optional, Any
|
| 5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
| 6 |
+
from huggingface_hub import login
|
| 7 |
+
|
| 8 |
+
class ModelLoadingError(Exception):
|
| 9 |
+
"""Custom exception for model loading failures"""
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ModelGenerationError(Exception):
|
| 14 |
+
"""Custom exception for model generation failures"""
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LLMModelManager:
|
| 19 |
+
"""
|
| 20 |
+
負責LLM模型的載入、設備管理和文本生成。
|
| 21 |
+
管理模型、記憶體優化和設備配置。
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self,
|
| 25 |
+
model_path: Optional[str] = None,
|
| 26 |
+
tokenizer_path: Optional[str] = None,
|
| 27 |
+
device: Optional[str] = None,
|
| 28 |
+
max_length: int = 2048,
|
| 29 |
+
temperature: float = 0.3,
|
| 30 |
+
top_p: float = 0.85):
|
| 31 |
+
"""
|
| 32 |
+
初始化模型管理器
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
model_path: LLM模型的路徑或HuggingFace模型名稱,默認使用Llama 3.2
|
| 36 |
+
tokenizer_path: tokenizer的路徑,通常與model_path相同
|
| 37 |
+
device: 運行設備 ('cpu'或'cuda'),None時自動檢測
|
| 38 |
+
max_length: 輸入文本的最大長度
|
| 39 |
+
temperature: 生成文本的溫度參數
|
| 40 |
+
top_p: 生成文本時的核心採樣機率閾值
|
| 41 |
+
"""
|
| 42 |
+
# 設置專屬logger
|
| 43 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
| 44 |
+
if not self.logger.handlers:
|
| 45 |
+
handler = logging.StreamHandler()
|
| 46 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
| 47 |
+
handler.setFormatter(formatter)
|
| 48 |
+
self.logger.addHandler(handler)
|
| 49 |
+
self.logger.setLevel(logging.INFO)
|
| 50 |
+
|
| 51 |
+
# 模型配置
|
| 52 |
+
self.model_path = model_path or "meta-llama/Llama-3.2-3B-Instruct"
|
| 53 |
+
self.tokenizer_path = tokenizer_path or self.model_path
|
| 54 |
+
|
| 55 |
+
# 設備管理
|
| 56 |
+
self.device = self._detect_device(device)
|
| 57 |
+
self.logger.info(f"Device selected: {self.device}")
|
| 58 |
+
|
| 59 |
+
# 生成參數
|
| 60 |
+
self.max_length = max_length
|
| 61 |
+
self.temperature = temperature
|
| 62 |
+
self.top_p = top_p
|
| 63 |
+
|
| 64 |
+
# 模型狀態
|
| 65 |
+
self.model = None
|
| 66 |
+
self.tokenizer = None
|
| 67 |
+
self._model_loaded = False
|
| 68 |
+
self.call_count = 0
|
| 69 |
+
|
| 70 |
+
# HuggingFace認證
|
| 71 |
+
self.hf_token = self._setup_huggingface_auth()
|
| 72 |
+
|
| 73 |
+
def _detect_device(self, device: Optional[str]) -> str:
|
| 74 |
+
"""
|
| 75 |
+
檢測並設置運行設備
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
device: 用戶指定的設備,None時自動檢測
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
str: ('cuda' or 'cpu')
|
| 82 |
+
"""
|
| 83 |
+
if device:
|
| 84 |
+
if device == 'cuda' and not torch.cuda.is_available():
|
| 85 |
+
self.logger.warning("CUDA requested but not available, falling back to CPU")
|
| 86 |
+
return 'cpu'
|
| 87 |
+
return device
|
| 88 |
+
|
| 89 |
+
detected_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 90 |
+
|
| 91 |
+
if detected_device == 'cuda':
|
| 92 |
+
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
|
| 93 |
+
self.logger.info(f"CUDA detected with {gpu_memory:.2f} GB GPU memory")
|
| 94 |
+
|
| 95 |
+
return detected_device
|
| 96 |
+
|
| 97 |
+
def _setup_huggingface_auth(self) -> Optional[str]:
|
| 98 |
+
"""
|
| 99 |
+
設置HuggingFace認證
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Optional[str]: HuggingFace token,如果可用
|
| 103 |
+
"""
|
| 104 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 105 |
+
|
| 106 |
+
if hf_token:
|
| 107 |
+
try:
|
| 108 |
+
login(token=hf_token)
|
| 109 |
+
self.logger.info("Successfully authenticated with HuggingFace")
|
| 110 |
+
return hf_token
|
| 111 |
+
except Exception as e:
|
| 112 |
+
self.logger.error(f"HuggingFace authentication failed: {e}")
|
| 113 |
+
return None
|
| 114 |
+
else:
|
| 115 |
+
self.logger.warning("HF_TOKEN not found. Access to gated models may be limited")
|
| 116 |
+
return None
|
| 117 |
+
|
| 118 |
+
def _load_model(self):
|
| 119 |
+
"""
|
| 120 |
+
載入LLM模型和tokenizer,使用8位量化以節省記憶體
|
| 121 |
+
|
| 122 |
+
Raises:
|
| 123 |
+
ModelLoadingError: 當模型載入失敗時
|
| 124 |
+
"""
|
| 125 |
+
if self._model_loaded:
|
| 126 |
+
return
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
self.logger.info(f"Loading model from {self.model_path} with 8-bit quantization")
|
| 130 |
+
|
| 131 |
+
# 清理GPU記憶體
|
| 132 |
+
self._clear_gpu_cache()
|
| 133 |
+
|
| 134 |
+
# 設置8位量化配置
|
| 135 |
+
quantization_config = BitsAndBytesConfig(
|
| 136 |
+
load_in_8bit=True,
|
| 137 |
+
llm_int8_enable_fp32_cpu_offload=True
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# 載入tokenizer
|
| 141 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 142 |
+
self.tokenizer_path,
|
| 143 |
+
padding_side="left",
|
| 144 |
+
use_fast=False,
|
| 145 |
+
token=self.hf_token
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# 設置特殊標記
|
| 149 |
+
if self.tokenizer.pad_token is None:
|
| 150 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 151 |
+
|
| 152 |
+
# 載入模型
|
| 153 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 154 |
+
self.model_path,
|
| 155 |
+
quantization_config=quantization_config,
|
| 156 |
+
device_map="auto",
|
| 157 |
+
low_cpu_mem_usage=True,
|
| 158 |
+
token=self.hf_token
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
self._model_loaded = True
|
| 162 |
+
self.logger.info("Model loaded successfully")
|
| 163 |
+
|
| 164 |
+
except Exception as e:
|
| 165 |
+
error_msg = f"Failed to load model: {str(e)}"
|
| 166 |
+
self.logger.error(error_msg)
|
| 167 |
+
raise ModelLoadingError(error_msg) from e
|
| 168 |
+
|
| 169 |
+
def _clear_gpu_cache(self):
|
| 170 |
+
"""清理GPU記憶體緩存"""
|
| 171 |
+
if torch.cuda.is_available():
|
| 172 |
+
torch.cuda.empty_cache()
|
| 173 |
+
self.logger.debug("GPU cache cleared")
|
| 174 |
+
|
| 175 |
+
def generate_response(self, prompt: str, **generation_kwargs) -> str:
|
| 176 |
+
"""
|
| 177 |
+
生成LLM回應
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
prompt: 輸入提示詞
|
| 181 |
+
**generation_kwargs: 額外的生成參數,可覆蓋預設值
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
str: 生成的回應文本
|
| 185 |
+
|
| 186 |
+
Raises:
|
| 187 |
+
ModelGenerationError: 當生成失敗時
|
| 188 |
+
"""
|
| 189 |
+
# 確保模型已載入
|
| 190 |
+
if not self._model_loaded:
|
| 191 |
+
self._load_model()
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
self.call_count += 1
|
| 195 |
+
self.logger.info(f"Generating response (call #{self.call_count})")
|
| 196 |
+
|
| 197 |
+
# clean GPU
|
| 198 |
+
self._clear_gpu_cache()
|
| 199 |
+
|
| 200 |
+
# 設置固定種子以提高一致性
|
| 201 |
+
torch.manual_seed(42)
|
| 202 |
+
|
| 203 |
+
# prepare input
|
| 204 |
+
inputs = self.tokenizer(
|
| 205 |
+
prompt,
|
| 206 |
+
return_tensors="pt",
|
| 207 |
+
truncation=True,
|
| 208 |
+
max_length=self.max_length
|
| 209 |
+
).to(self.device)
|
| 210 |
+
|
| 211 |
+
# 準備生成參數
|
| 212 |
+
generation_params = self._prepare_generation_params(**generation_kwargs)
|
| 213 |
+
generation_params.update({
|
| 214 |
+
"pad_token_id": self.tokenizer.eos_token_id,
|
| 215 |
+
"attention_mask": inputs.attention_mask,
|
| 216 |
+
"use_cache": True,
|
| 217 |
+
})
|
| 218 |
+
|
| 219 |
+
# resposne
|
| 220 |
+
with torch.no_grad():
|
| 221 |
+
outputs = self.model.generate(inputs.input_ids, **generation_params)
|
| 222 |
+
|
| 223 |
+
# 解碼回應
|
| 224 |
+
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 225 |
+
response = self._extract_generated_response(full_response, prompt)
|
| 226 |
+
|
| 227 |
+
if not response or len(response.strip()) < 10:
|
| 228 |
+
raise ModelGenerationError("Generated response is too short or empty")
|
| 229 |
+
|
| 230 |
+
self.logger.info(f"Response generated successfully ({len(response)} characters)")
|
| 231 |
+
return response
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
error_msg = f"Text generation failed: {str(e)}"
|
| 235 |
+
self.logger.error(error_msg)
|
| 236 |
+
raise ModelGenerationError(error_msg) from e
|
| 237 |
+
|
| 238 |
+
def _prepare_generation_params(self, **kwargs) -> Dict[str, Any]:
|
| 239 |
+
"""
|
| 240 |
+
準備生成參數,支援模型特定的優化
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
**kwargs: 用戶提供的生成參數
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
Dict[str, Any]: 完整的生成參數配置
|
| 247 |
+
"""
|
| 248 |
+
# basic parameters
|
| 249 |
+
params = {
|
| 250 |
+
"max_new_tokens": 120,
|
| 251 |
+
"temperature": self.temperature,
|
| 252 |
+
"top_p": self.top_p,
|
| 253 |
+
"do_sample": True,
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
# 針對Llama模型的特殊優化
|
| 257 |
+
if "llama" in self.model_path.lower():
|
| 258 |
+
params.update({
|
| 259 |
+
"max_new_tokens": 600,
|
| 260 |
+
"temperature": 0.35, # not too big
|
| 261 |
+
"top_p": 0.75,
|
| 262 |
+
"repetition_penalty": 1.5,
|
| 263 |
+
"num_beams": 5,
|
| 264 |
+
"length_penalty": 1,
|
| 265 |
+
"no_repeat_ngram_size": 3
|
| 266 |
+
})
|
| 267 |
+
else:
|
| 268 |
+
params.update({
|
| 269 |
+
"max_new_tokens": 300,
|
| 270 |
+
"temperature": 0.6,
|
| 271 |
+
"top_p": 0.9,
|
| 272 |
+
"num_beams": 1,
|
| 273 |
+
"repetition_penalty": 1.05
|
| 274 |
+
})
|
| 275 |
+
|
| 276 |
+
# 用戶參數覆蓋預設值
|
| 277 |
+
params.update(kwargs)
|
| 278 |
+
|
| 279 |
+
return params
|
| 280 |
+
|
| 281 |
+
def _extract_generated_response(self, full_response: str, prompt: str) -> str:
|
| 282 |
+
"""
|
| 283 |
+
從完整回應中提取生成的部分
|
| 284 |
+
|
| 285 |
+
Args:
|
| 286 |
+
full_response: 模型的完整輸出
|
| 287 |
+
prompt: 原始提示詞
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
str: 提取的生成回應
|
| 291 |
+
"""
|
| 292 |
+
# 尋找assistant標記
|
| 293 |
+
assistant_tag = "<|assistant|>"
|
| 294 |
+
if assistant_tag in full_response:
|
| 295 |
+
response = full_response.split(assistant_tag)[-1].strip()
|
| 296 |
+
|
| 297 |
+
# 檢查是否有未閉合的user標記
|
| 298 |
+
user_tag = "<|user|>"
|
| 299 |
+
if user_tag in response:
|
| 300 |
+
response = response.split(user_tag)[0].strip()
|
| 301 |
+
|
| 302 |
+
return response
|
| 303 |
+
|
| 304 |
+
# 移除輸入提示詞
|
| 305 |
+
if full_response.startswith(prompt):
|
| 306 |
+
return full_response[len(prompt):].strip()
|
| 307 |
+
|
| 308 |
+
return full_response.strip()
|
| 309 |
+
|
| 310 |
+
def reset_context(self):
|
| 311 |
+
"""重置模型上下文,清理GPU緩存"""
|
| 312 |
+
if self._model_loaded:
|
| 313 |
+
self._clear_gpu_cache()
|
| 314 |
+
self.logger.info("Model context reset")
|
| 315 |
+
else:
|
| 316 |
+
self.logger.info("Model not loaded, no context to reset")
|
| 317 |
+
|
| 318 |
+
def get_current_device(self) -> str:
|
| 319 |
+
"""
|
| 320 |
+
獲取當前運行設備
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
str: 當前設備名稱
|
| 324 |
+
"""
|
| 325 |
+
return self.device
|
| 326 |
+
|
| 327 |
+
def is_model_loaded(self) -> bool:
|
| 328 |
+
"""
|
| 329 |
+
檢查模型是否已載入
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
bool: 模型載入狀態
|
| 333 |
+
"""
|
| 334 |
+
return self._model_loaded
|
| 335 |
+
|
| 336 |
+
def get_call_count(self) -> int:
|
| 337 |
+
"""
|
| 338 |
+
獲取模型調用次數
|
| 339 |
+
|
| 340 |
+
Returns:
|
| 341 |
+
int: 調用次數
|
| 342 |
+
"""
|
| 343 |
+
return self.call_count
|
| 344 |
+
|
| 345 |
+
def get_model_info(self) -> Dict[str, Any]:
|
| 346 |
+
"""
|
| 347 |
+
獲取模型信息
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
Dict[str, Any]: 包含模型路徑、設備、載入狀態等信息
|
| 351 |
+
"""
|
| 352 |
+
return {
|
| 353 |
+
"model_path": self.model_path,
|
| 354 |
+
"device": self.device,
|
| 355 |
+
"is_loaded": self._model_loaded,
|
| 356 |
+
"call_count": self.call_count,
|
| 357 |
+
"has_hf_token": self.hf_token is not None
|
| 358 |
+
}
|
requirements.txt
CHANGED
|
@@ -6,7 +6,7 @@ pillow>=9.4.0
|
|
| 6 |
numpy>=1.23.5
|
| 7 |
matplotlib>=3.7.0
|
| 8 |
gradio>=3.32.0
|
| 9 |
-
|
| 10 |
yt-dlp>=2023.3.4
|
| 11 |
requests>=2.28.1
|
| 12 |
transformers
|
|
@@ -14,4 +14,4 @@ accelerate
|
|
| 14 |
bitsandbytes
|
| 15 |
sentencepiece
|
| 16 |
huggingface_hub>=0.19.0
|
| 17 |
-
urllib3>=1.26.0
|
|
|
|
| 6 |
numpy>=1.23.5
|
| 7 |
matplotlib>=3.7.0
|
| 8 |
gradio>=3.32.0
|
| 9 |
+
open-clip-torch>=2.20.0
|
| 10 |
yt-dlp>=2023.3.4
|
| 11 |
requests>=2.28.1
|
| 12 |
transformers
|
|
|
|
| 14 |
bitsandbytes
|
| 15 |
sentencepiece
|
| 16 |
huggingface_hub>=0.19.0
|
| 17 |
+
urllib3>=1.26.0
|
scene_scoring_engine.py
CHANGED
|
@@ -249,13 +249,13 @@ class SceneScoringEngine:
|
|
| 249 |
Returns:
|
| 250 |
(最佳場景類型, 置信度) 的元組
|
| 251 |
"""
|
|
|
|
| 252 |
if not scene_scores:
|
| 253 |
return "unknown", 0.0
|
| 254 |
|
| 255 |
-
# 檢查地標相關分數是否達到門檻,如果是,直接回傳 "tourist_landmark"
|
| 256 |
# 假設場景分數 dictionary 中,"tourist_landmark"、"historical_monument"、"natural_landmark" 三個 key
|
| 257 |
# 分別代表不同類型地標。將它們加總,若總分超過 0.3,就認定為地標場景。
|
| 258 |
-
# print(f"DEBUG: determine_scene_type input scores: {scene_scores}")
|
| 259 |
landmark_score = (
|
| 260 |
scene_scores.get("tourist_landmark", 0.0) +
|
| 261 |
scene_scores.get("historical_monument", 0.0) +
|
|
@@ -268,7 +268,7 @@ class SceneScoringEngine:
|
|
| 268 |
# 找分數最高的那個場景
|
| 269 |
best_scene = max(scene_scores, key=scene_scores.get)
|
| 270 |
best_score = scene_scores[best_scene]
|
| 271 |
-
|
| 272 |
return best_scene, float(best_score)
|
| 273 |
|
| 274 |
def fuse_scene_scores(self, yolo_scene_scores: Dict[str, float],
|
|
@@ -361,8 +361,9 @@ class SceneScoringEngine:
|
|
| 361 |
current_yolo_weight = default_yolo_weight
|
| 362 |
current_clip_weight = default_clip_weight
|
| 363 |
current_places365_weight = default_places365_weight
|
| 364 |
-
|
| 365 |
-
|
|
|
|
| 366 |
|
| 367 |
scene_definition = self.scene_types.get(scene_type, {})
|
| 368 |
|
|
@@ -394,8 +395,8 @@ class SceneScoringEngine:
|
|
| 394 |
"professional_kitchen", "cafe", "library", "gym", "retail_store",
|
| 395 |
"supermarket", "classroom", "conference_room", "medical_facility",
|
| 396 |
"educational_setting", "dining_area"]):
|
| 397 |
-
current_yolo_weight = 0.
|
| 398 |
-
current_clip_weight = 0.
|
| 399 |
current_places365_weight = 0.25
|
| 400 |
|
| 401 |
# 對於特定室外常見場景(非地標),物體仍然重要
|
|
@@ -491,7 +492,7 @@ class SceneScoringEngine:
|
|
| 491 |
fused_scores[scene_type] = min(1.0, max(0.0, fused_score))
|
| 492 |
|
| 493 |
return fused_scores
|
| 494 |
-
|
| 495 |
|
| 496 |
def update_enable_landmark_status(self, enable_landmark: bool):
|
| 497 |
"""
|
|
|
|
| 249 |
Returns:
|
| 250 |
(最佳場景類型, 置信度) 的元組
|
| 251 |
"""
|
| 252 |
+
print(f"DEBUG: determine_scene_type input scores: {scene_scores}")
|
| 253 |
if not scene_scores:
|
| 254 |
return "unknown", 0.0
|
| 255 |
|
| 256 |
+
# 檢查地標相關分數是否達到門檻,如果是,直接回傳 "tourist_landmark"
|
| 257 |
# 假設場景分數 dictionary 中,"tourist_landmark"、"historical_monument"、"natural_landmark" 三個 key
|
| 258 |
# 分別代表不同類型地標。將它們加總,若總分超過 0.3,就認定為地標場景。
|
|
|
|
| 259 |
landmark_score = (
|
| 260 |
scene_scores.get("tourist_landmark", 0.0) +
|
| 261 |
scene_scores.get("historical_monument", 0.0) +
|
|
|
|
| 268 |
# 找分數最高的那個場景
|
| 269 |
best_scene = max(scene_scores, key=scene_scores.get)
|
| 270 |
best_score = scene_scores[best_scene]
|
| 271 |
+
print(f"DEBUG: determine_scene_type result: scene={best_scene}, score={best_score}")
|
| 272 |
return best_scene, float(best_score)
|
| 273 |
|
| 274 |
def fuse_scene_scores(self, yolo_scene_scores: Dict[str, float],
|
|
|
|
| 361 |
current_yolo_weight = default_yolo_weight
|
| 362 |
current_clip_weight = default_clip_weight
|
| 363 |
current_places365_weight = default_places365_weight
|
| 364 |
+
print(f"DEBUG: Scene {scene_type} - yolo_score: {yolo_score}, clip_score: {clip_score}, places365_score: {places365_score}")
|
| 365 |
+
print(f"DEBUG: Scene {scene_type} - weights: yolo={current_yolo_weight:.3f}, clip={current_clip_weight:.3f}, places365={current_places365_weight:.3f}")
|
| 366 |
+
|
| 367 |
|
| 368 |
scene_definition = self.scene_types.get(scene_type, {})
|
| 369 |
|
|
|
|
| 395 |
"professional_kitchen", "cafe", "library", "gym", "retail_store",
|
| 396 |
"supermarket", "classroom", "conference_room", "medical_facility",
|
| 397 |
"educational_setting", "dining_area"]):
|
| 398 |
+
current_yolo_weight = 0.50
|
| 399 |
+
current_clip_weight = 0.25
|
| 400 |
current_places365_weight = 0.25
|
| 401 |
|
| 402 |
# 對於特定室外常見場景(非地標),物體仍然重要
|
|
|
|
| 492 |
fused_scores[scene_type] = min(1.0, max(0.0, fused_score))
|
| 493 |
|
| 494 |
return fused_scores
|
| 495 |
+
print(f"DEBUG: fuse_scene_scores final result: {fused_scores}")
|
| 496 |
|
| 497 |
def update_enable_landmark_status(self, enable_landmark: bool):
|
| 498 |
"""
|