izhx commited on
Commit
9cfa641
·
verified ·
1 Parent(s): 5b46bc4

Fix mdoel loading (#17)

Browse files

- Fix mdoel loading (bd36ab0de70e8c918ebaef21ab674e45760e489b)
- Update README.md (333d6e0ec3210304dc43a4c7d75731378c4aa612)
- Update modeling_gme_qwen2vl.py (280b82673f3afce78487f0287755a7cc23199420)
- Update config.json (4174ed03d1015398958742e9dcc4c34277640a4f)
- Update README.md (dca92bddc7528a380b87c21f5383c8c856c27143)
- Update modeling_gme_qwen2vl.py (39953e5c545ac83822ad50cb3ef09214db7fbba5)
- Update README.md (62f5d69e260988e066a2e60c59180d25151d5e9b)

Files changed (3) hide show
  1. README.md +12 -0
  2. config.json +7 -4
  3. modeling_gme_qwen2vl.py +39 -16
README.md CHANGED
@@ -3696,7 +3696,19 @@ The `GME` models support three types of input: **text**, **image**, and **image-
3696
 
3697
  **Transformers**
3698
 
 
 
3699
  ```python
 
 
 
 
 
 
 
 
 
 
3700
  t2i_prompt = 'Find an image that matches the given text.'
3701
  texts = [
3702
  "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
 
3696
 
3697
  **Transformers**
3698
 
3699
+ The remote code has some issues with `transformers>=4.52.0`, please downgrade or use `sentence_transformers`
3700
+
3701
  ```python
3702
+ from transformers import AutoModel
3703
+ from transformers.utils.versions import require_version
3704
+
3705
+
3706
+ require_version(
3707
+ "transformers<4.52.0",
3708
+ "The remote code has some issues with transformers>=4.52.0, please downgrade: pip install transformers==4.51.3"
3709
+ )
3710
+
3711
+
3712
  t2i_prompt = 'Find an image that matches the given text.'
3713
  texts = [
3714
  "The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.",
config.json CHANGED
@@ -1,9 +1,12 @@
1
  {
2
  "_name_or_path": "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct",
3
- "architectures": ["GmeQwen2VLForVision2Seq"],
 
 
 
4
  "auto_map": {
5
- "AutoModel": "modeling_gme_qwen2vl.GmeQwen2VLForVision2Seq",
6
- "AutoConfig": "modeling_gme_qwen2vl.GmeQwen2VLConfig"
7
  },
8
  "attention_dropout": 0.0,
9
  "bos_token_id": 151643,
@@ -15,7 +18,7 @@
15
  "intermediate_size": 8960,
16
  "max_position_embeddings": 32768,
17
  "max_window_layers": 28,
18
- "model_type": "gme_qwen2_vl",
19
  "num_attention_heads": 12,
20
  "num_hidden_layers": 28,
21
  "num_key_value_heads": 2,
 
1
  {
2
  "_name_or_path": "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct",
3
+ "architectures": [
4
+ "Qwen2VLForConditionalGeneration",
5
+ "GmeQwen2VL"
6
+ ],
7
  "auto_map": {
8
+ "AutoConfig": "modeling_gme_qwen2vl.GmeQwen2VLConfig",
9
+ "AutoModel": "modeling_gme_qwen2vl.GmeQwen2VL"
10
  },
11
  "attention_dropout": 0.0,
12
  "bos_token_id": 151643,
 
18
  "intermediate_size": 8960,
19
  "max_position_embeddings": 32768,
20
  "max_window_layers": 28,
21
+ "model_type": "qwen2_vl",
22
  "num_attention_heads": 12,
23
  "num_hidden_layers": 28,
24
  "num_key_value_heads": 2,
modeling_gme_qwen2vl.py CHANGED
@@ -12,16 +12,25 @@ import torch
12
  from PIL import Image
13
  from torch.utils.data import DataLoader
14
  from tqdm.autonotebook import tqdm
15
- from transformers import (
16
- AutoProcessor,
17
- PreTrainedModel,
18
  Qwen2VLConfig,
19
  Qwen2VLForConditionalGeneration,
 
 
 
 
 
 
 
 
20
  )
21
- import os
22
 
23
 
24
  class GmeQwen2VLConfig(Qwen2VLConfig):
 
 
25
  def __init__(
26
  self,
27
  min_image_tokens: int = 256,
@@ -35,14 +44,25 @@ class GmeQwen2VLConfig(Qwen2VLConfig):
35
  self.max_length = max_length
36
 
37
 
38
- class GmeQwen2VLForVision2Seq(PreTrainedModel):
39
  config_class = GmeQwen2VLConfig
40
- base_model_prefix: str = "base"
 
 
 
 
 
 
 
 
41
 
42
  def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
43
  super().__init__(config)
44
- self.base = Qwen2VLForConditionalGeneration.from_pretrained(config._name_or_path)
45
- self.base.tie_weights() # It's important to produce same outputs.
 
 
 
46
 
47
  min_pixels: int = config.min_image_tokens * 28 * 28
48
  max_pixels: int = config.max_image_tokens * 28 * 28
@@ -55,6 +75,9 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
55
  self.default_instruction: str = "You are a helpful assistant."
56
  self.sep: str = " "
57
 
 
 
 
58
  def forward(
59
  self,
60
  input_ids: Optional[torch.LongTensor] = None,
@@ -70,21 +93,21 @@ class GmeQwen2VLForVision2Seq(PreTrainedModel):
70
  **kwargs
71
  ) -> torch.Tensor:
72
  if inputs_embeds is None:
73
- inputs_embeds = self.base.model.embed_tokens(input_ids)
74
  if pixel_values is not None:
75
- pixel_values = pixel_values.type(self.base.visual.get_dtype())
76
- image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
77
- image_mask = input_ids == self.base.config.image_token_id
78
  inputs_embeds[image_mask] = image_embeds
79
  # if pixel_values_videos is not None:
80
- # pixel_values_videos = pixel_values_videos.type(self.base.visual.get_dtype())
81
- # video_embeds = self.base.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
82
- # video_mask = input_ids == self.base.config.video_token_id
83
  # inputs_embeds[video_mask] = video_embeds
84
  if attention_mask is not None:
85
  attention_mask = attention_mask.to(inputs_embeds.device)
86
 
87
- outputs = self.base.model(
88
  input_ids=None,
89
  position_ids=position_ids,
90
  attention_mask=attention_mask,
 
12
  from PIL import Image
13
  from torch.utils.data import DataLoader
14
  from tqdm.autonotebook import tqdm
15
+ from transformers import AutoProcessor, PreTrainedModel
16
+ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
17
+ Qwen2VisionTransformerPretrainedModel,
18
  Qwen2VLConfig,
19
  Qwen2VLForConditionalGeneration,
20
+ Qwen2VLModel,
21
+ )
22
+ from transformers.utils.versions import require_version
23
+
24
+
25
+ require_version(
26
+ "transformers<4.52.0",
27
+ "This code has some issues with transformers>=4.52.0, please downgrade: pip install transformers==4.51.3"
28
  )
 
29
 
30
 
31
  class GmeQwen2VLConfig(Qwen2VLConfig):
32
+ # model_type = ''
33
+
34
  def __init__(
35
  self,
36
  min_image_tokens: int = 256,
 
44
  self.max_length = max_length
45
 
46
 
47
+ class GmeQwen2VL(PreTrainedModel):
48
  config_class = GmeQwen2VLConfig
49
+ base_model_prefix = "model"
50
+ supports_gradient_checkpointing = True
51
+ _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
52
+ # _skip_keys_device_placement = "past_key_values"
53
+ _supports_flash_attn_2 = True
54
+ _supports_sdpa = True
55
+ # _supports_cache_class = True
56
+ _supports_static_cache = False # TODO (joao): fix. torch.compile failing probably due to `cache_positions`
57
+ # _tied_weights_keys = ["lm_head.weight"]
58
 
59
  def __init__(self, config: GmeQwen2VLConfig, **kwargs: Any) -> None:
60
  super().__init__(config)
61
+ self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
62
+ self.model = Qwen2VLModel(config)
63
+ self.vocab_size = config.vocab_size
64
+ # self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
65
+ self.rope_deltas = None # cache rope_deltas here
66
 
67
  min_pixels: int = config.min_image_tokens * 28 * 28
68
  max_pixels: int = config.max_image_tokens * 28 * 28
 
75
  self.default_instruction: str = "You are a helpful assistant."
76
  self.sep: str = " "
77
 
78
+ # Initialize weights and apply final processing
79
+ self.post_init()
80
+
81
  def forward(
82
  self,
83
  input_ids: Optional[torch.LongTensor] = None,
 
93
  **kwargs
94
  ) -> torch.Tensor:
95
  if inputs_embeds is None:
96
+ inputs_embeds = self.model.get_input_embeddings()(input_ids)
97
  if pixel_values is not None:
98
+ pixel_values = pixel_values.type(self.visual.get_dtype())
99
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
100
+ image_mask = input_ids == self.config.image_token_id
101
  inputs_embeds[image_mask] = image_embeds
102
  # if pixel_values_videos is not None:
103
+ # pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
104
+ # video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
105
+ # video_mask = input_ids == self.config.video_token_id
106
  # inputs_embeds[video_mask] = video_embeds
107
  if attention_mask is not None:
108
  attention_mask = attention_mask.to(inputs_embeds.device)
109
 
110
+ outputs = self.model(
111
  input_ids=None,
112
  position_ids=position_ids,
113
  attention_mask=attention_mask,