vian123 commited on
Commit
a2dca05
·
verified ·
1 Parent(s): 8296f89

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +127 -259
src/streamlit_app.py CHANGED
@@ -13,331 +13,199 @@ import torchaudio
13
  import yt_dlp
14
  import torch
15
 
 
16
  class Interface:
17
  @staticmethod
18
  def get_header(title: str, description: str) -> None:
19
- """
20
- Display the header of the application.
21
- """
22
- st.set_page_config(
23
- page_title="Audio Summarization",
24
- page_icon="🗣️",
25
- )
26
 
27
- hide_decoration_bar_style = """<style>header {visibility: hidden;}</style>"""
28
- st.markdown(hide_decoration_bar_style, unsafe_allow_html=True)
29
- hide_streamlit_footer = """
30
- <style>#MainMenu {visibility: hidden;}
31
- footer {visibility: hidden;}</style>
32
- """
33
- st.markdown(hide_streamlit_footer, unsafe_allow_html=True)
34
-
35
  st.title(title)
36
-
37
  st.info(description)
38
- st.write("\n")
39
 
40
  @staticmethod
41
  def get_audio_file() -> str:
42
- """
43
- Upload an audio file for transcription and summarization.
44
- """
45
- uploaded_file = st.file_uploader(
46
- "Choose an audio file",
47
- type=["wav"],
48
- help="Upload an audio file for transcription and summarization.",
49
- )
50
- if uploaded_file is None:
51
- return None
52
-
53
- if uploaded_file.name.endswith(".wav"):
54
  st.audio(uploaded_file, format="audio/wav")
55
- else:
 
56
  st.warning("Please upload a valid .wav audio file.")
57
- return None
58
-
59
- return uploaded_file
60
-
61
  @staticmethod
62
- def get_approach() -> None:
63
- """
64
- Select the approach for input audio summarization.
65
- """
66
- approach = st.selectbox(
67
- "Select the approach for input audio summarization",
68
- options=["Youtube Link", "Input Audio File"],
69
- index=1,
70
- help="Choose the approach you want to use for summarization.",
71
- )
72
 
73
- return approach
74
-
75
  @staticmethod
76
  def get_link_youtube() -> str:
77
- """
78
- Input a YouTube link for audio summarization.
79
- """
80
- youtube_link = st.text_input(
81
- "Enter the YouTube link",
82
- placeholder="https://www.youtube.com/watch?v=example",
83
- help="Paste the YouTube link of the video you want to summarize.",
84
- )
85
  if youtube_link.strip():
86
  st.video(youtube_link)
87
-
88
  return youtube_link
89
-
90
  @staticmethod
91
- def get_sidebar_input(state: dict) -> str:
92
- """
93
- Handles sidebar interaction and returns the audio path if available.
94
- """
95
  with st.sidebar:
96
  st.markdown("### Select Approach")
97
  approach = Interface.get_approach()
98
  state['session'] = 1
99
 
100
  audio_path = None
101
-
102
- if approach == "Input Audio File" and state['session'] == 1:
103
  audio = Interface.get_audio_file()
104
- if audio is not None:
105
  audio_path = Utils.temporary_file(audio)
106
- state['session'] = 2
107
-
108
- elif approach == "Youtube Link" and state['session'] == 1:
109
  youtube_link = Interface.get_link_youtube()
110
  if youtube_link:
111
  audio_path = Utils.download_youtube_audio_to_tempfile(youtube_link)
112
- if audio_path is not None:
113
- with open(audio_path, "rb") as audio_file:
114
- audio_bytes = audio_file.read()
115
- st.audio(audio_bytes, format="audio/wav")
116
- state['session'] = 2
117
-
118
- generate = False
119
- if state['session'] == 2 and 'audio_path' in locals() and audio_path:
120
- generate = st.button("🚀 Generate Result !!")
121
 
 
122
  return audio_path, generate
123
 
 
124
  class Utils:
125
  @staticmethod
126
  def temporary_file(uploaded_file: str) -> str:
127
- """
128
- Create a temporary file for the uploaded audio file.
129
- """
130
- if uploaded_file is not None:
131
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
132
- temp_file.write(uploaded_file.read())
133
- temp_file_path = temp_file.name
134
- return temp_file_path
135
-
136
- @staticmethod
137
  def clean_transcript(text: str) -> str:
