vian123 commited on
Commit
2eb50a0
·
verified ·
1 Parent(s): 29e319e

Delete main.py

Browse files
Files changed (1) hide show
  1. main.py +0 -342
main.py DELETED
@@ -1,342 +0,0 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import tempfile
4
- import time
5
- import sys
6
- import re
7
- import os
8
-
9
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline, AutoTokenizer
10
- from torchaudio.transforms import Resample
11
- import soundfile as sf
12
- import torchaudio
13
- import yt_dlp
14
- import torch
15
-
16
- class Interface:
17
- @staticmethod
18
- def get_header(title: str, description: str) -> None:
19
- """
20
- Display the header of the application.
21
- """
22
- st.set_page_config(
23
- page_title="Audio Summarization",
24
- page_icon="🗣️",
25
- )
26
-
27
- hide_decoration_bar_style = """<style>header {visibility: hidden;}</style>"""
28
- st.markdown(hide_decoration_bar_style, unsafe_allow_html=True)
29
- hide_streamlit_footer = """
30
- <style>#MainMenu {visibility: hidden;}
31
- footer {visibility: hidden;}</style>
32
- """
33
- st.markdown(hide_streamlit_footer, unsafe_allow_html=True)
34
-
35
- st.title(title)
36
-
37
- st.info(description)
38
- st.write("\n")
39
-
40
- @staticmethod
41
- def get_audio_file() -> str:
42
- """
43
- Upload an audio file for transcription and summarization.
44
- """
45
- uploaded_file = st.file_uploader(
46
- "Choose an audio file",
47
- type=["wav"],
48
- help="Upload an audio file for transcription and summarization.",
49
- )
50
- if uploaded_file is None:
51
- return None
52
-
53
- if uploaded_file.name.endswith(".wav"):
54
- st.audio(uploaded_file, format="audio/wav")
55
- else:
56
- st.warning("Please upload a valid .wav audio file.")
57
- return None
58
-
59
- return uploaded_file
60
-
61
- @staticmethod
62
- def get_approach() -> None:
63
- """
64
- Select the approach for input audio summarization.
65
- """
66
- approach = st.selectbox(
67
- "Select the approach for input audio summarization",
68
- options=["Youtube Link", "Input Audio File"],
69
- index=1,
70
- help="Choose the approach you want to use for summarization.",
71
- )
72
-
73
- return approach
74
-
75
- @staticmethod
76
- def get_link_youtube() -> str:
77
- """
78
- Input a YouTube link for audio summarization.
79
- """
80
- youtube_link = st.text_input(
81
- "Enter the YouTube link",
82
- placeholder="https://www.youtube.com/watch?v=example",
83
- help="Paste the YouTube link of the video you want to summarize.",
84
- )
85
- if youtube_link.strip():
86
- st.video(youtube_link)
87
-
88
- return youtube_link
89
-
90
- @staticmethod
91
- def get_sidebar_input(state: dict) -> str:
92
- """
93
- Handles sidebar interaction and returns the audio path if available.
94
- """
95
- with st.sidebar:
96
- st.markdown("### Select Approach")
97
- approach = Interface.get_approach()
98
- state['session'] = 1
99
-
100
- audio_path = None
101
-
102
- if approach == "Input Audio File" and state['session'] == 1:
103
- audio = Interface.get_audio_file()
104
- if audio is not None:
105
- audio_path = Utils.temporary_file(audio)
106
- state['session'] = 2
107
-
108
- elif approach == "Youtube Link" and state['session'] == 1:
109
- youtube_link = Interface.get_link_youtube()
110
- if youtube_link:
111
- audio_path = Utils.download_youtube_audio_to_tempfile(youtube_link)
112
- if audio_path is not None:
113
- with open(audio_path, "rb") as audio_file:
114
- audio_bytes = audio_file.read()
115
- st.audio(audio_bytes, format="audio/wav")
116
- state['session'] = 2
117
-
118
- generate = False
119
- if state['session'] == 2 and 'audio_path' in locals() and audio_path:
120
- generate = st.button("🚀 Generate Result !!")
121
-
122
- return audio_path, generate
123
-
124
- class Utils:
125
- @staticmethod
126
- def temporary_file(uploaded_file: str) -> str:
127
- """
128
- Create a temporary file for the uploaded audio file.
129
- """
130
- if uploaded_file is not None:
131
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
132
- temp_file.write(uploaded_file.read())
133
- temp_file_path = temp_file.name
134
- return temp_file_path
135
-
136
- @staticmethod
137
- def clean_transcript(text: str) -> str:
138
- """
139
- Clean the transcript text by removing unwanted characters and formatting.
140
- """
141
- text = text.replace(",", " ")
142
- text = re.sub(r'(?<=[a-zA-Z])\.(?=[a-zA-Z])', ' ', text)
143
- text = re.sub(r'\s+', ' ', text)
144
- text = re.sub(r'\s*\.\s*', '. ', text)
145
- return text.strip()
146
-
147
- @staticmethod
148
- def preprocess_audio(input_path: str) -> str:
149
- """
150
- Preprocess the audio file by converting it to mono and resampling to 16000 Hz.
151
- """
152
- waveform, sample_rate = torchaudio.load(input_path)
153
- print(f"📢 Original waveform shape: {waveform.shape}")
154
- print(f"📢 Original sample rate: {sample_rate}")
155
-
156
- # Convert to mono (average if stereo)
157
- if waveform.shape[0] > 1:
158
- waveform = waveform.mean(dim=0, keepdim=True)
159
- print("✅ Converted to mono.")
160
-
161
- # Resample to 16000 Hz if needed
162
- target_sample_rate = 16000
163
- if sample_rate != target_sample_rate:
164
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
165
- waveform = resampler(waveform)
166
- print(f"✅ Resampled to {target_sample_rate} Hz.")
167
- sample_rate = target_sample_rate
168
-
169
- # Create a temporary file for the output
170
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
171
- output_path = tmpfile.name
172
-
173
- torchaudio.save(output_path, waveform, sample_rate)
174
- print(f"✅ Saved preprocessed audio to temporary file: {output_path}")
175
-
176
- return output_path
177
-
178
- @staticmethod
179
- def _format_filename(input_string, chunk_number=0):
180
- """
181
- Format the input string to create a valid filename.
182
- Replaces non-alphanumeric characters with underscores, removes extra spaces,
183
- and appends a chunk number if provided.
184
- """
185
- input_string = input_string.strip()
186
- formatted_string = re.sub(r'[^a-zA-Z0-9\s]', '_', input_string)
187
- formatted_string = re.sub(r'[\s_]+', '_', formatted_string)
188
- formatted_string = formatted_string.lower()
189
- formatted_string += f'_chunk_{chunk_number}'
190
- return formatted_string
191
-
192
- @staticmethod
193
- def download_youtube_audio_to_tempfile(youtube_url):
194
- """
195
- Download audio from a YouTube video and save it as a WAV file in a temporary directory.
196
- Returns the path to the saved WAV file.
197
- """
198
- try:
199
- # Get video info to use its title in the filename
200
- with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
201
- info_dict = ydl.extract_info(youtube_url, download=False)
202
- original_title = info_dict.get('title', 'audio')
203
- formatted_title = Utils._format_filename(original_title)
204
-
205
- # Create a temporary directory
206
- temp_dir = tempfile.mkdtemp()
207
- output_path_no_ext = os.path.join(temp_dir, formatted_title)
208
-
209
- ydl_opts = {
210
- 'format': 'bestaudio/best',
211
- 'postprocessors': [{
212
- 'key': 'FFmpegExtractAudio',
213
- 'preferredcodec': 'wav',
214
- 'preferredquality': '192',
215
- }],
216
- 'outtmpl': output_path_no_ext,
217
- 'quiet': True
218
- }
219
-
220
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
221
- ydl.download([youtube_url])
222
-
223
- # Wait for yt_dlp to actually create the WAV file
224
- expected_output = output_path_no_ext + ".wav"
225
- timeout = 5
226
- while not os.path.exists(expected_output) and timeout > 0:
227
- time.sleep(1)
228
- timeout -= 1
229
-
230
- if not os.path.exists(expected_output):
231
- raise FileNotFoundError(f"Audio file was not saved as expected: {expected_output}")
232
-
233
- st.toast(f"Audio downloaded and saved to: {expected_output}")
234
- return expected_output
235
-
236
- except Exception as e:
237
- st.toast(f"Failed to download {youtube_url}: {e}")
238
- return None
239
-
240
- class Generation:
241
- def __init__(
242
- self,
243
- summarization_model: str = "vian123/brio-finance-finetuned-v2",
244
- speech_to_text_model: str = "nyrahealth/CrisperWhisper",
245
- ):
246
- self.summarization_model = summarization_model
247
- self.speech_to_text_model = speech_to_text_model
248
- self.device = "cpu"
249
- self.dtype = torch.float32
250
- self.processor_speech = AutoProcessor.from_pretrained(speech_to_text_model)
251
- self.model_speech = AutoModelForSpeechSeq2Seq.from_pretrained(
252
- speech_to_text_model,
253
- torch_dtype=self.dtype,
254
- low_cpu_mem_usage=True,
255
- use_safetensors=True,
256
- attn_implementation="eager",
257
- ).to(self.device)
258
- self.summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
259
-
260
- def transcribe_audio_pytorch(self, file_path: str) -> str:
261
- """
262
- transcribe audio using the PyTorch-based speech-to-text model.
263
- """
264
- converted_path = Utils.preprocess_audio(file_path)
265
- waveform, sample_rate = torchaudio.load(converted_path)
266
- duration = waveform.shape[1] / sample_rate
267
- if duration < 1.0:
268
- print("❌ Audio too short to process.")
269
- return ""
270
-
271
- pipe = pipeline(
272
- "automatic-speech-recognition",
273
- model=self.model_speech,
274
- tokenizer=self.processor_speech.tokenizer,
275
- feature_extractor=self.processor_speech.feature_extractor,
276
- chunk_length_s=5,
277
- batch_size=1,
278
- return_timestamps=None,
279
- torch_dtype=self.dtype,
280
- device=self.device,
281
- model_kwargs={"language": "en"},
282
- )
283
-
284
- try:
285
- hf_pipeline_output = pipe(converted_path)
286
- print("✅ HF pipeline output:", hf_pipeline_output)
287
- return hf_pipeline_output.get("text", "")
288
- except Exception as e:
289
- print("❌ Pipeline failed with error:", e)
290
- return ""
291
-
292
- def summarize_string(self, text: str) -> str:
293
- """
294
- Summarize the input text using the summarization model.
295
- """
296
- summarizer = pipeline("summarization", model=self.summarization_model, tokenizer=self.summarization_model)
297
- try:
298
- if len(text.strip()) < 10:
299
- return ""
300
-
301
- inputs = self.summarization_tokenizer(text, truncation=True, max_length=512, return_tensors="pt")
302
- truncated_text = self.summarization_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
303
-
304
- word_count = len(truncated_text.split())
305
- min_len = max(int(word_count * 0.5), 30)
306
- max_len = max(min_len + 20, int(word_count * 0.75))
307
-
308
- summary = summarizer(
309
- truncated_text,
310
- max_length=max_len,
311
- min_length=min_len,
312
- do_sample=False
313
- )
314
- return summary[0]['summary_text']
315
- except Exception as e:
316
- return f"Error: {e}"
317
-
318
- def main():
319
- Interface.get_header(
320
- title="Financial YouTube Video Audio Summarization",
321
- description="🎧 Upload an financial audio file or financial YouTube video link to 📝 transcribe and 📄 summarize its content using CrisperWhisper and Financial Fine-tuned BRIO 🤖."
322
- )
323
-
324
- generate = False
325
- state = dict(session=0)
326
-
327
- audio_path, generate = Interface.get_sidebar_input(state)
328
-
329
- if generate and state['session'] == 2:
330
- with st.spinner("Generating ..."):
331
- generation = Generation()
332
- transcribe = generation.transcribe_audio_pytorch(audio_path)
333
-
334
- with st.expander("Transcription Text", expanded=True):
335
- st.text_area("Transcription:", transcribe, height=300)
336
-
337
- summarization = generation.summarize_string(transcribe)
338
- with st.expander("Summarization Text", expanded=True):
339
- st.text_area("Summarization:", summarization, height=300)
340
-
341
- if __name__ == "__main__":
342
- main()