jupyterjazz commited on
Commit
455d3b0
·
verified ·
1 Parent(s): f35e327

remove-single-vector-projection (#18)

Browse files

- refactor: remove single vec projection (5bb9539014f6cef2d3662dba50b75749e0922b50)

adapters/adapter_config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "alpha_pattern": {},
3
  "auto_mapping": null,
4
- "base_model_name_or_path": "jinaai/colqwen25-duo-base",
5
  "bias": "none",
6
  "corda_config": null,
7
  "eva_config": null,
 
1
  {
2
  "alpha_pattern": {},
3
  "auto_mapping": null,
4
+ "base_model_name_or_path": "jinaai/jina-embeddings-v4",
5
  "bias": "none",
6
  "corda_config": null,
7
  "eva_config": null,
adapters/adapter_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c9799872132988d3689a35300538fb97fc5b0e02c1c42f7afd914fd1d8b59a88
3
- size 360118024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6b7ab4a79daa3b4f3b5274500cc99d3dc89aa8c3419e9d79f89e366685e12e5
3
+ size 359863776
config.json CHANGED
@@ -33,7 +33,6 @@
33
  },
34
  "rope_theta": 1000000.0,
35
  "single_vector_pool_strategy": "mean",
36
- "single_vector_projector_dim": 1024,
37
  "sliding_window": 32768,
38
  "tie_word_embeddings": true,
39
  "torch_dtype": "bfloat16",
 
33
  },
34
  "rope_theta": 1000000.0,
35
  "single_vector_pool_strategy": "mean",
 
36
  "sliding_window": 32768,
37
  "tie_word_embeddings": true,
38
  "torch_dtype": "bfloat16",
configuration_jina_embeddings_v4.py CHANGED
@@ -9,14 +9,12 @@ class JinaEmbeddingsV4Config(Qwen2_5_VLConfig):
9
 
10
  def __init__(
11
  self,
12
- single_vector_projector_dim: int = 1024,
13
  single_vector_pool_strategy: str = "mean",
14
  multi_vector_projector_dim: int = 128,
15
  pretrained_peft_model_name_or_path: Optional[str] = None,
16
  **kwargs,
17
  ):
18
  super().__init__(**kwargs)
19
- self.single_vector_projector_dim = single_vector_projector_dim
20
  self.single_vector_pool_strategy = single_vector_pool_strategy
21
  self.multi_vector_projector_dim = multi_vector_projector_dim
22
  self.pretrained_peft_model_name_or_path = pretrained_peft_model_name_or_path
 
9
 
10
  def __init__(
11
  self,
 
12
  single_vector_pool_strategy: str = "mean",
13
  multi_vector_projector_dim: int = 128,
14
  pretrained_peft_model_name_or_path: Optional[str] = None,
15
  **kwargs,
16
  ):
17
  super().__init__(**kwargs)
 
18
  self.single_vector_pool_strategy = single_vector_pool_strategy
19
  self.multi_vector_projector_dim = multi_vector_projector_dim
20
  self.pretrained_peft_model_name_or_path = pretrained_peft_model_name_or_path
model-00001-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6b45c7afe391b4d9cc49f1ed3f6976f4a25ed40aa2165ed2ae118ff549355985
3
- size 4997750760
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abb244162956ec2f26d944b6c10cbb96afe211d2aff908b8b2f498ec27a9100b
3
+ size 4997750728
model-00002-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a20083234b15a57f34207bb99589241cf7531f01c09fd657110712cb634a811a
3
- size 2516308496
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d5252a7ede6469220b0e7386af53fea9a45fa299a1d2af6fe68cb29897de3e3
3
+ size 2512111904
model.safetensors.index.json CHANGED
@@ -439,8 +439,6 @@
439
  "model.norm.weight": "model-00002-of-00002.safetensors",
440
  "multi_vector_projector.bias": "model-00002-of-00002.safetensors",
441
  "multi_vector_projector.weight": "model-00002-of-00002.safetensors",
442
- "single_vector_projector.bias": "model-00002-of-00002.safetensors",
443
- "single_vector_projector.weight": "model-00002-of-00002.safetensors",
444
  "visual.blocks.0.attn.proj.bias": "model-00001-of-00002.safetensors",
445
  "visual.blocks.0.attn.proj.weight": "model-00001-of-00002.safetensors",
446
  "visual.blocks.0.attn.qkv.bias": "model-00001-of-00002.safetensors",
 
439
  "model.norm.weight": "model-00002-of-00002.safetensors",
440
  "multi_vector_projector.bias": "model-00002-of-00002.safetensors",
441
  "multi_vector_projector.weight": "model-00002-of-00002.safetensors",
 
 
442
  "visual.blocks.0.attn.proj.bias": "model-00001-of-00002.safetensors",
443
  "visual.blocks.0.attn.proj.weight": "model-00001-of-00002.safetensors",
444
  "visual.blocks.0.attn.qkv.bias": "model-00001-of-00002.safetensors",
modeling_jina_embeddings_v4.py CHANGED
@@ -141,12 +141,11 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
141
 
142
  def __init__(self, config: JinaEmbeddingsV4Config):