138
- """
139
- Clean the transcript text by removing unwanted characters and formatting.
140
- """
141
- text = text.replace(",", " ")
142
  text = re.sub(r'(?<=[a-zA-Z])\.(?=[a-zA-Z])', ' ', text)
143
- text = re.sub(r'\s+', ' ', text)
144
- text = re.sub(r'\s*\.\s*', '. ', text)
145
- return text.strip()
146
-
147
  @staticmethod
148
  def preprocess_audio(input_path: str) -> str:
149
- """
150
- Preprocess the audio file by converting it to mono and resampling to 16000 Hz.
151
- """
152
  waveform, sample_rate = torchaudio.load(input_path)
153
- print(f"📢 Original waveform shape: {waveform.shape}")
154
- print(f"📢 Original sample rate: {sample_rate}")
155
-
156
- # Convert to mono (average if stereo)
157
  if waveform.shape[0] > 1:
158
  waveform = waveform.mean(dim=0, keepdim=True)
159
- print("✅ Converted to mono.")
160
-
161
- # Resample to 16000 Hz if needed
162
- target_sample_rate = 16000
163
- if sample_rate != target_sample_rate:
164
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
165
- waveform = resampler(waveform)
166
- print(f"✅ Resampled to {target_sample_rate} Hz.")
167
- sample_rate = target_sample_rate
168
 
169
- # Create a temporary file for the output
170
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
171
- output_path = tmpfile.name
172
-
173
- torchaudio.save(output_path, waveform, sample_rate)
174
- print(f"✅ Saved preprocessed audio to temporary file: {output_path}")
175
-
176
- return output_path
177
-
178
  @staticmethod
179
- def _format_filename(input_string, chunk_number=0):
180
- """
181
- Format the input string to create a valid filename.
182
- Replaces non-alphanumeric characters with underscores, removes extra spaces,
183
- and appends a chunk number if provided.
184
- """
185
- input_string = input_string.strip()
186
- formatted_string = re.sub(r'[^a-zA-Z0-9\s]', '_', input_string)
187
- formatted_string = re.sub(r'[\s_]+', '_', formatted_string)
188
- formatted_string = formatted_string.lower()
189
- formatted_string += f'_chunk_{chunk_number}'
190
- return formatted_string
191
 
192
  @staticmethod
193
- def download_youtube_audio_to_tempfile(youtube_url):
194
- """
195
- Download audio from a YouTube video and save it as a WAV file in a temporary directory.
196
- Returns the path to the saved WAV file.
197
- """
198
- with st.spinner("Downloading and converting YouTube audio..."):
199
- try:
200
- # Get video info to use its title in the filename
201
- with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
202
- info_dict = ydl.extract_info(youtube_url, download=False)
203
- original_title = info_dict.get('title', 'audio')
204
- formatted_title = Utils._format_filename(original_title)
205
-
206
- # Create a temporary directory
207
- temp_dir = tempfile.mkdtemp()
208
- output_path_no_ext = os.path.join(temp_dir, formatted_title)
209
-
210
- ydl_opts = {
211
- 'format': 'bestaudio/best',
212
- 'postprocessors': [{
213
- 'key': 'FFmpegExtractAudio',
214
- 'preferredcodec': 'wav',
215
- 'preferredquality': '192',
216
- }],
217
- 'outtmpl': output_path_no_ext,
218
- 'quiet': True
219
- }
220
-
221
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
222
- ydl.download([youtube_url])
223
-
224
- # Wait for yt_dlp to actually create the WAV file
225
- expected_output = output_path_no_ext + ".wav"
226
- timeout = 5
227
- while not os.path.exists(expected_output) and timeout > 0:
228
- time.sleep(1)
229
- timeout -= 1
230
-
231
- if not os.path.exists(expected_output):
232
- raise FileNotFoundError(f"Audio file was not saved as expected: {expected_output}")
233
-
234
- st.toast(f"Audio downloaded and saved to: {expected_output}")
235
- return expected_output
236
-
237
- except Exception as e:
238
- st.toast(f"Failed to download {youtube_url}: {e}")
239
- return None
240
 
