jupyterjazz commited on
Commit
f35e327
·
verified ·
1 Parent(s): 4453d02

refactor-image-processing (#16)

Browse files

- refactor: support urls, fast processor, flash attn check (9624180b21e596eb410896719e1c1708aaed343c)
- refactor: image loading in st wrapper (9ef2e43d97b27bc27da6b71bc68d6160d317da20)

custom_st.py CHANGED
@@ -1,32 +1,34 @@
 
 
1
  from typing import Any, Dict, List, Literal, Optional, Union
2
 
 
3
  import torch
4
  from PIL import Image
5
  from torch import nn
6
- from transformers import AutoConfig, AutoProcessor, AutoModel
7
 
8
 
9
  class Transformer(nn.Module):
10
 
11
  save_in_root: bool = True
12
-
13
  def __init__(
14
  self,
15
- model_name_or_path: str = 'jinaai/jina-embeddings-v4',
16
  max_seq_length: Optional[int] = None,
17
  config_args: Optional[Dict[str, Any]] = None,
18
  model_args: Optional[Dict[str, Any]] = None,
19
  tokenizer_args: Optional[Dict[str, Any]] = None,
20
  cache_dir: Optional[str] = None,
21
- backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
22
  **kwargs,
23
  ) -> None:
24
  super(Transformer, self).__init__()
25
- if backend != 'torch':
26
  raise ValueError(
27
- f'Backend \'{backend}\' is not supported, please use \'torch\' instead'
28
  )
29
-
30
  config_kwargs = config_args or {}
31
  model_kwargs = model_args or {}
32
  tokenizer_kwargs = tokenizer_args or {}
@@ -34,9 +36,11 @@ class Transformer(nn.Module):
34
  self.config = AutoConfig.from_pretrained(
35
  model_name_or_path, cache_dir=cache_dir, **config_kwargs
36
  )
37
- self.default_task = model_args.pop('default_task', None)
38
  if self.default_task and self.default_task not in self.config.task_names:
39
- raise ValueError(f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}.")
 
 
40
 
41
  self.model = AutoModel.from_pretrained(
42
  model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs
@@ -45,6 +49,7 @@ class Transformer(nn.Module):
45
  self.processor = AutoProcessor.from_pretrained(
46
  model_name_or_path,
47
  cache_dir=cache_dir,
 
48
  **tokenizer_kwargs,
49
  )
50
  self.max_seq_length = max_seq_length or 8192
@@ -55,33 +60,52 @@ class Transformer(nn.Module):
55
  encoding = {}
56
  text_indices = []
57
  image_indices = []
58
-
59
  for i, text in enumerate(texts):
60
  if isinstance(text, str):
61
- text_indices.append(i)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  elif isinstance(text, Image.Image):
63
  image_indices.append(i)
64
  else:
65
- raise ValueError(f'Invalid input type: {type(text)}')
66
-
67
  if text_indices:
68
  _texts = [texts[i] for i in text_indices]
69
- text_features = self.processor.process_texts(_texts, max_length=self.max_seq_length)
 
 
70
  for key, value in text_features.items():
71
- encoding[f'text_{key}'] = value
72
- encoding['text_indices'] = text_indices
73
-
74
  if image_indices:
75
  _images = [texts[i] for i in image_indices]
76
  img_features = self.processor.process_images(_images)
77
  for key, value in img_features.items():
78
- encoding[f'image_{key}'] = value
79
- encoding['image_indices'] = image_indices
80
-
81
  return encoding
82
-
83
 
84
- def forward(self, features: Dict[str, torch.Tensor], task: Optional[str] = None) -> Dict[str, torch.Tensor]:
 
 
85
  self.model.eval()
86
 
87
  if task is None:
@@ -94,41 +118,55 @@ class Transformer(nn.Module):
94
  task = self.default_task
95
  else:
96
  if task not in self.config.task_names:
97
- raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
 
 
98
 
99
  device = self.model.device.type
100
  all_embeddings = []
101
-
102
  with torch.no_grad():
103
- if any(k.startswith('text_') for k in features.keys()):
104
- text_batch = {k[len('text_'):]: v.to(device) for k, v in features.items() if k.startswith('text_') and k != 'text_indices'}
105
- text_indices = features.get('text_indices', [])
106
-
 
 
 
 
107
  with torch.autocast(device_type=device):
108
- text_embeddings = self.model(**text_batch, task_label=task).single_vec_emb
 
 
109
  if self.config.truncate_dim:
110
- text_embeddings = text_embeddings[:, :self.config.truncate_dim]
111
-
112
  for i, embedding in enumerate(text_embeddings):
113
  all_embeddings.append((text_indices[i], embedding))
114
-
115
- if any(k.startswith('image_') for k in features.keys()):
116
- image_batch = {k[len('image_'):]: v.to(device) for k, v in features.items() if k.startswith('image_') and k != 'image_indices'}
117
- image_indices = features.get('image_indices', [])
118
-
 
 
 
 
119
  with torch.autocast(device_type=device):
120
- img_embeddings = self.model(**image_batch, task_label=task).single_vec_emb
 
 
121
  if self.config.truncate_dim:
122
- img_embeddings = img_embeddings[:, :self.config.truncate_dim]
123
-
124
  for i, embedding in enumerate(img_embeddings):
125
  all_embeddings.append((image_indices[i], embedding))
126
 
127
  if not all_embeddings:
128
- raise RuntimeError('No embeddings were generated')
129
 
130
  all_embeddings.sort(key=lambda x: x[0]) # sort by original index
131
  combined_embeddings = torch.stack([emb for _, emb in all_embeddings])
132
- features['sentence_embedding'] = combined_embeddings
133
-
134
  return features
 
1
+ from io import BytesIO
2
+ from pathlib import Path
3
  from typing import Any, Dict, List, Literal, Optional, Union
4
 
5
+ import requests
6
  import torch
7
  from PIL import Image
8
  from torch import nn
9
+ from transformers import AutoConfig, AutoModel, AutoProcessor
10
 
11
 
12
  class Transformer(nn.Module):
13
 
14
  save_in_root: bool = True
15
+
16
  def __init__(
17
  self,
18
+ model_name_or_path: str = "jinaai/jina-embeddings-v4",
19
  max_seq_length: Optional[int] = None,
20
  config_args: Optional[Dict[str, Any]] = None,
21
  model_args: Optional[Dict[str, Any]] = None,
22
  tokenizer_args: Optional[Dict[str, Any]] = None,
23
  cache_dir: Optional[str] = None,
24
+ backend: Literal["torch", "onnx", "openvino"] = "torch",
25
  **kwargs,
26
  ) -> None:
27
  super(Transformer, self).__init__()
28
+ if backend != "torch":
29
  raise ValueError(
30
+ f"Backend '{backend}' is not supported, please use 'torch' instead"
31
  )
 
32
  config_kwargs = config_args or {}
33
  model_kwargs = model_args or {}
34
  tokenizer_kwargs = tokenizer_args or {}
 
36
  self.config = AutoConfig.from_pretrained(
37
  model_name_or_path, cache_dir=cache_dir, **config_kwargs
38
  )
39
+ self.default_task = model_args.pop("default_task", None)
40
  if self.default_task and self.default_task not in self.config.task_names:
41
+ raise ValueError(
42
+ f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}."
43
+ )
44
 
45
  self.model = AutoModel.from_pretrained(
46
  model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs
 
49
  self.processor = AutoProcessor.from_pretrained(
50
  model_name_or_path,
51
  cache_dir=cache_dir,
52
+ use_fast=True,
53
  **tokenizer_kwargs,
54
  )
55
  self.max_seq_length = max_seq_length or 8192
 
60
  encoding = {}
61
  text_indices = []
62
  image_indices = []
 
63
  for i, text in enumerate(texts):
64
  if isinstance(text, str):
65
+ # Remove Query: or Passage: prefixes when checking for URLs or file paths
66
+ clean_text = text
67
+ if text.startswith("Query: "):
68
+ clean_text = text[len("Query: ") :]
69
+ elif text.startswith("Passage: "):
70
+ clean_text = text[len("Passage: ") :]
71
+
72
+ if clean_text.startswith("http"):
73
+ response = requests.get(clean_text)
74
+ texts[i] = Image.open(BytesIO(response.content)).convert("RGB")
75
+ image_indices.append(i)
76
+ elif Path(clean_text).is_file():
77
+ try:
78
+ texts[i] = Image.open(clean_text).convert("RGB")
79
+ image_indices.append(i)
80
+ except Exception as e:
81
+ text_indices.append(i)
82
+ else:
83
+ text_indices.append(i)
84
  elif isinstance(text, Image.Image):
85
  image_indices.append(i)
86
  else:
87
+ raise ValueError(f"Invalid input type: {type(text)}")
 
88
  if text_indices:
89
  _texts = [texts[i] for i in text_indices]
90
+ text_features = self.processor.process_texts(
91
+ _texts, max_length=self.max_seq_length
92
+ )
93
  for key, value in text_features.items():
94
+ encoding[f"text_{key}"] = value
95
+ encoding["text_indices"] = text_indices
96
+
97
  if image_indices:
98
  _images = [texts[i] for i in image_indices]
99
  img_features = self.processor.process_images(_images)
100
  for key, value in img_features.items():
101
+ encoding[f"image_{key}"] = value
102
+ encoding["image_indices"] = image_indices
103
+
104
  return encoding
 
105
 
106
+ def forward(
107
+ self, features: Dict[str, torch.Tensor], task: Optional[str] = None
108
+ ) -> Dict[str, torch.Tensor]:
109
  self.model.eval()
110
 
111
  if task is None:
 
118
  task = self.default_task
119
  else:
120
  if task not in self.config.task_names:
121
+ raise ValueError(
122
+ f"Invalid task: {task}. Must be one of {self.config.task_names}."
123
+ )
124
 
125
  device = self.model.device.type
126
  all_embeddings = []
127
+
128
  with torch.no_grad():
129
+ if any(k.startswith("text_") for k in features.keys()):
130
+ text_batch = {
131
+ k[len("text_") :]: v.to(device)
132
+ for k, v in features.items()
133
+ if k.startswith("text_") and k != "text_indices"
134
+ }
135
+ text_indices = features.get("text_indices", [])
136
+
137
  with torch.autocast(device_type=device):
138
+ text_embeddings = self.model(
139
+ **text_batch, task_label=task
140
+ ).single_vec_emb
141
  if self.config.truncate_dim:
142
+ text_embeddings = text_embeddings[:, : self.config.truncate_dim]
143
+
144
  for i, embedding in enumerate(text_embeddings):
145
  all_embeddings.append((text_indices[i], embedding))
146
+
147
+ if any(k.startswith("image_") for k in features.keys()):
148
+ image_batch = {
149
+ k[len("image_") :]: v.to(device)
150
+ for k, v in features.items()
151
+ if k.startswith("image_") and k != "image_indices"
152
+ }
153
+ image_indices = features.get("image_indices", [])
154
+
155
  with torch.autocast(device_type=device):
156
+ img_embeddings = self.model(
157
+ **image_batch, task_label=task
158
+ ).single_vec_emb
159
  if self.config.truncate_dim:
160
+ img_embeddings = img_embeddings[:, : self.config.truncate_dim]
161
+
162
  for i, embedding in enumerate(img_embeddings):
163
  all_embeddings.append((image_indices[i], embedding))
164
 
165
  if not all_embeddings:
166
+ raise RuntimeError("No embeddings were generated")
167
 
168
  all_embeddings.sort(key=lambda x: x[0]) # sort by original index
169
  combined_embeddings = torch.stack([emb for _, emb in all_embeddings])
170
+ features["sentence_embedding"] = combined_embeddings
171
+
172
  return features
modeling_jina_embeddings_v4.py CHANGED
@@ -5,20 +5,24 @@ import os
5
  from dataclasses import dataclass
6
  from enum import Enum
7
  from functools import partial
 
8
  from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
9
 
10
  import numpy as np
 
11
  import torch
12
  from huggingface_hub import snapshot_download
13
- from peft import PeftModel, LoraConfig
14
  from PIL import Image
15
  from torch import nn
16
  from torch.utils.data import DataLoader
17
  from tqdm import tqdm
18
  from transformers import BatchFeature
19
- from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
 
20
  from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
21
  from .custom_lora_module import MultiAdapterLinear
 
22
 
23
 
24
  class PromptType(str, Enum):
@@ -140,7 +144,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
140
  self._init_projection_layers(config)
141
  self.post_init()
142
  self.processor = JinaEmbeddingsV4Processor.from_pretrained(
143
- self.name_or_path, trust_remote_code=True
144
  )
145
  self.single_vector_projector_dim = config.single_vector_projector_dim
146
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
@@ -160,7 +164,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
160
  task (str): The task name. Must be one of ['retrieval', 'text-matching', 'code']
161
  """
162
  if task not in self.config.task_names:
163
- raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
 
 
164
  self._task = task
165
 
166
  def get_last_hidden_states(
@@ -342,7 +348,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
342
  for batch in tqdm(dataloader, desc=desc):
343
  with torch.no_grad():
344
  batch = {k: v.to(self.device) for k, v in batch.items()}
345
- with torch.autocast(device_type=torch.device(self.device).type, dtype=torch.bfloat16):
 
 
346
  embeddings = self(**batch, task_label=task_label)
347
  if vector_type == "single_vector":
348
  embeddings = embeddings.single_vec_emb
@@ -395,7 +403,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
395
  encode_kwargs["truncate_dim"] = truncate_dim
396
 
397
  return encode_kwargs
398
-
399
  def _validate_task(self, task: Optional[str] = None) -> str:
400
  if task is None:
401
  if self.task is None:
@@ -406,7 +414,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
406
  task = self.task
407
  else:
408
  if task not in self.config.task_names:
409
- raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
 
 
410
  return task
411
 
412
  def encode_texts(
@@ -460,9 +470,23 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
460
 
461
  return embeddings
462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  def encode_images(
464
  self,
465
- images: List[Image.Image],
466
  task: Optional[str] = None,
467
  batch_size: int = 8,
468
  vector_type: Optional[str] = None,
@@ -474,7 +498,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
474
  Encodes a list of images into embeddings.
475
 
476
  Args:
477
- images: List of PIL images to encode
478
  batch_size: Number of images to process at once
479
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
480
  return_numpy: Whether to return numpy arrays instead of torch tensors
@@ -489,9 +513,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
489
  self.processor.image_processor.max_pixels = (
490
  max_pixels # change during encoding
491
  )
492
-
493
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
494
  task = self._validate_task(task)
 
495
  embeddings = self._process_batches(
496
  data=images,
497
  processor_fn=self.processor.process_images,
@@ -519,8 +543,10 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
519
  """
520
  if "torch_dtype" not in kwargs:
521
  kwargs["torch_dtype"] = "auto"
522
-
523
  kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
 
 
524
 
525
  base_model = super().from_pretrained(
526
  pretrained_model_name_or_path, *args, **kwargs
@@ -547,19 +573,19 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
547
  model_id=adapter_dir,
548
  config=lora_config,
549
  )
550
-
551
  @property
552
  def task(self):
553
  return self.model.task
554
-
555
  @task.setter
556
  def task(self, value):
557
  self.model.task = value
558
-
559
  peft_model.task = property(task.fget, task.fset)
560
  peft_model.__class__.task = property(
561
  lambda self: self.model.task,
562
- lambda self, value: setattr(self.model, 'task', value)
563
  )
564
 
565
  return peft_model
 
5
  from dataclasses import dataclass
6
  from enum import Enum
7
  from functools import partial
8
+ from io import BytesIO
9
  from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
10
 
11
  import numpy as np
12
+ import requests
13
  import torch
14
  from huggingface_hub import snapshot_download
15
+ from peft import LoraConfig, PeftModel
16
  from PIL import Image
17
  from torch import nn
18
  from torch.utils.data import DataLoader
19
  from tqdm import tqdm
20
  from transformers import BatchFeature
21
+ from transformers.utils import is_flash_attn_2_available
22
+
23
  from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
24
  from .custom_lora_module import MultiAdapterLinear
25
+ from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
26
 
27
 
28
  class PromptType(str, Enum):
 
144
  self._init_projection_layers(config)
145
  self.post_init()
146
  self.processor = JinaEmbeddingsV4Processor.from_pretrained(
147
+ self.name_or_path, trust_remote_code=True, use_fast=True
148
  )
149
  self.single_vector_projector_dim = config.single_vector_projector_dim
150
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
 
164
  task (str): The task name. Must be one of ['retrieval', 'text-matching', 'code']
165
  """
166
  if task not in self.config.task_names:
167
+ raise ValueError(
168
+ f"Invalid task: {task}. Must be one of {self.config.task_names}."
169
+ )
170
  self._task = task
171
 
172
  def get_last_hidden_states(
 
348
  for batch in tqdm(dataloader, desc=desc):
349
  with torch.no_grad():
350
  batch = {k: v.to(self.device) for k, v in batch.items()}
351
+ with torch.autocast(
352
+ device_type=torch.device(self.device).type, dtype=torch.bfloat16
353
+ ):
354
  embeddings = self(**batch, task_label=task_label)
355
  if vector_type == "single_vector":
356
  embeddings = embeddings.single_vec_emb
 
403
  encode_kwargs["truncate_dim"] = truncate_dim
404
 
405
  return encode_kwargs
406
+
407
  def _validate_task(self, task: Optional[str] = None) -> str:
408
  if task is None:
409
  if self.task is None:
 
414
  task = self.task
415
  else:
416
  if task not in self.config.task_names:
417
+ raise ValueError(
418
+ f"Invalid task: {task}. Must be one of {self.config.task_names}."
419
+ )
420
  return task
421
 
422
  def encode_texts(
 
470
 
471
  return embeddings
472
 
473
+ def _load_images_if_needed(
474
+ self, images: List[Union[str, Image.Image]]
475
+ ) -> List[Image.Image]:
476
+ loaded_images = []
477
+ for image in images:
478
+ if isinstance(image, str):
479
+ if image.startswith("http"):
480
+ response = requests.get(image)
481
+ image = Image.open(BytesIO(response.content)).convert("RGB")
482
+ else:
483
+ image = Image.open(image).convert("RGB")
484
+ loaded_images.append(image)
485
+ return loaded_images
486
+
487
  def encode_images(
488
  self,
489
+ images: List[Union[str, Image.Image]],
490
  task: Optional[str] = None,
491
  batch_size: int = 8,
492
  vector_type: Optional[str] = None,
 
498
  Encodes a list of images into embeddings.
499
 
500
  Args:
501
+ images: List of PIL images, URLs, or local file paths to encode
502
  batch_size: Number of images to process at once
503
  vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
504
  return_numpy: Whether to return numpy arrays instead of torch tensors
 
513
  self.processor.image_processor.max_pixels = (
514
  max_pixels # change during encoding
515
  )
 
516
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
517
  task = self._validate_task(task)
518
+ images = self._load_images_if_needed(images)
519
  embeddings = self._process_batches(
520
  data=images,
521
  processor_fn=self.processor.process_images,
 
543
  """
544
  if "torch_dtype" not in kwargs:
545
  kwargs["torch_dtype"] = "auto"
546
+
547
  kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
548
+ if not is_flash_attn_2_available():
549
+ kwargs["attn_implementation"] = "sdpa"
550
 
551
  base_model = super().from_pretrained(
552
  pretrained_model_name_or_path, *args, **kwargs
 
573
  model_id=adapter_dir,
574
  config=lora_config,
575
  )
576
+
577
  @property
578
  def task(self):
579
  return self.model.task
580
+
581
  @task.setter
582
  def task(self, value):
583
  self.model.task = value
584
+
585
  peft_model.task = property(task.fget, task.fset)
586
  peft_model.__class__.task = property(
587
  lambda self: self.model.task,
588
+ lambda self, value: setattr(self.model, "task", value),
589
  )
590
 
591
  return peft_model
tokenizer_config.json CHANGED
@@ -202,7 +202,7 @@
202
  "extra_special_tokens": {},
203
  "model_max_length": 131072,
204
  "pad_token": "<|endoftext|>",
205
- "processor_class": "ColQwen25DuoProcessor",
206
  "split_special_tokens": false,
207
  "tokenizer_class": "Qwen2Tokenizer",
208
  "unk_token": null
 
202
  "extra_special_tokens": {},
203
  "model_max_length": 131072,
204
  "pad_token": "<|endoftext|>",
205
+ "processor_class": "JinaEmbeddingsV4Processor",
206
  "split_special_tokens": false,
207
  "tokenizer_class": "Qwen2Tokenizer",
208
  "unk_token": null