orpatashnik commited on
Commit
09790b2
·
1 Parent(s): 8eeff9c
Files changed (1) hide show
  1. nested_attention_pipeline.py +2 -1
nested_attention_pipeline.py CHANGED
@@ -4,6 +4,7 @@ from typing import List
4
  import torch
5
  from PIL import Image
6
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
 
7
 
8
  from nested_attention_processor import AttnProcessor, NestedAttnProcessor
9
  from utils import get_generator
@@ -110,7 +111,7 @@ class NestedAdapterInference:
110
 
111
  def load_nested_adapter(self):
112
  state_dict = {"adapter_modules": {}, "qformer": {}}
113
- f = torch.load(self.adapter_ckpt, map_location="cpu")
114
  for key in f.keys():
115
  if key.startswith("adapter_modules."):
116
  state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[
 
4
  import torch
5
  from PIL import Image
6
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
7
+ from safetensors import load_file
8
 
9
  from nested_attention_processor import AttnProcessor, NestedAttnProcessor
10
  from utils import get_generator
 
111
 
112
  def load_nested_adapter(self):
113
  state_dict = {"adapter_modules": {}, "qformer": {}}
114
+ f = load_file(self.adapter_ckpt)
115
  for key in f.keys():
116
  if key.startswith("adapter_modules."):
117
  state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[