241
  class Generation:
242
- def __init__(
243
- self,
244
- summarization_model: str = "vian123/brio-finance-finetuned-v2",
245
- speech_to_text_model: str = "nyrahealth/CrisperWhisper",
246
- ):
247
- self.summarization_model = summarization_model
248
- self.speech_to_text_model = speech_to_text_model
249
  self.device = "cpu"
250
  self.dtype = torch.float32
251
- self.processor_speech = AutoProcessor.from_pretrained(speech_to_text_model)
252
- self.model_speech = AutoModelForSpeechSeq2Seq.from_pretrained(
253
- speech_to_text_model,
254
- torch_dtype=self.dtype,
255
- low_cpu_mem_usage=True,
256
- use_safetensors=True,
257
- attn_implementation="eager",
258
- ).to(self.device)
259
- self.summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
260
-
261
- def transcribe_audio_pytorch(self, file_path: str) -> str:
262
- """
263
- transcribe audio using the PyTorch-based speech-to-text model.
264
- """
265
- converted_path = Utils.preprocess_audio(file_path)
266
- waveform, sample_rate = torchaudio.load(converted_path)
267
- duration = waveform.shape[1] / sample_rate
268
- if duration < 1.0:
269
- print("❌ Audio too short to process.")
270
  return ""
271
 
272
- pipe = pipeline(
273
  "automatic-speech-recognition",
274
- model=self.model_speech,
275
- tokenizer=self.processor_speech.tokenizer,
276
- feature_extractor=self.processor_speech.feature_extractor,
277
  chunk_length_s=5,
278
- batch_size=1,
279
- return_timestamps=None,
280
  torch_dtype=self.dtype,
281
- device=self.device,
282
- model_kwargs={"language": "en"},
283
  )
284
 
285
  try:
286
- hf_pipeline_output = pipe(converted_path)
287
- print("✅ HF pipeline output:", hf_pipeline_output)
288
- return hf_pipeline_output.get("text", "")
289
  except Exception as e:
290
- print(" Pipeline failed with error:", e)
291
  return ""
292
 
293
- def summarize_string(self, text: str) -> str:
294
- """
295
- Summarize the input text using the summarization model.
296
- """
297
- summarizer = pipeline("summarization", model=self.summarization_model, tokenizer=self.summarization_model)
 
 
 
 
298
  try:
299
- if len(text.strip()) < 10:
300
- return ""
301
-
302
- inputs = self.summarization_tokenizer(text, truncation=True, max_length=512, return_tensors="pt")
303
- truncated_text = self.summarization_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
304
-
305
- word_count = len(truncated_text.split())
306
- min_len = max(int(word_count * 0.5), 30)
307
- max_len = max(min_len + 20, int(word_count * 0.75))
308
-
309
- summary = summarizer(
310
- truncated_text,
311
- max_length=max_len,
312
- min_length=min_len,
313
- do_sample=False
314
- )
315
  return summary[0]['summary_text']
316
  except Exception as e:
317
- return f"Error: {e}"
318
-
 
319
  def main():
320
- Interface.get_header(
321
- title="Financial YouTube Video Audio Summarization",
322
- description="🎧 Upload an financial audio file or financial YouTube video link to 📝 transcribe and 📄 summarize its content using CrisperWhisper and Financial Fine-tuned BRIO 🤖."
323
- )
324
-
325
- generate = False
326
- state = dict(session=0)
327
-
328
- audio_path, generate = Interface.get_sidebar_input(state)
329
-
330
- if generate and state['session'] == 2:
331
- with st.spinner("Generating ..."):
332
- generation = Generation()
333
- transcribe = generation.transcribe_audio_pytorch(audio_path)
334
-
335
- with st.expander("Transcription Text", expanded=True):
336
- st.text_area("Transcription:", transcribe, height=300)
337
-
338
- summarization = generation.summarize_string(transcribe)
339
- with st.expander("Summarization Text", expanded=True):
340
- st.text_area("Summarization:", summarization, height=300)
341
 
342
  if __name__ == "__main__":
343
  main()
 
13
  import yt_dlp
14
  import torch
15
 
16
+
17
  class Interface:
18
  @staticmethod
19
  def get_header(title: str, description: str) -> None:
