jupyterjazz commited on
Commit
e7f92e1
·
verified ·
1 Parent(s): 76705e2

stateless-adapter-switching (#11)

Browse files

- feat: implement stateless adapter switching [wip] (85f64e256eb783288093a9ff8fc63f9ba66ba2f5)
- feat: finalized implementation (8a9e9edbdb6c678e370f8fed2211f688e30a2fca)
- feat: merged checkpoint, modified qwen, readme (b3d45f64324f0d2b9542b691694e97e581136e97)

README.md CHANGED
@@ -22,11 +22,9 @@ image_paths = ['/<path_to_image>']
22
  images = [Image.open(path) for path in image_paths]
23
 
24
  # Example 1: Text matching task with single vector embeddings
25
- model.set_task(task='text-matching')
26
-
27
  # Generate embeddings with dimension truncation (256), decrease max_pixels
28
- img_embeddings = model.encode_images(images=images, truncate_dim=256, max_pixels=602112)
29
- text_embeddings = model.encode_texts(texts=texts, truncate_dim=256, max_length=512)
30
 
31
  # Example 2: Retrieval task with multi-vector embeddings
32
  model.set_task(task='retrieval')
@@ -36,10 +34,8 @@ img_embeddings = model.encode_images(images=images, vector_type='multi_vector')
36
  text_embeddings = model.encode_texts(texts=texts, vector_type='multi_vector', prompt_name='passage')
37
 
38
  # Example 3: Code task with single vector embeddings
39
- model.set_task(task='code')
40
-
41
  code = ["def hello_world():\n print('Hello, World!')"]
42
- code_embeddings = model.encode_texts(texts=code)
43
 
44
  ```
45
 
@@ -75,8 +71,8 @@ with torch.no_grad():
75
 
76
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
77
  # Get embeddings
78
- text_embeddings = model.model(**text_batch).single_vec_emb
79
- img_embeddings = model.model(**image_batch).single_vec_emb
80
 
81
 
82
  ```
 
22
  images = [Image.open(path) for path in image_paths]
23
 
24
  # Example 1: Text matching task with single vector embeddings
 
 
25
  # Generate embeddings with dimension truncation (256), decrease max_pixels
26
+ img_embeddings = model.encode_images(images=images, truncate_dim=256, max_pixels=602112, task='text-matching')
27
+ text_embeddings = model.encode_texts(texts=texts, truncate_dim=256, max_length=512, task='text-matching')
28
 
29
  # Example 2: Retrieval task with multi-vector embeddings
30
  model.set_task(task='retrieval')
 
34
  text_embeddings = model.encode_texts(texts=texts, vector_type='multi_vector', prompt_name='passage')
35
 
36
  # Example 3: Code task with single vector embeddings
 
 
37
  code = ["def hello_world():\n print('Hello, World!')"]
38
+ code_embeddings = model.encode_texts(texts=code, task='code')
39
 
40
  ```
41
 
 
71
 
72
  with torch.autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
73
  # Get embeddings
74
+ text_embeddings = model.model(**text_batch, task_label='retrieval').single_vec_emb
75
+ img_embeddings = model.model(**image_batch, task_label='retrieval').single_vec_emb
76
 
77
 
78
  ```
adapters/{retrieval/adapter_config.json → adapter_config.json} RENAMED
File without changes
adapters/{text-matching/adapter_model.safetensors → adapter_model.safetensors} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3677815cef695c54aae2358c574c046d6d9a5787fd96ca457ee00ac656576985
3
- size 120138416
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a5cb8cc0f4e10f184ccc10f8864999098b887dbc4107221ec0e400d927f4555
3
+ size 360095344
adapters/code/adapter_config.json DELETED
@@ -1,26 +0,0 @@
1
- {
2
- "alpha_pattern": {},
3
- "auto_mapping": null,
4
- "base_model_name_or_path": "jinaai/colqwen25-duo-base",
5
- "bias": "none",
6
- "fan_in_fan_out": false,
7
- "inference_mode": false,
8
- "init_lora_weights": "gaussian",
9
- "layer_replication": null,
10
- "layers_pattern": null,
11
- "layers_to_transform": null,
12
- "loftq_config": {},
13
- "lora_alpha": 32,
14
- "lora_dropout": 0.1,
15
- "megatron_config": null,
16
- "megatron_core": "megatron.core",
17
- "modules_to_save": null,
18
- "peft_type": "LORA",
19
- "r": 32,
20
- "rank_pattern": {},
21
- "revision": null,
22
- "target_modules": "(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(single_vector_projector|multi_vector_projector).*$)",
23
- "task_type": "FEATURE_EXTRACTION",
24
- "use_dora": false,
25
- "use_rslora": false
26
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adapters/code/adapter_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:510d017efc64c97e2db985ed1a96b17477ac97e1a5470996209041ad35beeee7
3
- size 119802032
 
 
 
 
adapters/retrieval/adapter_model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0c2b1d85506d01bd29a942975cb0abbd8c4af3487fb80b5ad408ae0e55f8bb3a
3
- size 120138416
 
 
 
 
adapters/text-matching/adapter_config.json DELETED
@@ -1,26 +0,0 @@
1
- {
2
- "alpha_pattern": {},
3
- "auto_mapping": null,
4
- "base_model_name_or_path": "jinaai/colqwen25-duo-base",
5
- "bias": "none",
6
- "fan_in_fan_out": false,
7
- "inference_mode": true,
8
- "init_lora_weights": "gaussian",
9
- "layer_replication": null,
10
- "layers_pattern": null,
11
- "layers_to_transform": null,
12
- "loftq_config": {},
13
- "lora_alpha": 32,
14
- "lora_dropout": 0.1,
15
- "megatron_config": null,
16
- "megatron_core": "megatron.core",
17
- "modules_to_save": null,
18
- "peft_type": "LORA",
19
- "r": 32,
20
- "rank_pattern": {},
21
- "revision": null,
22
- "target_modules": "(.*(model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$|.*(single_vector_projector|multi_vector_projector).*$)",
23
- "task_type": "FEATURE_EXTRACTION",
24
- "use_dora": false,
25
- "use_rslora": false
26
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -54,5 +54,7 @@
54
  "vision_start_token_id": 151652,
55
  "vision_token_id": 151654,
56
  "vocab_size": 151936,
57
- "truncate_dim": null
 
 
58
  }
 
54
  "vision_start_token_id": 151652,
55
  "vision_token_id": 151654,
56
  "vocab_size": 151936,
57
+ "truncate_dim": null,
58
+ "task_names": ["retrieval", "text-matching", "code"],
59
+ "matryoshka_dims": [128, 256, 512, 1024]
60
  }
custom_lora_module.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ import warnings
5
+ from typing import Any, Optional, Union, List
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from peft.tuners.lora import LoraLayer
11
+
12
+ class MultiAdapterLinear(nn.Module, LoraLayer):
13
+ """
14
+ Custom LoRA module supporting multiple adapters for a linear layer.
15
+
16
+ This module extends the standard LoRA implementation to support multiple task-specific
17
+ adapters that can be dynamically selected during the forward pass. The task_label
18
+ parameter passed to the forward function determines which LoRA adapter(s) to use:
19
+ - If task_label is a string, all examples in the batch use the same adapter
20
+ - If task_label is a list of strings, each example can use a different adapter
21
+
22
+ This enables efficient multi-task inference where all task-specific LoRA adapters
23
+ are loaded in memory simultaneously and dynamically selected per example, eliminating
24
+ the need to switch adapter states between tasks and allowing optimal throughput
25
+ for mixed-task batches.
26
+
27
+ Derived from peft.tuners.lora.Linear.
28
+ """
29
+ def __init__(
30
+ self,
31
+ base_layer,
32
+ adapter_name: str,
33
+ task_names: List[str],
34
+ r: int = 0,
35
+ lora_alpha: int = 1,
36
+ lora_dropout: float = 0.0,
37
+ fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
38
+ is_target_conv_1d_layer: bool = False,
39
+ init_lora_weights: Union[bool, str] = True,
40
+ use_rslora: bool = False,
41
+ use_dora: bool = False,
42
+ lora_bias: bool = False,
43
+ **kwargs,
44
+ ) -> None:
45
+ super().__init__()
46
+ LoraLayer.__init__(self, base_layer, **kwargs)
47
+
48
+ self.fan_in_fan_out = fan_in_fan_out
49
+ self.task_names = task_names
50
+ self._active_adapter = adapter_name
51
+ self.update_layer(
52
+ adapter_name,
53
+ r,
54
+ lora_alpha=lora_alpha,
55
+ lora_dropout=lora_dropout,
56
+ init_lora_weights=init_lora_weights,
57
+ use_rslora=use_rslora,
58
+ use_dora=use_dora,
59
+ lora_bias=lora_bias,
60
+ )
61
+ self.is_target_conv_1d_layer = is_target_conv_1d_layer
62
+
63
+
64
+ def forward(self, x: torch.Tensor, task_label: Union[str, List[str]], *args: Any, **kwargs: Any) -> torch.Tensor:
65
+ self._check_forward_args(x, *args, **kwargs)
66
+
67
+ if self.disable_adapters:
68
+ if self.merged:
69
+ self.unmerge()
70
+ result = self.base_layer(x, *args, **kwargs)
71
+ elif self.merged:
72
+ result = self.base_layer(x, *args, **kwargs)
73
+ else:
74
+ result = self.base_layer(x, *args, **kwargs)
75
+ torch_result_dtype = result.dtype
76
+
77
+ lora_A_keys = self.lora_A.keys()
78
+ for active_adapter in self.active_adapters:
79
+ if active_adapter not in lora_A_keys:
80
+ continue
81
+
82
+ if isinstance(task_label, str):
83
+ lora_A = self.lora_A[active_adapter][task_label]
84
+ lora_B = self.lora_B[active_adapter][task_label]
85
+ dropout = self.lora_dropout[active_adapter]
86
+ scaling = self.scaling[active_adapter]
87
+ x = self._cast_input_dtype(x, lora_A.weight.dtype)
88
+ result = result + lora_B(lora_A(dropout(x))) * scaling
89
+ else:
90
+ unique_tasks = list(set(task_label))
91
+ lora_output = torch.zeros_like(result)
92
+
93
+ for task in unique_tasks:
94
+ task_indices = [i for i, t in enumerate(task_label) if t == task]
95
+ task_x = x[task_indices]
96
+
97
+ lora_A = self.lora_A[active_adapter][task]
98
+ lora_B = self.lora_B[active_adapter][task]
99
+ dropout = self.lora_dropout[active_adapter]
100
+ scaling = self.scaling[active_adapter]
101
+
102
+ task_x = self._cast_input_dtype(task_x, lora_A.weight.dtype)
103
+ task_lora_value = lora_B(lora_A(dropout(task_x))) * scaling
104
+
105
+ for i, idx in enumerate(task_indices):
106
+ lora_output[idx] = task_lora_value[i]
107
+
108
+ result = result + lora_output
109
+
110
+ result = result.to(torch_result_dtype)
111
+
112
+ return result
113
+
114
+ def __repr__(self) -> str:
115
+ rep = super().__repr__()
116
+ return "lora." + rep
117
+
118
+
119
+ def update_layer(
120
+ self,
121
+ adapter_name,
122
+ r,
123
+ lora_alpha,
124
+ lora_dropout,
125
+ init_lora_weights,
126
+ use_rslora,
127
+ use_dora: bool = False,
128
+ lora_bias: bool = False,
129
+ ):
130
+ # This code works for linear layers, override for other layer types
131
+ if r <= 0:
132
+ raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
133
+
134
+ self.r[adapter_name] = r
135
+ self.lora_alpha[adapter_name] = lora_alpha
136
+ if lora_dropout > 0.0:
137
+ lora_dropout_layer = nn.Dropout(p=lora_dropout)
138
+ else:
139
+ lora_dropout_layer = nn.Identity()
140
+
141
+ self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
142
+ # Actual trainable parameters
143
+ self.lora_A[adapter_name] = nn.ModuleDict({
144
+ task_name: nn.Linear(self.in_features, r, bias=False)
145
+ for task_name in self.task_names
146
+ })
147
+ self.lora_B[adapter_name] = nn.ModuleDict({
148
+ task_name: nn.Linear(r, self.out_features, bias=lora_bias)
149
+ for task_name in self.task_names
150
+ })
151
+ self.lora_bias[adapter_name] = lora_bias
152
+
153
+ if use_rslora:
154
+ self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
155
+ else:
156
+ self.scaling[adapter_name] = lora_alpha / r
157
+
158
+ self.reset_lora_parameters(adapter_name, init_lora_weights)
159
+ self._move_adapter_to_device_of_base_layer(adapter_name)
160
+ self.use_dora[adapter_name] = False
161
+ self.set_adapter(self.active_adapters)
162
+
163
+ def reset_lora_parameters(self, adapter_name, init_lora_weights):
164
+ if init_lora_weights is False:
165
+ return
166
+ if init_lora_weights is True:
167
+ # initialize A the same way as the default for nn.Linear and B to zero
168
+ # https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L124
169
+ for task_name in self.task_names:
170
+ nn.init.kaiming_uniform_(self.lora_A[adapter_name][task_name].weight, a=math.sqrt(5))
171
+ elif init_lora_weights.lower() == "gaussian":
172
+ for task_name in self.task_names:
173
+ nn.init.normal_(self.lora_A[adapter_name][task_name].weight, std=1 / self.r[adapter_name])
174
+ else:
175
+ raise ValueError(f"Unknown initialization {init_lora_weights=}")
176
+ for task_name in self.task_names:
177
+ nn.init.zeros_(self.lora_B[adapter_name][task_name].weight)
178
+ if self.lora_bias[adapter_name]:
179
+ for task_name in self.task_names:
180
+ nn.init.zeros_(self.lora_B[adapter_name][task_name].bias)
181
+
182
+
183
+ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
184
+ """
185
+ Merge the active adapter weights into the base weights
186
+ """
187
+ raise NotImplementedError("Merge operation is not supported")
188
+
189
+ def unmerge(self) -> None:
190
+ """
191
+ This method unmerges all merged adapter layers from the base weights.
192
+ """
193
+ raise NotImplementedError("Unmerge operation is not supported")
modeling_jina_embeddings_v4.py CHANGED
@@ -10,17 +10,17 @@ from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
10
  import numpy as np
11
  import torch
12
  from huggingface_hub import snapshot_download
13
- from peft import PeftModel
14
  from peft.utils.hotswap import hotswap_adapter
15
  from PIL import Image
16
  from torch import nn
17
  from torch.utils.data import DataLoader
18
  from tqdm import tqdm
19
  from transformers import BatchFeature
20
- from transformers.models.qwen2_5_vl import (Qwen2_5_VLForConditionalGeneration,
21
- Qwen2_5_VLProcessor)
22
-
23
  from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
 
 
24
 
25
 
26
  class PromptType(str, Enum):
@@ -28,14 +28,7 @@ class PromptType(str, Enum):
28
  passage = "passage"
29
 
30
 
31
- class TaskType(str, Enum):
32
- retrieval = "retrieval"
33
- code = "code"
34
- text_matching = "text-matching"
35
-
36
-
37
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
38
- TRUNCATE_DIMS = [128, 256, 512, 1024]
39
  VECTOR_TYPES = ["single_vector", "multi_vector"]
40
 
41
 
@@ -153,9 +146,28 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
153
  )
