vian123 commited on
Commit
2debac3
Β·
verified Β·
1 Parent(s): 2eb50a0

Upload streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +342 -0
src/streamlit_app.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import tempfile
4
+ import time
5
+ import sys
6
+ import re
7
+ import os
8
+
9
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline, AutoTokenizer
10
+ from torchaudio.transforms import Resample
11
+ import soundfile as sf
12
+ import torchaudio
13
+ import yt_dlp
14
+ import torch
15
+
16
+ class Interface:
17
+ @staticmethod
18
+ def get_header(title: str, description: str) -> None:
19
+ """
20
+ Display the header of the application.
21
+ """
22
+ st.set_page_config(
23
+ page_title="Audio Summarization",
24
+ page_icon="πŸ—£οΈ",
25
+ )
26
+
27
+ hide_decoration_bar_style = """<style>header {visibility: hidden;}</style>"""
28
+ st.markdown(hide_decoration_bar_style, unsafe_allow_html=True)
29
+ hide_streamlit_footer = """
30
+ <style>#MainMenu {visibility: hidden;}
31
+ footer {visibility: hidden;}</style>
32
+ """
33
+ st.markdown(hide_streamlit_footer, unsafe_allow_html=True)
34
+
35
+ st.title(title)
36
+
37
+ st.info(description)
38
+ st.write("\n")
39
+
40
+ @staticmethod
41
+ def get_audio_file() -> str:
42
+ """
43
+ Upload an audio file for transcription and summarization.
44
+ """
45
+ uploaded_file = st.file_uploader(
46
+ "Choose an audio file",
47
+ type=["wav"],
48
+ help="Upload an audio file for transcription and summarization.",
49
+ )
50
+ if uploaded_file is None:
51
+ return None
52
+
53
+ if uploaded_file.name.endswith(".wav"):
54
+ st.audio(uploaded_file, format="audio/wav")
55
+ else:
56
+ st.warning("Please upload a valid .wav audio file.")
57
+ return None
58
+
59
+ return uploaded_file
60
+
61
+ @staticmethod
62
+ def get_approach() -> None:
63
+ """
64
+ Select the approach for input audio summarization.
65
+ """
66
+ approach = st.selectbox(
67
+ "Select the approach for input audio summarization",
68
+ options=["Youtube Link", "Input Audio File"],
69
+ index=1,
70
+ help="Choose the approach you want to use for summarization.",
71
+ )
72
+
73
+ return approach
74
+
75
+ @staticmethod
76
+ def get_link_youtube() -> str:
77
+ """
78
+ Input a YouTube link for audio summarization.
79
+ """
80
+ youtube_link = st.text_input(
81
+ "Enter the YouTube link",
82
+ placeholder="https://www.youtube.com/watch?v=example",
83
+ help="Paste the YouTube link of the video you want to summarize.",
84
+ )
85
+ if youtube_link.strip():
86
+ st.video(youtube_link)
87
+
88
+ return youtube_link
89
+
90
+ @staticmethod
91
+ def get_sidebar_input(state: dict) -> str:
92
+ """
93
+ Handles sidebar interaction and returns the audio path if available.
94
+ """
95
+ with st.sidebar:
96
+ st.markdown("### Select Approach")
97
+ approach = Interface.get_approach()
98
+ state['session'] = 1
99
+
100
+ audio_path = None
101
+
102
+ if approach == "Input Audio File" and state['session'] == 1:
103
+ audio = Interface.get_audio_file()
104
+ if audio is not None:
105
+ audio_path = Utils.temporary_file(audio)
106
+ state['session'] = 2
107
+
108
+ elif approach == "Youtube Link" and state['session'] == 1:
109
+ youtube_link = Interface.get_link_youtube()
110
+ if youtube_link:
111
+ audio_path = Utils.download_youtube_audio_to_tempfile(youtube_link)
112
+ if audio_path is not None:
113
+ with open(audio_path, "rb") as audio_file:
114
+ audio_bytes = audio_file.read()
115
+ st.audio(audio_bytes, format="audio/wav")
116
+ state['session'] = 2
117
+
118
+ generate = False
119
+ if state['session'] == 2 and 'audio_path' in locals() and audio_path:
120
+ generate = st.button("πŸš€ Generate Result !!")
121
+
122
+ return audio_path, generate
123
+
124
+ class Utils:
125
+ @staticmethod
126
+ def temporary_file(uploaded_file: str) -> str:
127
+ """
128
+ Create a temporary file for the uploaded audio file.
129
+ """
130
+ if uploaded_file is not None:
131
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
132
+ temp_file.write(uploaded_file.read())
133
+ temp_file_path = temp_file.name
134
+ return temp_file_path
135
+
136
+ @staticmethod
137
+ def clean_transcript(text: str) -> str:
138
+ """
139
+ Clean the transcript text by removing unwanted characters and formatting.
140
+ """
141
+ text = text.replace(",", " ")
142
+ text = re.sub(r'(?<=[a-zA-Z])\.(?=[a-zA-Z])', ' ', text)
143
+ text = re.sub(r'\s+', ' ', text)
144
+ text = re.sub(r'\s*\.\s*', '. ', text)
145
+ return text.strip()
146
+
147
+ @staticmethod
148
+ def preprocess_audio(input_path: str) -> str:
149
+ """
150
+ Preprocess the audio file by converting it to mono and resampling to 16000 Hz.
151
+ """
152
+ waveform, sample_rate = torchaudio.load(input_path)
153
+ print(f"πŸ“’ Original waveform shape: {waveform.shape}")
154
+ print(f"πŸ“’ Original sample rate: {sample_rate}")
155
+
156
+ # Convert to mono (average if stereo)
157
+ if waveform.shape[0] > 1:
158
+ waveform = waveform.mean(dim=0, keepdim=True)
159
+ print("βœ… Converted to mono.")
160
+
161
+ # Resample to 16000 Hz if needed
162
+ target_sample_rate = 16000
163
+ if sample_rate != target_sample_rate:
164
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
165
+ waveform = resampler(waveform)
166
+ print(f"βœ… Resampled to {target_sample_rate} Hz.")
167
+ sample_rate = target_sample_rate
168
+
169
+ # Create a temporary file for the output
170
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
171
+ output_path = tmpfile.name
172
+
173
+ torchaudio.save(output_path, waveform, sample_rate)
174
+ print(f"βœ… Saved preprocessed audio to temporary file: {output_path}")
175
+
176
+ return output_path
177
+
178
+ @staticmethod
179
+ def _format_filename(input_string, chunk_number=0):
180
+ """
181
+ Format the input string to create a valid filename.
182
+ Replaces non-alphanumeric characters with underscores, removes extra spaces,
183
+ and appends a chunk number if provided.
184
+ """
185
+ input_string = input_string.strip()
186
+ formatted_string = re.sub(r'[^a-zA-Z0-9\s]', '_', input_string)
187
+ formatted_string = re.sub(r'[\s_]+', '_', formatted_string)
188
+ formatted_string = formatted_string.lower()
189
+ formatted_string += f'_chunk_{chunk_number}'
190
+ return formatted_string
191
+
192
+ @staticmethod
193
+ def download_youtube_audio_to_tempfile(youtube_url):
194
+ """
195
+ Download audio from a YouTube video and save it as a WAV file in a temporary directory.
196
+ Returns the path to the saved WAV file.
197
+ """
198
+ try:
199
+ # Get video info to use its title in the filename
200
+ with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
201
+ info_dict = ydl.extract_info(youtube_url, download=False)
202
+ original_title = info_dict.get('title', 'audio')
203
+ formatted_title = Utils._format_filename(original_title)
204
+
205
+ # Create a temporary directory
206
+ temp_dir = tempfile.mkdtemp()
207
+ output_path_no_ext = os.path.join(temp_dir, formatted_title)
208
+
209
+ ydl_opts = {
210
+ 'format': 'bestaudio/best',
211
+ 'postprocessors': [{
212
+ 'key': 'FFmpegExtractAudio',
213
+ 'preferredcodec': 'wav',
214
+ 'preferredquality': '192',
215
+ }],
216
+ 'outtmpl': output_path_no_ext,
217
+ 'quiet': True
218
+ }
219
+
220
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
221
+ ydl.download([youtube_url])
222
+
223
+ # Wait for yt_dlp to actually create the WAV file
224
+ expected_output = output_path_no_ext + ".wav"
225
+ timeout = 5
226
+ while not os.path.exists(expected_output) and timeout > 0:
227
+ time.sleep(1)
228
+ timeout -= 1
229
+
230
+ if not os.path.exists(expected_output):
231
+ raise FileNotFoundError(f"Audio file was not saved as expected: {expected_output}")
232
+
233
+ st.toast(f"Audio downloaded and saved to: {expected_output}")
234
+ return expected_output
235
+
236
+ except Exception as e:
237
+ st.toast(f"Failed to download {youtube_url}: {e}")
238
+ return None
239
+
240
+ class Generation:
241
+ def __init__(
242
+ self,
243
+ summarization_model: str = "vian123/brio-finance-finetuned-v2",
244
+ speech_to_text_model: str = "nyrahealth/CrisperWhisper",
245
+ ):
246
+ self.summarization_model = summarization_model
247
+ self.speech_to_text_model = speech_to_text_model
248
+ self.device = "cpu"
249
+ self.dtype = torch.float32
250
+ self.processor_speech = AutoProcessor.from_pretrained(speech_to_text_model)
251
+ self.model_speech = AutoModelForSpeechSeq2Seq.from_pretrained(
252
+ speech_to_text_model,
253
+ torch_dtype=self.dtype,
254
+ low_cpu_mem_usage=True,
255
+ use_safetensors=True,
256
+ attn_implementation="eager",
257
+ ).to(self.device)
258
+ self.summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
259
+
260
+ def transcribe_audio_pytorch(self, file_path: str) -> str:
261
+ """
262
+ transcribe audio using the PyTorch-based speech-to-text model.
263
+ """
264
+ converted_path = Utils.preprocess_audio(file_path)
265
+ waveform, sample_rate = torchaudio.load(converted_path)
266
+ duration = waveform.shape[1] / sample_rate
267
+ if duration < 1.0:
268
+ print("❌ Audio too short to process.")
269
+ return ""
270
+
271
+ pipe = pipeline(
272
+ "automatic-speech-recognition",
273
+ model=self.model_speech,
274
+ tokenizer=self.processor_speech.tokenizer,
275
+ feature_extractor=self.processor_speech.feature_extractor,
276
+ chunk_length_s=5,
277
+ batch_size=1,
278
+ return_timestamps=None,
279
+ torch_dtype=self.dtype,
280
+ device=self.device,
281
+ model_kwargs={"language": "en"},
282
+ )
283
+
284
+ try:
285
+ hf_pipeline_output = pipe(converted_path)
286
+ print("βœ… HF pipeline output:", hf_pipeline_output)
287
+ return hf_pipeline_output.get("text", "")
288
+ except Exception as e:
289
+ print("❌ Pipeline failed with error:", e)
290
+ return ""
291
+
292
+ def summarize_string(self, text: str) -> str:
293
+ """
294
+ Summarize the input text using the summarization model.
295
+ """
296
+ summarizer = pipeline("summarization", model=self.summarization_model, tokenizer=self.summarization_model)
297
+ try:
298
+ if len(text.strip()) < 10:
299
+ return ""
300
+
301
+ inputs = self.summarization_tokenizer(text, truncation=True, max_length=512, return_tensors="pt")
302
+ truncated_text = self.summarization_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
303
+
304
+ word_count = len(truncated_text.split())
305
+ min_len = max(int(word_count * 0.5), 30)
306
+ max_len = max(min_len + 20, int(word_count * 0.75))
307
+
308
+ summary = summarizer(
309
+ truncated_text,
310
+ max_length=max_len,
311
+ min_length=min_len,
312
+ do_sample=False
313
+ )
314
+ return summary[0]['summary_text']
315
+ except Exception as e:
316
+ return f"Error: {e}"
317
+
318
+ def main():
319
+ Interface.get_header(
320
+ title="Financial YouTube Video Audio Summarization",
321
+ description="🎧 Upload an financial audio file or financial YouTube video link to πŸ“ transcribe and πŸ“„ summarize its content using CrisperWhisper and Financial Fine-tuned BRIO πŸ€–."
322
+ )
323
+
324
+ generate = False
325
+ state = dict(session=0)
326
+
327
+ audio_path, generate = Interface.get_sidebar_input(state)
328
+
329
+ if generate and state['session'] == 2:
330
+ with st.spinner("Generating ..."):
331
+ generation = Generation()
332
+ transcribe = generation.transcribe_audio_pytorch(audio_path)
333
+
334
+ with st.expander("Transcription Text", expanded=True):
335
+ st.text_area("Transcription:", transcribe, height=300)
336
+
337
+ summarization = generation.summarize_string(transcribe)
338
+ with st.expander("Summarization Text", expanded=True):
339
+ st.text_area("Summarization:", summarization, height=300)
340
+
341
+ if __name__ == "__main__":
342
+ main()