Issamohammed commited on
Commit
37f7d1f
·
verified ·
1 Parent(s): be25d7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -45
app.py CHANGED
@@ -1,48 +1,69 @@
1
  import os
2
  import torch
3
  import gradio as gr
 
4
  from pydub import AudioSegment
 
5
  from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
6
  from pathlib import Path
7
  from tempfile import NamedTemporaryFile
8
  from datetime import timedelta
9
  import time
10
 
 
 
 
 
11
  # Configuration
12
  MODEL_ID = "KBLab/kb-whisper-large"
13
  CHUNK_DURATION_MS = 10000
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
 
16
 
17
  # Initialize model and pipeline
18
  def initialize_pipeline():
19
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
20
- MODEL_ID,
21
- torch_dtype=TORCH_DTYPE,
22
- low_cpu_mem_usage=True
23
- ).to(DEVICE)
24
-
25
- processor = AutoProcessor.from_pretrained(MODEL_ID)
26
-
27
- return pipeline(
28
- "automatic-speech-recognition",
29
- model=model,
30
- tokenizer=processor.tokenizer,
31
- feature_extractor=processor.feature_extractor,
32
- device=DEVICE,
33
- torch_dtype=TORCH_DTYPE,
34
- model_kwargs={"use_flash_attention_2": torch.cuda.is_available()}
35
- )
 
 
36
 
37
  # Convert audio if needed
38
  def convert_to_wav(audio_path: str) -> str:
39
- ext = str(Path(audio_path).suffix).lower()
40
- if ext != ".wav":
41
- audio = AudioSegment.from_file(audio_path)
42
- wav_path = str(Path(audio_path).with_suffix(".converted.wav"))
43
- audio.export(wav_path, format="wav")
44
- return wav_path
45
- return audio_path
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # Split audio into chunks
48
  def split_audio(audio_path: str) -> list:
@@ -50,8 +71,12 @@ def split_audio(audio_path: str) -> list:
50
  audio = AudioSegment.from_file(audio_path)
51
  if len(audio) == 0:
52
  raise ValueError("Audio file is empty or invalid.")
53
- return [audio[i:i + CHUNK_DURATION_MS] around(i, len(audio), CHUNK_DURATION_MS) for i in range(0, len(audio), CHUNK_DURATION_MS)]
 
 
 
54
  except Exception as e:
 
55
  raise ValueError(f"Failed to process audio: {str(e)}")
56
 
57
  # Helper to compute chunk start time
@@ -62,9 +87,10 @@ def get_chunk_time(index: int, chunk_duration_ms: int) -> str:
62
  # Transcribe audio with progress and timestamps
63
  def transcribe(audio_path: str, include_timestamps: bool = False, progress=gr.Progress()):
64
  try:
65
- if not audio_path:
66
- return "No audio file provided.", None
67
-
 
68
  # Convert to WAV if needed
69
  wav_path = convert_to_wav(audio_path)
70
 
