alakxender commited on
Commit
febf67e
·
1 Parent(s): 5ef23ad
Files changed (1) hide show
  1. app.py +87 -47
app.py CHANGED
@@ -5,46 +5,28 @@ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
5
  import tempfile
6
  import os
7
 
8
- # Model configuration, this model contains synthetic data
9
- MODEL_ID = "alakxender/whisper-small-dv-full"
 
 
 
 
 
10
  BATCH_SIZE = 8
11
  FILE_LIMIT_MB = 1000
12
  CHUNK_LENGTH_S = 10
13
  STRIDE_LENGTH_S = [3,2]
14
 
15
- # Device and dtype setup
16
  device = 0 if torch.cuda.is_available() else "cpu"
17
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
-
19
- # Initialize model with memory optimizations
20
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
21
- MODEL_ID,
22
- torch_dtype=torch_dtype,
23
- low_cpu_mem_usage=True,
24
- use_safetensors=True
25
- )
26
- model.to(device)
27
-
28
- # Initialize processor
29
- processor = AutoProcessor.from_pretrained(MODEL_ID)
30
-
31
- # Single pipeline initialization with all components
32
- pipe = pipeline(
33
- "automatic-speech-recognition",
34
- model=model,
35
- tokenizer=processor.tokenizer,
36
- feature_extractor=processor.feature_extractor,
37
- chunk_length_s=CHUNK_LENGTH_S,
38
- stride_length_s=STRIDE_LENGTH_S,
39
- batch_size=BATCH_SIZE,
40
- torch_dtype=torch_dtype,
41
- device=device,
42
- )
43
-
44
- # Define the generation arguments
45
 
46
  # Define optimized generation arguments
47
- def get_generate_kwargs(is_short_audio=False):
48
  """
49
  Get appropriate generation parameters based on audio length.
50
  Short audio transcription benefits from different parameters.
@@ -72,29 +54,81 @@ def get_generate_kwargs(is_short_audio=False):
72
  "repetition_penalty": 1.2, # Light penalty for repeated tokens
73
  }
74
 
75
- # IMPORTANT: Fix for forced_decoder_ids error
76
- # Remove forced_decoder_ids from the model's generation config
77
- if hasattr(model.generation_config, 'forced_decoder_ids'):
78
- print("Removing forced_decoder_ids from generation config")
79
- model.generation_config.forced_decoder_ids = None
80
-
81
- # Also check if it's in the model config
82
- if hasattr(model.config, 'forced_decoder_ids'):
83
- print("Removing forced_decoder_ids from model config")
84
- delattr(model.config, 'forced_decoder_ids')
85
-
86
  @spaces.GPU
87
- def transcribe(audio_input):
 
 
88
  if audio_input is None:
89
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
90
 
91
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # Use the defined generate_kwargs dictionary
93
- result = pipe(
94
  audio_input,
95
- generate_kwargs=get_generate_kwargs()
96
  )
 
 
97
  return result["text"]
 
98
  except Exception as e:
99
  # More detailed error logging might be helpful here if issues persist
100
  print(f"Detailed Error: {e}")
@@ -116,6 +150,12 @@ file_transcribe = gr.Interface(
116
  fn=transcribe,
117
  inputs=[
118
  gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio file"),
 
 
 
 
 
 
119
  ],
120
  outputs=gr.Textbox(
121
  label="",
@@ -125,11 +165,11 @@ file_transcribe = gr.Interface(
125
  ),
126
  title="Transcribe Dhivehi Audio",
127
  description=(
128
- "Upload an audio file or record using your microphone to transcribe."
129
  ),
130
  flagging_mode="never",
131
  examples=[
132
- ["sample.mp3"]
133
  ],
134
  api_name=False,
135
  cache_examples=False
 
5
  import tempfile
6
  import os
7
 
8
+ # Available models
9
+ MODELS = {
10
+ "alakxender/whisper-small-dv-full": "Whisper Small DV Full",
11
+ #"alakxender/whisper-small-dv-mx02": "Whisper Small DV MX02"
12
+ }
13
+
14
+ # Model configuration constants
15
  BATCH_SIZE = 8
16
  FILE_LIMIT_MB = 1000
17
  CHUNK_LENGTH_S = 10
18
  STRIDE_LENGTH_S = [3,2]
19
 
20
+ # Global variables for device and model management
21
  device = 0 if torch.cuda.is_available() else "cpu"
22
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
23
+ current_model_id = None
24
+ current_model = None
25
+ current_processor = None
26
+ current_pipe = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Define optimized generation arguments
29
+ def get_generate_kwargs(model, is_short_audio=False):
30
  """
