izhx kosung commited on
Commit
5b46bc4
·
verified ·
1 Parent(s): 40ed72b

Update custom_st.py (#19)

Browse files

- Update custom_st.py (89146e49f8e1cd1cd5231642a803167c3868f443)
- Update custom_st.py (23e2bf96c6f5d8b13f352d4ac29cd153f522ab91)
- Update custom_st.py (7a9a21cc961c2fd89de1c56c616a7c64e9cbcd2a)


Co-authored-by: kosung <kosung@users.noreply.huggingface.co>

Files changed (1) hide show
  1. custom_st.py +3 -1
custom_st.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from io import BytesIO
2
  from typing import Any, Dict, Optional, List
3
  import torch
@@ -51,7 +53,7 @@ class MultiModalTransformer(BaseTransformer):
51
  self, features: Dict[str, torch.Tensor], **kwargs
52
  ) -> Dict[str, torch.Tensor]:
53
  if features.get("inputs_embeds", None) is None:
54
- features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
55
  if features.get("pixel_values", None) is not None:
56
  features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
57
  image_embeds = self.auto_model.visual(
 
1
+ import math
2
+ import logging
3
  from io import BytesIO
4
  from typing import Any, Dict, Optional, List
5
  import torch
 
53
  self, features: Dict[str, torch.Tensor], **kwargs
54
  ) -> Dict[str, torch.Tensor]:
55
  if features.get("inputs_embeds", None) is None:
56
+ features["inputs_embeds"] = self.auto_model.base_model.get_input_embeddings()(features["input_ids"])
57
  if features.get("pixel_values", None) is not None:
58
  features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
59
  image_embeds = self.auto_model.visual(