143
  Qwen2_5_VLForConditionalGeneration.__init__(self, config)
144
- self._init_projection_layers(config)
145
  self.post_init()
146
  self.processor = JinaEmbeddingsV4Processor.from_pretrained(
147
  self.name_or_path, trust_remote_code=True, use_fast=True
148
  )
149
- self.single_vector_projector_dim = config.single_vector_projector_dim
150
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
151
  self._task = None
152
 
@@ -204,32 +203,25 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
204
 
205
  return hidden_states[-1]
206
 
207
- def _init_projection_layers(self, config) -> None:
208
  """
209
  Initializes projection layers.
210
  """
211
- self.config.single_vector_projector_dim = config.single_vector_projector_dim
212
  self.config.multi_vector_projector_dim = config.multi_vector_projector_dim
213
 
214
- self.single_vector_projector = nn.Linear(
215
- in_features=self.config.text_config.hidden_size,
216
- out_features=self.config.single_vector_projector_dim,
217
- )
218
-
219
  self.multi_vector_projector = nn.Linear(
220
  in_features=self.config.text_config.hidden_size,
221
  out_features=self.config.multi_vector_projector_dim,
222
  )
223
 
224
- def project_to_single_vector_embeddings(
225
  self,
226
- task_label: Union[str, List[str]],
227
  hidden_states: torch.Tensor,
228
  attention_mask: torch.Tensor,
229
  input_ids: Optional[torch.LongTensor] = None,
230
  ) -> torch.Tensor:
231
  """
232
- Project the hidden states to single-vector embeddings.
233
  """
234
  if self._input_has_image(input_ids[0]): # got document image
235
  img_start_positions = torch.where(
@@ -257,12 +249,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
257
  hidden_states * attention_mask.unsqueeze(-1), dim=1
258
  ) / torch.sum(attention_mask, dim=1, keepdim=True)
259
 
260
- single_vec_emb = self.single_vector_projector(
261
- pooled_output, task_label=task_label
262
- )
263
- return torch.nn.functional.normalize(single_vec_emb, dim=-1)
264
 
265
- def project_to_multi_vector_embeddings(
266
  self,
267
  task_label: Union[str, List[str]],
268
  hidden_states: torch.Tensor,
@@ -306,13 +295,12 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration):
306
  **kwargs,
307
  ) # (batch_size, seq_length, hidden_size)
308
  # Compute the embeddings
309
- single_vec_emb = self.project_to_single_vector_embeddings(
310
  hidden_states=hidden_states,
311
  attention_mask=attention_mask,
312
  input_ids=input_ids,
313
- task_label=task_label,
314
  )
315
- multi_vec_emb = self.project_to_multi_vector_embeddings(
316
  hidden_states=hidden_states,
317
  attention_mask=attention_mask,
318
  task_label=task_label,
 
141
 
142
  def __init__(self, config: JinaEmbeddingsV4Config):
143
  Qwen2_5_VLForConditionalGeneration.__init__(self, config)
144
+ self._init_projection_layer(config)
145
  self.post_init()
146
  self.processor = JinaEmbeddingsV4Processor.from_pretrained(
147
  self.name_or_path, trust_remote_code=True, use_fast=True
148
  )
 
149
  self.multi_vector_projector_dim = config.multi_vector_projector_dim
150
  self._task = None
151
 
 
203
 
204
  return hidden_states[-1]
205
 
206
+ def _init_projection_layer(self, config) -> None:
207
  """
208
  Initializes projection layers.
209
  """
 
210
  self.config.multi_vector_projector_dim = config.multi_vector_projector_dim
211
 
 
 
 
 
 
212
  self.multi_vector_projector = nn.Linear(
213
  in_features=self.config.text_config.hidden_size,
214
  out_features=self.config.multi_vector_projector_dim,
215
  )
216
 
217
+ def get_single_vector_embeddings(
218
  self,
 
219
  hidden_states: torch.Tensor,
220
  attention_mask: torch.Tensor,
221
  input_ids: Optional[torch.LongTensor] = None,
222
  ) -> torch.Tensor:
223
  """
224
+ Get the single-vector embeddings from the hidden states.
225
  """
226
  if self._input_has_image(input_ids[0]): # got document image
227
  img_start_positions = torch.where(
 
249
  hidden_states * attention_mask.unsqueeze(-1), dim=1
250
  ) / torch.sum(attention_mask, dim=1, keepdim=True)
251
 
252
+ return torch.nn.functional.normalize(pooled_output, dim=-1)
 
 
 
253
 
254
+ def get_multi_vector_embeddings(
255
  self,
256
  task_label: Union[str, List[str]],
257
  hidden_states: torch.Tensor,
 
295
  **kwargs,
296
  ) # (batch_size, seq_length, hidden_size)
297
  # Compute the embeddings
298
+ single_vec_emb = self.get_single_vector_embeddings(
299
  hidden_states=hidden_states,
300
  attention_mask=attention_mask,
301
  input_ids=input_ids,
 
302
  )
303
+ multi_vec_emb = self.get_multi_vector_embeddings(
304
  hidden_states=hidden_states,
305
  attention_mask=attention_mask,
306
  task_label=task_label,