31
  Get appropriate generation parameters based on audio length.
32
  Short audio transcription benefits from different parameters.
 
54
  "repetition_penalty": 1.2, # Light penalty for repeated tokens
55
  }
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  @spaces.GPU
58
+ def transcribe(audio_input, model_choice, progress=gr.Progress()):
59
+ global current_model_id, current_model, current_processor, current_pipe, device, torch_dtype
60
+
61
  if audio_input is None:
62
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
63
 
64
  try:
65
+ # Load the selected model if not already loaded or different model selected
66
+ if current_model_id != model_choice or current_model is None:
67
+ progress(0, desc=f"Loading model: {MODELS[model_choice]}")
68
+ print(f"Loading model: {model_choice}")
69
+
70
+ # Initialize model with memory optimizations
71
+ progress(0.2, desc="Downloading model weights...")
72
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
73
+ model_choice,
74
+ torch_dtype=torch_dtype,
75
+ low_cpu_mem_usage=True,
76
+ use_safetensors=True
77
+ )
78
+
79
+ progress(0.4, desc="Moving model to device...")
80
+ model.to(device)
81
+
82
+ # Initialize processor
83
+ progress(0.6, desc="Loading processor...")
84
+ processor = AutoProcessor.from_pretrained(model_choice)
85
+
86
+ # Single pipeline initialization with all components
87
+ progress(0.8, desc="Creating pipeline...")
88
+ pipe = pipeline(
89
+ "automatic-speech-recognition",
90
+ model=model,
91
+ tokenizer=processor.tokenizer,
92
+ feature_extractor=processor.feature_extractor,
93
+ chunk_length_s=CHUNK_LENGTH_S,
94
+ stride_length_s=STRIDE_LENGTH_S,
95
+ batch_size=BATCH_SIZE,
96
+ torch_dtype=torch_dtype,
97
+ device=device,
98
+ )
99
+
100
+ # IMPORTANT: Fix for forced_decoder_ids error
101
+ progress(0.9, desc="Configuring model...")
102
+ # Remove forced_decoder_ids from the model's generation config
103
+ if hasattr(model.generation_config, 'forced_decoder_ids'):
104
+ print("Removing forced_decoder_ids from generation config")
105
+ model.generation_config.forced_decoder_ids = None
106
+
107
+ # Also check if it's in the model config
108
+ if hasattr(model.config, 'forced_decoder_ids'):
109
+ print("Removing forced_decoder_ids from model config")
110
+ delattr(model.config, 'forced_decoder_ids')
111
+
112
+ # Update global variables
113
+ current_model_id = model_choice
114
+ current_model = model
115
+ current_processor = processor
116
+ current_pipe = pipe
117
+
118
+ print(f"Model {model_choice} loaded successfully on {device}")
119
+
120
+ # Start transcription
121
+ progress(0.95, desc="Processing audio...")
122
+
123
  # Use the defined generate_kwargs dictionary
124
+ result = current_pipe(
125
  audio_input,
126
+ generate_kwargs=get_generate_kwargs(current_model)
127
  )
128
+
129
+ progress(1.0, desc="Transcription complete!")
130
  return result["text"]
131
+
132
  except Exception as e:
133
  # More detailed error logging might be helpful here if issues persist
134
  print(f"Detailed Error: {e}")
 
150
  fn=transcribe,
151
  inputs=[
152
  gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio file"),
153
+ gr.Dropdown(
154
+ choices=list(MODELS.keys()),
155
+ value=list(MODELS.keys())[0], # Default to first model
156
+ label="Select Model",
157
+ info="Choose the Whisper model for transcription"
158
+ )
159
  ],
160
  outputs=gr.Textbox(
161
  label="",
 
165
  ),
166
  title="Transcribe Dhivehi Audio",
167
  description=(
168
+ "Upload an audio file or record using your microphone to transcribe. Select your preferred model from the dropdown."
169
  ),
170
  flagging_mode="never",
171
  examples=[
172
+ ["sample.mp3", "alakxender/whisper-small-dv-full"]
173
  ],
174
  api_name=False,
175
  cache_examples=False