@@ -73,43 +99,83 @@ def transcribe(audio_path: str, include_timestamps: bool = False, progress=gr.Pr
73
  total_chunks = len(chunks)
74
  transcript = []
75
  timestamped_transcript = []
 
76
 
77
  for i, chunk in enumerate(chunks):
 
78
  try:
79
  with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
 
80
  chunk.export(temp_file.name, format="wav")
81
  result = PIPELINE(temp_file.name,
82
  generate_kwargs={"task": "transcribe", "language": "sv"})
83
  text = result["text"].strip()
84
- transcript.append(text)
85
- if include_timestamps:
86
- timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
87
- timestamped_transcript.append(f"[{timestamp}] {text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  finally:
89
- if os.path.exists(temp_file.name):
90
- os.remove(temp_file.name)
91
-
 
 
 
92
  progress((i + 1) / total_chunks)
93
  yield " ".join(transcript), None
94
-
95
  # Clean up converted file if created
96
  if wav_path != audio_path and os.path.exists(wav_path):
97
- os.remove(wav_path)
98
-
 
 
 
99
  # Prepare final transcript and downloadable file
100
  final_transcript = " ".join(transcript)
 
 
 
101
  download_content = "\n".join(timestamped_transcript) if include_timestamps else final_transcript
102
- with NamedTemporaryFile(suffix=".txt", delete=False, mode='w', encoding='utf-8') as temp_file:
103
- temp_file.write(download_content)
104
- download_path = temp_file.name
105
-
 
 
 
 
 
106
  return final_transcript, download_path
107
 
 
 
 
108
  except Exception as e:
109
- return f"Error during transcription: {str(e)}", None
 
110
 
111
  # Initialize pipeline globally
112
- PIPELINE = initialize_pipeline()
 
 
 
 
113
 
114
  # Gradio Interface with Blocks
115
  def create_interface():
@@ -136,4 +202,8 @@ def create_interface():
136
  return demo
137
 
138
  if __name__ == "__main__":
139
- create_interface().launch()
 
 
 
 
 
1
  import os
2
  import torch
3
  import gradio as gr
4
+ import logging
5
  from pydub import AudioSegment
6
+ from pydub.exceptions import CouldntDecodeError
7
  from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
8
  from pathlib import Path
9
  from tempfile import NamedTemporaryFile
10
  from datetime import timedelta
11
  import time
12
 
13
+ # Setup logging
14
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
15
+ logger = logging.getLogger(__name__)
16
+
17
  # Configuration
18
  MODEL_ID = "KBLab/kb-whisper-large"
19
  CHUNK_DURATION_MS = 10000
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
  TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
22
+ SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"}
23
 
24
  # Initialize model and pipeline
25
  def initialize_pipeline():
26
+ try:
27
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
28
+ MODEL_ID,
29
+ torch_dtype=TORCH_DTYPE,
30
+ low_cpu_mem_usage=True
31
+ ).to(DEVICE)
32
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
33
+ return pipeline(
34
+ "automatic-speech-recognition",
35
+ model=model,
36
+ tokenizer=processor.tokenizer,
37
+ feature_extractor=processor.feature_extractor,
38
+ device=DEVICE,
39
+ torch_dtype=TORCH_DTYPE,
40
+ model_kwargs={"use_flash_attention_2": torch.cuda.is_available()}
41
+ )
42
+ except Exception as e:
43
+ logger.error(f"Failed to initialize pipeline: {str(e)}")
44
+ raise RuntimeError("Unable to load transcription model. Please check your network connection or model ID.")
45
 
46
  # Convert audio if needed
47
  def convert_to_wav(audio_path: str) -> str:
48
+ try:
49
+ ext = str(Path(audio_path).suffix).lower()
50
+ if ext not in SUPPORTED_FORMATS:
51
+ raise ValueError(f"Unsupported audio format: {ext}. Supported formats: {', '.join(SUPPORTED_FORMATS)}")
52
+ if ext != ".wav":
53
+ audio = AudioSegment.from_file(audio_path)
54
+ wav_path = str(Path(audio_path).with_suffix(".converted.wav"))
55
+ audio.export(wav_path, format="wav")
56
+ return wav_path
57
+ return audio_path
58
+ except CouldntDecodeError:
59
+ logger.error(f"Failed to decode audio file: {audio_path}")
60
+ raise ValueError("Audio file is corrupted or in an unsupported format.")
61
+ except OSError as e:
62
+ logger.error(f"OS error during audio conversion: {str(e)}")
63
+ raise ValueError("Failed to process audio file due to a system error.")
64
+ except Exception as e:
65
+ logger.error(f"Unexpected error during audio conversion: {str(e)}")
66
+ raise ValueError("An unexpected error occurred while converting the audio.")
67
 
68
  # Split audio into chunks
69
  def split_audio(audio_path: str) -> list:
 
71
  audio = AudioSegment.from_file(audio_path)
72
  if len(audio) == 0:
73
  raise ValueError("Audio file is empty or invalid.")
74
+ return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)]
75
+ except CouldntDecodeError:
76
+ logger.error(f"Failed to decode audio for splitting: {audio_path}")
77
+ raise ValueError("Audio file is corrupted or in an unsupported format.")
78
  except Exception as e:
79
+ logger.error(f"Failed to split audio: {str(e)}")
80
  raise ValueError(f"Failed to process audio: {str(e)}")
81
 
82
  # Helper to compute chunk start time
 
87
  # Transcribe audio with progress and timestamps
88
  def transcribe(audio_path: str, include_timestamps: bool = False, progress=gr.Progress()):
89
  try:
