Coool2 commited on
Commit
1a3b775
·
1 Parent(s): 6aad47c

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +29 -42
agent.py CHANGED
@@ -200,64 +200,51 @@ def initialize_models(use_api_mode=False):
200
 
201
  from typing import Any, List, Optional
202
  from llama_index.core.embeddings import BaseEmbedding
203
- import torch
204
- from FlagEmbedding.visual.modeling import Visualized_BGE
205
 
206
- class BAAIVisualizedAdvanced(BaseEmbedding):
207
  """
208
- Advanced implementation using FlagEmbedding's Visualized_BGE.
209
  """
210
 
211
- def __init__(self,
212
- model_name_bge: str = "BAAI/bge-base-en-v1.5",
213
- model_weight_path: str = "path/to/Visualized_base_en_v1.5.pth",
214
- **kwargs: Any) -> None:
215
  super().__init__(**kwargs)
216
- # Initialize the Visualized BGE model
217
- self._model = Visualized_BGE(
218
- model_name_bge=model_name_bge,
219
- model_weight=model_weight_path
220
- )
221
- self._model.eval()
222
 
223
  @classmethod
224
  def class_name(cls) -> str:
225
- return "baai_visualized_advanced"
226
 
227
  def _get_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
228
- """Generate embedding for query with optional image."""
229
- with torch.no_grad():
230
- if image_path:
231
- # Encode both text and image
232
- embedding = self._model.encode(image=image_path, text=query)
233
- else:
234
- # Text-only encoding
235
- embedding = self._model.encode(text=query)
236
- return embedding.cpu().numpy().tolist()
237
 
238
  def _get_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
239
- """Generate embedding for text with optional image."""
240
- with torch.no_grad():
241
- if image_path:
242
- # Image-only encoding
243
- embedding = self._model.encode(image=image_path)
244
- else:
245
- # Text-only encoding
246
- embedding = self._model.encode(text=text)
247
- return embedding.cpu().numpy().tolist()
248
 
249
  def _get_text_embeddings(self, texts: List[str], image_paths: Optional[List[str]] = None) -> List[List[float]]:
250
- """Batch embedding generation."""
251
  embeddings = []
252
  image_paths = image_paths or [None] * len(texts)
253
 
254
- with torch.no_grad():
255
- for text, img_path in zip(texts, image_paths):
256
- if img_path:
257
- emb = self._model.encode(image=img_path, text=text)
258
- else:
259
- emb = self._model.encode(text=text)
260
- embeddings.append(emb.cpu().numpy().tolist())
261
 
262
  return embeddings
263
 
@@ -268,7 +255,7 @@ def initialize_models(use_api_mode=False):
268
  return self._get_text_embedding(text, image_path)
269
 
270
 
271
- embed_model = BAAIVisualizedEmbedding()
272
  # Code LLM
273
  code_llm = HuggingFaceLLM(
274
  model_name="Qwen/Qwen2.5-Coder-3B-Instruct",
 
200
 
201
  from typing import Any, List, Optional
202
  from llama_index.core.embeddings import BaseEmbedding
203
+ from sentence_transformers import SentenceTransformer
204
+ from PIL import Image
205
 
206
+ class MultimodalCLIPEmbedding(BaseEmbedding):
207
  """
208
+ Custom embedding class using CLIP for multimodal capabilities.
209
  """
210
 
211
+ def __init__(self, model_name: str = "clip-ViT-B-32", **kwargs: Any) -> None:
 
 
 
212
  super().__init__(**kwargs)
213
+ self._model = SentenceTransformer(model_name)
 
 
 
 
 
214
 
215
  @classmethod
216
  def class_name(cls) -> str:
217
+ return "multimodal_clip"
218
 
219
  def _get_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
220
+ if image_path:
221
+ image = Image.open(image_path)
222
+ embedding = self._model.encode(image)
223
+ return embedding.tolist()
224
+ else:
225
+ embedding = self._model.encode(query)
226
+ return embedding.tolist()
 
 
227
 
228
  def _get_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
229
+ if image_path:
230
+ image = Image.open(image_path)
231
+ embedding = self._model.encode(image)
232
+ return embedding.tolist()
233
+ else:
234
+ embedding = self._model.encode(text)
235
+ return embedding.tolist()
 
 
236
 
237
  def _get_text_embeddings(self, texts: List[str], image_paths: Optional[List[str]] = None) -> List[List[float]]:
 
238
  embeddings = []
239
  image_paths = image_paths or [None] * len(texts)
240
 
241
+ for text, img_path in zip(texts, image_paths):
242
+ if img_path:
243
+ image = Image.open(img_path)
244
+ emb = self._model.encode(image)
245
+ else:
246
+ emb = self._model.encode(text)
247
+ embeddings.append(emb.tolist())
248
 
249
  return embeddings
250
 
 
255
  return self._get_text_embedding(text, image_path)
256
 
257
 
258
+ embed_model = MultimodalCLIPEmbedding()
259
  # Code LLM
260
  code_llm = HuggingFaceLLM(
261
  model_name="Qwen/Qwen2.5-Coder-3B-Instruct",