|
import os |
|
import torch |
|
|
|
import numpy as np |
|
from transformers import ( |
|
AutoModelForSpeechSeq2Seq, |
|
AutoProcessor, |
|
pipeline, |
|
) |
|
from transformers.utils import is_flash_attn_2_available |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
MODEL_ID = "openai/whisper-large-v3-turbo" |
|
LANGUAGE = "english" |
|
|
|
device = "cuda" |
|
use_device_map = True |
|
try_compile_model = True |
|
try_use_flash_attention = True |
|
torch_dtype = torch.float16 |
|
np_dtype = np.float16 |
|
|
|
|
|
try: |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa", |
|
device_map="auto" if use_device_map else None, |
|
) |
|
if not use_device_map: |
|
model.to(device) |
|
except RuntimeError as e: |
|
try: |
|
logger.warning("Falling back to device_map=None") |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
attn_implementation="flash_attention_2" if try_use_flash_attention else "sdpa", |
|
device_map=None, |
|
) |
|
model.to(device) |
|
except RuntimeError as e: |
|
try: |
|
logger.warning("Disabling flash attention") |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch_dtype, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True, |
|
attn_implementation="sdpa", |
|
) |
|
model.to(device) |
|
except Exception as e: |
|
logger.error(f"Error loading ASR model: {e}") |
|
logger.error(f"Are you providing a valid model ID? {MODEL_ID}") |
|
raise |
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
|
|
transcribe_pipeline = pipeline( |
|
task="automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
torch_dtype=torch_dtype |
|
) |
|
|
|
|
|
try: |
|
if try_compile_model: |
|
transcribe_pipeline.model = torch.compile(transcribe_pipeline.model, mode="max-autotune") |
|
else: |
|
logger.warning("Proceeding without compiling the model (requirements not met)") |
|
except Exception as e: |
|
logger.warning(f"Error compiling model: {e}") |
|
logger.warning("Proceeding without compiling the model") |
|
|
|
|
|
logger.info("Warming up Whisper model with dummy input") |
|
warmup_audio = np.random.rand(16000).astype(np_dtype) |
|
transcribe_pipeline(warmup_audio) |
|
logger.info("Model warmup complete") |