90
+ if not audio_path or not os.path.exists(audio_path):
91
+ logger.warning("Invalid or missing audio file path.")
92
+ return "Please upload a valid audio file.", None
93
+
94
  # Convert to WAV if needed
95
  wav_path = convert_to_wav(audio_path)
96
 
 
99
  total_chunks = len(chunks)
100
  transcript = []
101
  timestamped_transcript = []
102
+ failed_chunks = 0
103
 
104
  for i, chunk in enumerate(chunks):
105
+ temp_file_path = None
106
  try:
107
  with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
108
+ temp_file_path = temp_file.name
109
  chunk.export(temp_file.name, format="wav")
110
  result = PIPELINE(temp_file.name,
111
  generate_kwargs={"task": "transcribe", "language": "sv"})
112
  text = result["text"].strip()
113
+ if text: # Only append non-empty transcriptions
114
+ transcript.append(text)
115
+ if include_timestamps:
116
+ timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
117
+ timestamped_transcript.append(f"[{timestamp}] {text}")
118
+ except RuntimeError as e:
119
+ logger.warning(f"Failed to transcribe chunk {i+1}/{total_chunks}: {str(e)}")
120
+ failed_chunks += 1
121
+ transcript.append("[Transcription failed for this segment]")
122
+ if include_timestamps:
123
+ timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
124
+ timestamped_transcript.append(f"[{timestamp}] [Transcription failed]")
125
+ except Exception as e:
126
+ logger.error(f"Unexpected error in chunk {i+1}/{total_chunks}: {str(e)}")
127
+ failed_chunks += 1
128
+ transcript.append("[Transcription failed for this segment]")
129
+ if include_timestamps:
130
+ timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
131
+ timestamped_transcript.append(f"[{timestamp}] [Transcription failed]")
132
  finally:
133
+ if temp_file_path and os.path.exists(temp_file_path):
134
+ try:
135
+ os.remove(temp_file_path)
136
+ except OSError as e:
137
+ logger.warning(f"Failed to delete temporary file {temp_file_path}: {str(e)}")
138
+
139
  progress((i + 1) / total_chunks)
140
  yield " ".join(transcript), None
141
+
142
  # Clean up converted file if created
143
  if wav_path != audio_path and os.path.exists(wav_path):
144
+ try:
145
+ os.remove(wav_path)
146
+ except OSError as e:
147
+ logger.warning(f"Failed to delete converted WAV file {wav_path}: {str(e)}")
148
+
149
  # Prepare final transcript and downloadable file
150
  final_transcript = " ".join(transcript)
151
+ if failed_chunks > 0:
152
+ final_transcript = f"Warning: {failed_chunks}/{total_chunks} chunks failed to transcribe.\n{final_transcript}"
153
+
154
  download_content = "\n".join(timestamped_transcript) if include_timestamps else final_transcript
155
+ download_path = None
156
+ try:
157
+ with NamedTemporaryFile(suffix=".txt", delete=False, mode='w', encoding='utf-8') as temp_file:
158
+ temp_file.write(download_content)
159
+ download_path = temp_file.name
160
+ except OSError as e:
161
+ logger.error(f"Failed to create downloadable transcript: {str(e)}")
162
+ final_transcript = f"{final_transcript}\nNote: Could not generate downloadable transcript due to a file error."
163
+
164
  return final_transcript, download_path
165
 
166
+ except ValueError as e:
167
+ logger.error(f"Value error during transcription: {str(e)}")
168
+ return str(e), None
169
  except Exception as e:
170
+ logger.error(f"Unexpected error during transcription: {str(e)}")
171
+ return f"An unexpected error occurred: {str(e)}. Please try again or contact support.", None
172
 
173
  # Initialize pipeline globally
174
+ try:
175
+ PIPELINE = initialize_pipeline()
176
+ except RuntimeError as e:
177
+ logger.critical(f"Pipeline initialization failed: {str(e)}")
178
+ raise
179
 
180
  # Gradio Interface with Blocks
181
  def create_interface():
 
202
  return demo
203
 
204
  if __name__ == "__main__":
205
+ try:
206
+ create_interface().launch()
207
+ except Exception as e:
208
+ logger.critical(f"Failed to launch Gradio interface: {str(e)}")
209
+ print(f"Error: Could not start the application. Please check the logs for details.")