Gregniuki commited on
Commit
4dd25c8
·
verified ·
1 Parent(s): c50bf62

Update infer/utils_infer.py

Browse files
Files changed (1) hide show
  1. infer/utils_infer.py +18 -7
infer/utils_infer.py CHANGED
@@ -137,9 +137,13 @@ asr_pipe = None
137
 
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
- dtype = (
141
- torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
142
- )
 
 
 
 
143
  global asr_pipe
144
  asr_pipe = pipeline(
145
  "automatic-speech-recognition",
@@ -149,6 +153,7 @@ def initialize_asr_pipeline(device: str = device, dtype=None):
149
  )
150
 
151
 
 
152
  # transcribe
153
 
154
 
@@ -170,10 +175,16 @@ def transcribe(ref_audio, language=None):
170
 
171
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
172
  if dtype is None:
173
- dtype = (
174
- torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
175
- )
176
- model = model.to(dtype)
 
 
 
 
 
 
177
 
178
  ckpt_type = ckpt_path.split(".")[-1]
179
  if ckpt_type == "safetensors":
 
137
 
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
+ if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6:
141
+ dtype = torch.float16
142
+ elif "cpu" in device:
143
+ dtype = torch.bfloat16
144
+ else:
145
+ dtype = torch.float32
146
+
147
  global asr_pipe
148
  asr_pipe = pipeline(
149
  "automatic-speech-recognition",
 
153
  )
154
 
155
 
156
+
157
  # transcribe
158
 
159
 
 
175
 
176
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
177
  if dtype is None:
178
+ if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6:
179
+ dtype = torch.float16
180
+ elif "cpu" in device:
181
+ dtype = torch.bfloat16
182
+ else:
183
+ dtype = torch.float32
184
+
185
+ # Move the model to the desired device and dtype
186
+ model = model.to(device=device, dtype=dtype)
187
+ #model = model.to(dtype)
188
 
189
  ckpt_type = ckpt_path.split(".")[-1]
190
  if ckpt_type == "safetensors":