20
+ st.set_page_config(page_title="Audio Summarization", page_icon="🗣️")
21
+
22
+ st.markdown("""
23
+ <style>
24
+ header, #MainMenu, footer {visibility: hidden;}
25
+ </style>
26
+ """, unsafe_allow_html=True)
27
 
 
 
 
 
 
 
 
 
28
  st.title(title)
 
29
  st.info(description)
 
30
 
31
  @staticmethod
32
  def get_audio_file() -> str:
33
+ uploaded_file = st.file_uploader("Choose an audio file", type=["wav"], help="Upload a .wav audio file.")
34
+ if uploaded_file and uploaded_file.name.endswith(".wav"):
 
 
 
 
 
 
 
 
 
 
35
  st.audio(uploaded_file, format="audio/wav")
36
+ return uploaded_file
37
+ elif uploaded_file:
38
  st.warning("Please upload a valid .wav audio file.")
39
+ return None
40
+
 
 
41
  @staticmethod
42
+ def get_approach() -> str:
43
+ return st.selectbox("Select summarization approach", ["Youtube Link", "Input Audio File"], index=1)
 
 
 
 
 
 
 
 
44
 
 
 
45
  @staticmethod
46
  def get_link_youtube() -> str:
47
+ youtube_link = st.text_input("Enter YouTube link", placeholder="https://www.youtube.com/watch?v=example")
 
 
 
 
 
 
 
48
  if youtube_link.strip():
49
  st.video(youtube_link)
 
50
  return youtube_link
51
+
52
  @staticmethod
53
+ def get_sidebar_input(state: dict) -> tuple:
 
 
 
54
  with st.sidebar:
55
  st.markdown("### Select Approach")
56
  approach = Interface.get_approach()
57
  state['session'] = 1
58
 
59
  audio_path = None
60
+ if approach == "Input Audio File":
 
61
  audio = Interface.get_audio_file()
62
+ if audio:
63
  audio_path = Utils.temporary_file(audio)
64
+ elif approach == "Youtube Link":
 
 
65
  youtube_link = Interface.get_link_youtube()
66
  if youtube_link:
67
  audio_path = Utils.download_youtube_audio_to_tempfile(youtube_link)
68
+ if audio_path:
69
+ with open(audio_path, "rb") as af:
70
+ st.audio(af.read(), format="audio/wav")
 
 
 
 
 
 
71
 
72
+ generate = audio_path and st.button("🚀 Generate Result !!")
73
  return audio_path, generate
74
 
75
+
76
  class Utils:
77
  @staticmethod
78
  def temporary_file(uploaded_file: str) -> str:
79
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
80
+ tmp.write(uploaded_file.read())
81
+ return tmp.name
82
+
83
+ @staticmethod
 
 
 
 
 
84
  def clean_transcript(text: str) -> str:
 
 
 
 
85
  text = re.sub(r'(?<=[a-zA-Z])\.(?=[a-zA-Z])', ' ', text)
86
+ text = re.sub(r'[^\w. ]+', ' ', text)
87
+ return re.sub(r'\s+', ' ', text).strip()
88
+
 
89
  @staticmethod
90
  def preprocess_audio(input_path: str) -> str:
 
 
 
91
  waveform, sample_rate = torchaudio.load(input_path)
 
 
 
 
92
  if waveform.shape[0] > 1:
93
  waveform = waveform.mean(dim=0, keepdim=True)
94
+ if sample_rate != 16000:
95
+ waveform = Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
96
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
97
+ torchaudio.save(tmp.name, waveform, 16000)
98
+ return tmp.name
 
 
 
 
99
 
 
 
 
 
 
 
 
 
 
100
  @staticmethod
101
+ def _format_filename(name: str, chunk=0) -> str:
102
+ clean = re.sub(r'[^a-zA-Z0-9]', '_', name.strip().lower())
103
+ return f"{clean}_chunk_{chunk}"
 
 
 
 
 
 
 
 
 
104
 
105
  @staticmethod
