Update model.py
Browse files
model.py
CHANGED
@@ -115,8 +115,8 @@ class SALMONN(nn.Module):
|
|
115 |
self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size)
|
116 |
|
117 |
# load ckpt
|
118 |
-
ckpt_dict = torch.load(ckpt)['model']
|
119 |
-
self.load_state_dict(ckpt_dict, strict=False
|
120 |
|
121 |
def generate(
|
122 |
self,
|
|
|
115 |
self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size)
|
116 |
|
117 |
# load ckpt
|
118 |
+
ckpt_dict = torch.load(ckpt, map_location=device)['model']
|
119 |
+
self.load_state_dict(ckpt_dict, strict=False)
|
120 |
|
121 |
def generate(
|
122 |
self,
|