154
  self.single_vector_projector_dim = config.single_vector_projector_dim
155
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  def get_last_hidden_states(
158
  self,
 
159
  input_ids: torch.LongTensor,
160
  attention_mask: torch.Tensor,
161
  **kwargs,
@@ -173,10 +185,10 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
173
  )
174
 
175
  kwargs["output_hidden_states"] = True
176
-
177
  outputs = super().forward(
178
- input_ids,
179
- attention_mask,
 
180
  **kwargs,
181
  position_ids=position_ids,
182
  rope_deltas=rope_deltas,
@@ -208,6 +220,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
208
 
209
  def project_to_single_vector_embeddings(
210
  self,
 
211
  hidden_states: torch.Tensor,
212
  attention_mask: torch.Tensor,
213
  input_ids: Optional[torch.LongTensor] = None,
@@ -216,33 +229,48 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
216
  Project the hidden states to single-vector embeddings.
217
  """
218
  if self._input_has_image(input_ids[0]): # got document image
219
- img_start_positions = torch.where(input_ids == self.config.vision_start_token_id)[1]
220
- img_end_positions = torch.where(input_ids == self.config.vision_end_token_id)[1]
221
-
 
 
 
 
222
  batch_size, seq_len = input_ids.shape
223
- position_indices = torch.arange(seq_len, device=input_ids.device).expand(batch_size, -1)
224
- image_mask = (position_indices >= img_start_positions.unsqueeze(1)) & (position_indices <= img_end_positions.unsqueeze(1))
225
-
 
 
 
 
226
  masked_hidden_states = hidden_states * image_mask.unsqueeze(-1)
227
- pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(dim=1, keepdim=True)
 
 
228
 
229
  else: # got query text
230
  pooled_output = torch.sum(
231
  hidden_states * attention_mask.unsqueeze(-1), dim=1
232
  ) / torch.sum(attention_mask, dim=1, keepdim=True)
233
 
234
- single_vec_emb = self.single_vector_projector(pooled_output)
 
 
235
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
236
 
237
  def project_to_multi_vector_embeddings(
238
  self,
 
239
  hidden_states: torch.Tensor,
240
  attention_mask: torch.Tensor,
241
  ) -> torch.Tensor:
242
  """
243
  Project the hidden states to multi-vector embeddings.
244
  """
245
- multi_vec_emb = self.multi_vector_projector(hidden_states)
 
 
246
  multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1)
247
  return multi_vec_emb * attention_mask.unsqueeze(-1)
248
 
@@ -251,6 +279,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
251
 
252
  def forward(
253
  self,
 
254
  input_ids: torch.LongTensor,
255
  attention_mask: torch.Tensor,
256
  output_vlm_last_hidden_states: bool = False,
@@ -268,15 +297,22 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
268
  """
269
  # Forward pass through the VLM
270
  hidden_states = self.get_last_hidden_states(
271
- input_ids=input_ids, attention_mask=attention_mask, **kwargs
 
 
 
272
  ) # (batch_size, seq_length, hidden_size)
273
-
274
  # Compute the embeddings
275
  single_vec_emb = self.project_to_single_vector_embeddings(
276
- hidden_states, attention_mask, input_ids=input_ids
 
 
 
277
  )
278
  multi_vec_emb = self.project_to_multi_vector_embeddings(
279
- hidden_states, attention_mask
 
 
280
  )
281
 
282
  return JinaEmbeddingsV4ModelOutput(
@@ -290,6 +326,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
290
  def _process_batches(
291
  self,
292
  data: List[Union[str, Image.Image]],
 
293
  processor_fn: Callable,
294
  desc: str,
295
  vector_type: str = "single_vector",
@@ -309,7 +346,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
309
  with torch.no_grad():
310
  batch = {k: v.to(self.device) for k, v in batch.items()}
311
  with torch.autocast(device_type=torch.device(self.device).type):
312
- embeddings = self(**batch)
313
  if vector_type == "single_vector":
314
  embeddings = embeddings.single_vec_emb
315
  if truncate_dim is not None:
@@ -340,7 +377,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
340
  else:
341
  encode_kwargs["prefix"] = (
342
  PREFIX_DICT[prompt_name]
343
- if self.task != TaskType.text_matching
344
  else PREFIX_DICT["query"]
345
  )
346
 
@@ -353,18 +390,32 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
353
  encode_kwargs["vector_type"] = vector_type
354
 
355
  truncate_dim = truncate_dim or self.config.truncate_dim
356
- if truncate_dim is not None and truncate_dim not in TRUNCATE_DIMS:
357
  raise ValueError(
358
- f"Invalid truncate_dim: {truncate_dim}. Must be one of {TRUNCATE_DIMS}."
359
  )
360
  else:
361
  encode_kwargs["truncate_dim"] = truncate_dim
362
 
363
  return encode_kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  def encode_texts(
366
  self,
367
  texts: List[str],
 
368
  max_length: int = 8192,
369
  batch_size: int = 8,
370
  vector_type: Optional[str] = None,
@@ -392,6 +443,8 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
392
  vector_type, truncate_dim, prompt_name
393
  )
394
 
 
 
395
  processor_fn = partial(
396
  self.processor.process_texts,
397
  max_length=max_length,
@@ -402,6 +455,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
402
  data=texts,
403
  processor_fn=processor_fn,
404
  desc="Encoding texts...",
 
405
  return_numpy=return_numpy,
406
  batch_size=batch_size,
407
  **encode_kwargs,
@@ -412,6 +466,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
412
  def encode_images(
413
  self,
414
  images: List[Image.Image],
 
415
  batch_size: int = 8,
416
  vector_type: Optional[str] = None,
417
  return_numpy: bool = False,
@@ -434,14 +489,17 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
434
  """
435
  if max_pixels:
436
  default_max_pixels = self.processor.image_processor.max_pixels
437
- self.processor.image_processor.max_pixels = max_pixels # change during encoding
 
 
438
 
439
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
440
-
441
  embeddings = self._process_batches(
442
  data=images,
443
  processor_fn=self.processor.process_images,
444
  desc="Encoding images...",
 
445
  batch_size=batch_size,
446
  return_numpy=return_numpy,
447
  **encode_kwargs,
@@ -464,15 +522,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
464
  """
465
  if "torch_dtype" not in kwargs:
466
  kwargs["torch_dtype"] = "auto"
467
-
468
- task_value = kwargs.pop("task", "retrieval")
469
- try:
470
- task = TaskType(task_value)
471
- except ValueError:
472
- valid_tasks = [t.value for t in TaskType]
473
- raise ValueError(
474
- f"Invalid task: {task_value}. Must be one of {valid_tasks}."
475
- )
476
 
477
  base_model = super().from_pretrained(
478
  pretrained_model_name_or_path, *args, **kwargs
@@ -487,44 +539,31 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
487
  )
488
  adapter_dir = os.path.join(adapter_cache_path, "adapters")
489
 
490
- base_model.adapter_dir = adapter_dir
491
- base_model.task = task
492
-
493
- # Create the PEFT model with the requested task adapter
 
 
 
494
  peft_model = PeftModel.from_pretrained(
495
- base_model, os.path.join(adapter_dir, task.value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  )
497
-
498
- # Add set_task method to the PEFT model instance
499
- def set_task_method(self, task: Union[str, TaskType]):
500
- """
501
- Set the task adapter for the model.
502
-
503
- Args:
504
- task (Union[str, TaskType]): The task name. Must be one of TaskType values or
505
- one of ['retrieval', 'text-matching', 'code']
506
- """
507
- if isinstance(task, str):
508
- try:
509
- task = TaskType(task)
510
- except ValueError:
511
- valid_tasks = [t.value for t in TaskType]
512
- raise ValueError(
513
- f"Invalid task: {task}. Must be one of {valid_tasks}"
514
- )
515
- if self.model.task != task:
516
- adapter_path = os.path.join(self.adapter_dir, task.value)
517
- hotswap_adapter(self, adapter_path, adapter_name="default")
518
- self.model.task = task
519
-
520
- def get_task_method(self):
521
- """
522
- Get the task adapter for the model.
523
- """
524
- return self.model.task.value
525
-
526
- # Bind the methods to the instance
527
- peft_model.set_task = set_task_method.__get__(peft_model, type(peft_model))
528
- peft_model.get_task = get_task_method.__get__(peft_model, type(peft_model))
529
 
530
  return peft_model
 
10
  import numpy as np
11
  import torch
12
  from huggingface_hub import snapshot_download
13
+ from peft import PeftModel, LoraConfig
14
  from peft.utils.hotswap import hotswap_adapter
15
  from PIL import Image
16
  from torch import nn
17
  from torch.utils.data import DataLoader
18
  from tqdm import tqdm
19
  from transformers import BatchFeature
20
+ from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
 
 
21
  from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
22
+ import peft
23
+ from .custom_lora_module import MultiAdapterLinear
24
 
25
 
26
  class PromptType(str, Enum):
 
28
  passage = "passage"
29
 
30
 
 
 
 
 
 
 
31
  PREFIX_DICT = {"query": "Query", "passage": "Passage"}
 
32
  VECTOR_TYPES = ["single_vector", "multi_vector"]
33
 
34
 
 
146
  )
147
  self.single_vector_projector_dim = config.single_vector_projector_dim
148
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
149
+ self._task = None
150
+
151
+ @property
152
+ def task(self) -> Optional[str]:
153
+ """Get the current task set for the model."""
154
+ return self._task
155
+
156
+ @task.setter
157
+ def task(self, task: str):
158
+ """
159
+ Set the task for the model.
160
+
161
+ Args:
162
+ task (str): The task name. Must be one of ['retrieval', 'text-matching', 'code']
163
+ """
164
+ if task not in self.config.task_names:
165
+ raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
166
+ self._task = task
167
 
168
  def get_last_hidden_states(
169
  self,
170
+ task_label: Union[str, List[str]],
171
  input_ids: torch.LongTensor,
172
  attention_mask: torch.Tensor,
173
  **kwargs,
 
185
  )
186
 
187
  kwargs["output_hidden_states"] = True
 
188
  outputs = super().forward(
189
+ task_label=task_label,
190
+ input_ids=input_ids,
191
+ attention_mask=attention_mask,
192
  **kwargs,
193
  position_ids=position_ids,
194
  rope_deltas=rope_deltas,
 
220
 
221
  def project_to_single_vector_embeddings(
222
  self,
223
+ task_label: Union[str, List[str]],
224
  hidden_states: torch.Tensor,
225
  attention_mask: torch.Tensor,
226
  input_ids: Optional[torch.LongTensor] = None,
 
229
  Project the hidden states to single-vector embeddings.
230
  """
231
  if self._input_has_image(input_ids[0]): # got document image
232
+ img_start_positions = torch.where(
233
+ input_ids == self.config.vision_start_token_id
234
+ )[1]
235
+ img_end_positions = torch.where(
236
+ input_ids == self.config.vision_end_token_id
237
+ )[1]
238
+
239
  batch_size, seq_len = input_ids.shape
240
+ position_indices = torch.arange(seq_len, device=input_ids.device).expand(
241
+ batch_size, -1
242
+ )
243
+ image_mask = (position_indices >= img_start_positions.unsqueeze(1)) & (
244
+ position_indices <= img_end_positions.unsqueeze(1)
245
+ )
246
+
247
  masked_hidden_states = hidden_states * image_mask.unsqueeze(-1)
248
+ pooled_output = masked_hidden_states.sum(dim=1) / image_mask.sum(
249
+ dim=1, keepdim=True
250
+ )
251
 
252
  else: # got query text
253
  pooled_output = torch.sum(
254
  hidden_states * attention_mask.unsqueeze(-1), dim=1
255
  ) / torch.sum(attention_mask, dim=1, keepdim=True)
256
 
257
+ single_vec_emb = self.single_vector_projector(
258
+ pooled_output, task_label=task_label
259
+ )
260
  return torch.nn.functional.normalize(single_vec_emb, dim=-1)
261
 
262
  def project_to_multi_vector_embeddings(
263
  self,
264
+ task_label: Union[str, List[str]],
265
  hidden_states: torch.Tensor,
266
  attention_mask: torch.Tensor,
267
  ) -> torch.Tensor:
268
  """
269
  Project the hidden states to multi-vector embeddings.
270
  """
271
+ multi_vec_emb = self.multi_vector_projector(
272
+ hidden_states, task_label=task_label
273
+ )
274
  multi_vec_emb = torch.nn.functional.normalize(multi_vec_emb, dim=-1)
275
  return multi_vec_emb * attention_mask.unsqueeze(-1)
276
 
 
279
 
280
  def forward(
281
  self,
282
+ task_label: Union[str, List[str]],
283
  input_ids: torch.LongTensor,
284
  attention_mask: torch.Tensor,
285
  output_vlm_last_hidden_states: bool = False,
 
297
  """
298
  # Forward pass through the VLM
299
  hidden_states = self.get_last_hidden_states(
300
+ input_ids=input_ids,
301
+ attention_mask=attention_mask,
302
+ task_label=task_label,
303
+ **kwargs,
304
  ) # (batch_size, seq_length, hidden_size)
 
305
  # Compute the embeddings
306
  single_vec_emb = self.project_to_single_vector_embeddings(
307
+ hidden_states=hidden_states,
308
+ attention_mask=attention_mask,
309
+ input_ids=input_ids,
310
+ task_label=task_label,
311
  )
312
  multi_vec_emb = self.project_to_multi_vector_embeddings(
313
+ hidden_states=hidden_states,
314
+ attention_mask=attention_mask,
315
+ task_label=task_label,
316
  )
317
 
318
  return JinaEmbeddingsV4ModelOutput(
 
326
  def _process_batches(
327
  self,
328
  data: List[Union[str, Image.Image]],
329
+ task_label: Union[str, List[str]],
330
  processor_fn: Callable,
331
  desc: str,
332
  vector_type: str = "single_vector",
 
346
  with torch.no_grad():
347
  batch = {k: v.to(self.device) for k, v in batch.items()}
348
  with torch.autocast(device_type=torch.device(self.device).type):
349
+ embeddings = self(**batch, task_label=task_label)
350
  if vector_type == "single_vector":
351
  embeddings = embeddings.single_vec_emb
352
  if truncate_dim is not None:
 
377
  else:
378
  encode_kwargs["prefix"] = (
379
  PREFIX_DICT[prompt_name]
380
+ if self.task != "text-matching"
381
  else PREFIX_DICT["query"]
382
  )
383
 
 
390
  encode_kwargs["vector_type"] = vector_type
391
 
392
  truncate_dim = truncate_dim or self.config.truncate_dim
393
+ if truncate_dim is not None and truncate_dim not in self.config.matryoshka_dims:
394
  raise ValueError(
395
+ f"Invalid truncate_dim: {truncate_dim}. Must be one of {self.config.matryoshka_dims}."
396
  )
397
  else:
398
  encode_kwargs["truncate_dim"] = truncate_dim
399
 
400
  return encode_kwargs
401
+
402
+ def _validate_task(self, task: Optional[str] = None) -> str:
403
+ if task is None:
404
+ if self.task is None:
405
+ raise ValueError(
406
+ "Task must be specified before encoding data. You can set it either as a model property "
407
+ "(e.g., model.task = 'retrieval') or pass it as an argument to the encode method."
408
+ )
409
+ task = self.task
410
+ else:
411
+ if task not in self.config.task_names:
412
+ raise ValueError(f"Invalid task: {task}. Must be one of {self.config.task_names}.")
413
+ return task
414
 
415
  def encode_texts(
416
  self,
417
  texts: List[str],
418
+ task: Optional[str] = None,
419
  max_length: int = 8192,
420
  batch_size: int = 8,
421
  vector_type: Optional[str] = None,
 
443
  vector_type, truncate_dim, prompt_name
444
  )
445
 
446
+ task = self._validate_task(task)
447
+
448
  processor_fn = partial(
449
  self.processor.process_texts,
450
  max_length=max_length,
 
455
  data=texts,
456
  processor_fn=processor_fn,
457
  desc="Encoding texts...",
458
+ task_label=task,
459
  return_numpy=return_numpy,
460
  batch_size=batch_size,
461
  **encode_kwargs,
 
466
  def encode_images(
467
  self,
468
  images: List[Image.Image],
469
+ task: Optional[str] = None,
470
  batch_size: int = 8,
471
  vector_type: Optional[str] = None,
472
  return_numpy: bool = False,
 
489
  """
490
  if max_pixels:
491
  default_max_pixels = self.processor.image_processor.max_pixels
492
+ self.processor.image_processor.max_pixels = (
493
+ max_pixels # change during encoding
494
+ )
495
 
496
  encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
497
+ task = self._validate_task(task)
498
  embeddings = self._process_batches(
499
  data=images,
500
  processor_fn=self.processor.process_images,
501
  desc="Encoding images...",
502
+ task_label=task,
503
  batch_size=batch_size,
504
  return_numpy=return_numpy,
505
  **encode_kwargs,
 
522
  """
523
  if "torch_dtype" not in kwargs:
524
  kwargs["torch_dtype"] = "auto"
525
+
526
+ if torch.cuda.is_available() and "attn_implementation" not in kwargs:
527
+ kwargs["attn_implementation"] = "flash_attention_2"
 
 
 
 
 
 
528
 
529
  base_model = super().from_pretrained(
530
  pretrained_model_name_or_path, *args, **kwargs
 
539
  )
540
  adapter_dir = os.path.join(adapter_cache_path, "adapters")
541
 
542
+ lora_config = LoraConfig.from_pretrained(adapter_dir)
543
+ lora_config._custom_modules = {
544
+ torch.nn.modules.linear.Linear: partial(
545
+ MultiAdapterLinear,
546
+ task_names=base_model.config.task_names,
547
+ )
548
+ }
549
  peft_model = PeftModel.from_pretrained(
550
+ model=base_model,
551
+ model_id=adapter_dir,
552
+ config=lora_config,
553
+ )
554
+
555
+ @property
556
+ def task(self):
557
+ return self.model.task
558
+
559
+ @task.setter
560
+ def task(self, value):
561
+ self.model.task = value
562
+
563
+ peft_model.task = property(task.fget, task.fset)
564
+ peft_model.__class__.task = property(
565
+ lambda self: self.model.task,
566
+ lambda self, value: setattr(self.model, 'task', value)
567
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
  return peft_model
qwen2_5_vl.py ADDED
The diff for this file is too large to render. See raw diff