106
+ def download_youtube_audio_to_tempfile(url: str) -> str:
107
+ try:
108
+ with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
109
+ info = ydl.extract_info(url, download=False)
110
+ filename = Utils._format_filename(info.get('title', 'audio'))
111
+
112
+ out_dir = tempfile.mkdtemp()
113
+ output_path = os.path.join(out_dir, filename)
114
+
115
+ ydl_opts = {
116
+ 'format': 'bestaudio/best',
117
+ 'postprocessors': [{
118
+ 'key': 'FFmpegExtractAudio',
119
+ 'preferredcodec': 'wav',
120
+ 'preferredquality': '192',
121
+ }],
122
+ 'outtmpl': output_path,
123
+ 'quiet': True
124
+ }
125
+
126
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
127
+ ydl.download([url])
128
+
129
+ final_path = output_path + ".wav"
130
+ for _ in range(5):
131
+ if os.path.exists(final_path):
132
+ return final_path
133
+ time.sleep(1)
134
+ raise FileNotFoundError(f"File not found: {final_path}")
135
+ except Exception as e:
136
+ st.toast(f"Download failed: {e}")
137
+ return None
138
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  class Generation:
141
+ def __init__(self, summarization_model="vian123/brio-finance-finetuned-v2", speech_to_text_model="nyrahealth/CrisperWhisper"):
 
 
 
 
 
 
142
  self.device = "cpu"
143
  self.dtype = torch.float32
144
+
145
+ self.processor = AutoProcessor.from_pretrained(speech_to_text_model)
146
+ self.model = AutoModelForSpeechSeq2Seq.from_pretrained(speech_to_text_model, torch_dtype=self.dtype).to(self.device)
147
+ self.tokenizer = AutoTokenizer.from_pretrained(summarization_model)
148
+ self.summarizer = pipeline("summarization", model=summarization_model, tokenizer=summarization_model)
149
+
150
+ def transcribe(self, audio_path: str) -> str:
151
+ processed_path = Utils.preprocess_audio(audio_path)
152
+ waveform, rate = torchaudio.load(processed_path)
153
+ if waveform.shape[1] / rate < 1:
 
 
 
 
 
 
 
 
 
154
  return ""
155
 
156
+ asr_pipe = pipeline(
157
  "automatic-speech-recognition",
158
+ model=self.model,
159
+ tokenizer=self.processor.tokenizer,
160
+ feature_extractor=self.processor.feature_extractor,
161
  chunk_length_s=5,
 
 
162
  torch_dtype=self.dtype,
163
+ device=self.device
 
164
  )
165
 
166
  try:
167
+ output = asr_pipe(processed_path)
168
+ return output.get("text", "")
 
169
  except Exception as e:
170
+ print("ASR error:", e)
171
  return ""
172
 
173
+ def summarize(self, text: str) -> str:
174
+ if len(text.strip()) < 10:
175
+ return ""
176
+ cleaned = self.tokenizer(text, truncation=True, max_length=512, return_tensors="pt")
177
+ decoded = self.tokenizer.decode(cleaned["input_ids"][0], skip_special_tokens=True)
178
+
179
+ word_count = len(decoded.split())
180
+ min_len, max_len = max(30, int(word_count * 0.5)), max(50, int(word_count * 0.75))
181
+
182
  try:
183
+ summary = self.summarizer(decoded, max_length=max_len, min_length=min_len, do_sample=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  return summary[0]['summary_text']
185
  except Exception as e:
186
+ return f"Summarization error: {e}"
187
+
188
+
189
  def main():
190
+ Interface.get_header(
191
+ title="Financial YouTube Video Audio Summarization",
192
+ description="🎧 Upload a financial audio or YouTube video to transcribe and summarize using CrisperWhisper + fine-tuned BRIO."
193
+ )
194
+
195
+ state = dict(session=0)
196
+ audio_path, generate = Interface.get_sidebar_input(state)
197
+
198
+ if generate:
199
+ with st.spinner("Processing..."):
200
+ gen = Generation()
201
+ transcript = gen.transcribe(audio_path)
202
+
203
+ st.expander("Transcription Text", expanded=True).text_area("Transcription", transcript, height=300)
204
+
205
+ with st.spinner("Summarizing..."):
206
+ summary = gen.summarize(transcript)
207
+ st.expander("Summarization Text", expanded=True).text_area("Summarization", summary, height=300)
208
+
 
 
209
 
210
  if __name__ == "__main__":
211
  main()