Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +127 -259
src/streamlit_app.py
CHANGED
@@ -13,331 +13,199 @@ import torchaudio
|
|
13 |
import yt_dlp
|
14 |
import torch
|
15 |
|
|
|
16 |
class Interface:
|
17 |
@staticmethod
|
18 |
def get_header(title: str, description: str) -> None:
|
19 |
-
"""
|
20 |
-
|
21 |
-
"""
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
)
|
26 |
|
27 |
-
hide_decoration_bar_style = """<style>header {visibility: hidden;}</style>"""
|
28 |
-
st.markdown(hide_decoration_bar_style, unsafe_allow_html=True)
|
29 |
-
hide_streamlit_footer = """
|
30 |
-
<style>#MainMenu {visibility: hidden;}
|
31 |
-
footer {visibility: hidden;}</style>
|
32 |
-
"""
|
33 |
-
st.markdown(hide_streamlit_footer, unsafe_allow_html=True)
|
34 |
-
|
35 |
st.title(title)
|
36 |
-
|
37 |
st.info(description)
|
38 |
-
st.write("\n")
|
39 |
|
40 |
@staticmethod
|
41 |
def get_audio_file() -> str:
|
42 |
-
"""
|
43 |
-
|
44 |
-
"""
|
45 |
-
uploaded_file = st.file_uploader(
|
46 |
-
"Choose an audio file",
|
47 |
-
type=["wav"],
|
48 |
-
help="Upload an audio file for transcription and summarization.",
|
49 |
-
)
|
50 |
-
if uploaded_file is None:
|
51 |
-
return None
|
52 |
-
|
53 |
-
if uploaded_file.name.endswith(".wav"):
|
54 |
st.audio(uploaded_file, format="audio/wav")
|
55 |
-
|
|
|
56 |
st.warning("Please upload a valid .wav audio file.")
|
57 |
-
|
58 |
-
|
59 |
-
return uploaded_file
|
60 |
-
|
61 |
@staticmethod
|
62 |
-
def get_approach() ->
|
63 |
-
"""
|
64 |
-
Select the approach for input audio summarization.
|
65 |
-
"""
|
66 |
-
approach = st.selectbox(
|
67 |
-
"Select the approach for input audio summarization",
|
68 |
-
options=["Youtube Link", "Input Audio File"],
|
69 |
-
index=1,
|
70 |
-
help="Choose the approach you want to use for summarization.",
|
71 |
-
)
|
72 |
|
73 |
-
return approach
|
74 |
-
|
75 |
@staticmethod
|
76 |
def get_link_youtube() -> str:
|
77 |
-
"""
|
78 |
-
Input a YouTube link for audio summarization.
|
79 |
-
"""
|
80 |
-
youtube_link = st.text_input(
|
81 |
-
"Enter the YouTube link",
|
82 |
-
placeholder="https://www.youtube.com/watch?v=example",
|
83 |
-
help="Paste the YouTube link of the video you want to summarize.",
|
84 |
-
)
|
85 |
if youtube_link.strip():
|
86 |
st.video(youtube_link)
|
87 |
-
|
88 |
return youtube_link
|
89 |
-
|
90 |
@staticmethod
|
91 |
-
def get_sidebar_input(state: dict) ->
|
92 |
-
"""
|
93 |
-
Handles sidebar interaction and returns the audio path if available.
|
94 |
-
"""
|
95 |
with st.sidebar:
|
96 |
st.markdown("### Select Approach")
|
97 |
approach = Interface.get_approach()
|
98 |
state['session'] = 1
|
99 |
|
100 |
audio_path = None
|
101 |
-
|
102 |
-
if approach == "Input Audio File" and state['session'] == 1:
|
103 |
audio = Interface.get_audio_file()
|
104 |
-
if audio
|
105 |
audio_path = Utils.temporary_file(audio)
|
106 |
-
|
107 |
-
|
108 |
-
elif approach == "Youtube Link" and state['session'] == 1:
|
109 |
youtube_link = Interface.get_link_youtube()
|
110 |
if youtube_link:
|
111 |
audio_path = Utils.download_youtube_audio_to_tempfile(youtube_link)
|
112 |
-
if audio_path
|
113 |
-
with open(audio_path, "rb") as
|
114 |
-
|
115 |
-
st.audio(audio_bytes, format="audio/wav")
|
116 |
-
state['session'] = 2
|
117 |
-
|
118 |
-
generate = False
|
119 |
-
if state['session'] == 2 and 'audio_path' in locals() and audio_path:
|
120 |
-
generate = st.button("🚀 Generate Result !!")
|
121 |
|
|
|
122 |
return audio_path, generate
|
123 |
|
|
|
124 |
class Utils:
|
125 |
@staticmethod
|
126 |
def temporary_file(uploaded_file: str) -> str:
|
127 |
-
""
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
temp_file.write(uploaded_file.read())
|
133 |
-
temp_file_path = temp_file.name
|
134 |
-
return temp_file_path
|
135 |
-
|
136 |
-
@staticmethod
|
137 |
def clean_transcript(text: str) -> str:
|
138 |
-
"""
|
139 |
-
Clean the transcript text by removing unwanted characters and formatting.
|
140 |
-
"""
|
141 |
-
text = text.replace(",", " ")
|
142 |
text = re.sub(r'(?<=[a-zA-Z])\.(?=[a-zA-Z])', ' ', text)
|
143 |
-
text = re.sub(r'
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
@staticmethod
|
148 |
def preprocess_audio(input_path: str) -> str:
|
149 |
-
"""
|
150 |
-
Preprocess the audio file by converting it to mono and resampling to 16000 Hz.
|
151 |
-
"""
|
152 |
waveform, sample_rate = torchaudio.load(input_path)
|
153 |
-
print(f"📢 Original waveform shape: {waveform.shape}")
|
154 |
-
print(f"📢 Original sample rate: {sample_rate}")
|
155 |
-
|
156 |
-
# Convert to mono (average if stereo)
|
157 |
if waveform.shape[0] > 1:
|
158 |
waveform = waveform.mean(dim=0, keepdim=True)
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
165 |
-
waveform = resampler(waveform)
|
166 |
-
print(f"✅ Resampled to {target_sample_rate} Hz.")
|
167 |
-
sample_rate = target_sample_rate
|
168 |
|
169 |
-
# Create a temporary file for the output
|
170 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
|
171 |
-
output_path = tmpfile.name
|
172 |
-
|
173 |
-
torchaudio.save(output_path, waveform, sample_rate)
|
174 |
-
print(f"✅ Saved preprocessed audio to temporary file: {output_path}")
|
175 |
-
|
176 |
-
return output_path
|
177 |
-
|
178 |
@staticmethod
|
179 |
-
def _format_filename(
|
180 |
-
|
181 |
-
|
182 |
-
Replaces non-alphanumeric characters with underscores, removes extra spaces,
|
183 |
-
and appends a chunk number if provided.
|
184 |
-
"""
|
185 |
-
input_string = input_string.strip()
|
186 |
-
formatted_string = re.sub(r'[^a-zA-Z0-9\s]', '_', input_string)
|
187 |
-
formatted_string = re.sub(r'[\s_]+', '_', formatted_string)
|
188 |
-
formatted_string = formatted_string.lower()
|
189 |
-
formatted_string += f'_chunk_{chunk_number}'
|
190 |
-
return formatted_string
|
191 |
|
192 |
@staticmethod
|
193 |
-
def download_youtube_audio_to_tempfile(
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
timeout = 5
|
227 |
-
while not os.path.exists(expected_output) and timeout > 0:
|
228 |
-
time.sleep(1)
|
229 |
-
timeout -= 1
|
230 |
-
|
231 |
-
if not os.path.exists(expected_output):
|
232 |
-
raise FileNotFoundError(f"Audio file was not saved as expected: {expected_output}")
|
233 |
-
|
234 |
-
st.toast(f"Audio downloaded and saved to: {expected_output}")
|
235 |
-
return expected_output
|
236 |
-
|
237 |
-
except Exception as e:
|
238 |
-
st.toast(f"Failed to download {youtube_url}: {e}")
|
239 |
-
return None
|
240 |
|
241 |
class Generation:
|
242 |
-
def __init__(
|
243 |
-
self,
|
244 |
-
summarization_model: str = "vian123/brio-finance-finetuned-v2",
|
245 |
-
speech_to_text_model: str = "nyrahealth/CrisperWhisper",
|
246 |
-
):
|
247 |
-
self.summarization_model = summarization_model
|
248 |
-
self.speech_to_text_model = speech_to_text_model
|
249 |
self.device = "cpu"
|
250 |
self.dtype = torch.float32
|
251 |
-
|
252 |
-
self.
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
def transcribe_audio_pytorch(self, file_path: str) -> str:
|
262 |
-
"""
|
263 |
-
transcribe audio using the PyTorch-based speech-to-text model.
|
264 |
-
"""
|
265 |
-
converted_path = Utils.preprocess_audio(file_path)
|
266 |
-
waveform, sample_rate = torchaudio.load(converted_path)
|
267 |
-
duration = waveform.shape[1] / sample_rate
|
268 |
-
if duration < 1.0:
|
269 |
-
print("❌ Audio too short to process.")
|
270 |
return ""
|
271 |
|
272 |
-
|
273 |
"automatic-speech-recognition",
|
274 |
-
model=self.
|
275 |
-
tokenizer=self.
|
276 |
-
feature_extractor=self.
|
277 |
chunk_length_s=5,
|
278 |
-
batch_size=1,
|
279 |
-
return_timestamps=None,
|
280 |
torch_dtype=self.dtype,
|
281 |
-
device=self.device
|
282 |
-
model_kwargs={"language": "en"},
|
283 |
)
|
284 |
|
285 |
try:
|
286 |
-
|
287 |
-
|
288 |
-
return hf_pipeline_output.get("text", "")
|
289 |
except Exception as e:
|
290 |
-
print("
|
291 |
return ""
|
292 |
|
293 |
-
def
|
294 |
-
|
295 |
-
|
296 |
-
""
|
297 |
-
|
|
|
|
|
|
|
|
|
298 |
try:
|
299 |
-
|
300 |
-
return ""
|
301 |
-
|
302 |
-
inputs = self.summarization_tokenizer(text, truncation=True, max_length=512, return_tensors="pt")
|
303 |
-
truncated_text = self.summarization_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
304 |
-
|
305 |
-
word_count = len(truncated_text.split())
|
306 |
-
min_len = max(int(word_count * 0.5), 30)
|
307 |
-
max_len = max(min_len + 20, int(word_count * 0.75))
|
308 |
-
|
309 |
-
summary = summarizer(
|
310 |
-
truncated_text,
|
311 |
-
max_length=max_len,
|
312 |
-
min_length=min_len,
|
313 |
-
do_sample=False
|
314 |
-
)
|
315 |
return summary[0]['summary_text']
|
316 |
except Exception as e:
|
317 |
-
return f"
|
318 |
-
|
|
|
319 |
def main():
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
with st.expander("Summarization Text", expanded=True):
|
340 |
-
st.text_area("Summarization:", summarization, height=300)
|
341 |
|
342 |
if __name__ == "__main__":
|
343 |
main()
|
|
|
13 |
import yt_dlp
|
14 |
import torch
|
15 |
|
16 |
+
|
17 |
class Interface:
|
18 |
@staticmethod
|
19 |
def get_header(title: str, description: str) -> None:
|
20 |
+
st.set_page_config(page_title="Audio Summarization", page_icon="🗣️")
|
21 |
+
|
22 |
+
st.markdown("""
|
23 |
+
<style>
|
24 |
+
header, #MainMenu, footer {visibility: hidden;}
|
25 |
+
</style>
|
26 |
+
""", unsafe_allow_html=True)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
st.title(title)
|
|
|
29 |
st.info(description)
|
|
|
30 |
|
31 |
@staticmethod
|
32 |
def get_audio_file() -> str:
|
33 |
+
uploaded_file = st.file_uploader("Choose an audio file", type=["wav"], help="Upload a .wav audio file.")
|
34 |
+
if uploaded_file and uploaded_file.name.endswith(".wav"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
st.audio(uploaded_file, format="audio/wav")
|
36 |
+
return uploaded_file
|
37 |
+
elif uploaded_file:
|
38 |
st.warning("Please upload a valid .wav audio file.")
|
39 |
+
return None
|
40 |
+
|
|
|
|
|
41 |
@staticmethod
|
42 |
+
def get_approach() -> str:
|
43 |
+
return st.selectbox("Select summarization approach", ["Youtube Link", "Input Audio File"], index=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
|
|
|
|
45 |
@staticmethod
|
46 |
def get_link_youtube() -> str:
|
47 |
+
youtube_link = st.text_input("Enter YouTube link", placeholder="https://www.youtube.com/watch?v=example")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
if youtube_link.strip():
|
49 |
st.video(youtube_link)
|
|
|
50 |
return youtube_link
|
51 |
+
|
52 |
@staticmethod
|
53 |
+
def get_sidebar_input(state: dict) -> tuple:
|
|
|
|
|
|
|
54 |
with st.sidebar:
|
55 |
st.markdown("### Select Approach")
|
56 |
approach = Interface.get_approach()
|
57 |
state['session'] = 1
|
58 |
|
59 |
audio_path = None
|
60 |
+
if approach == "Input Audio File":
|
|
|
61 |
audio = Interface.get_audio_file()
|
62 |
+
if audio:
|
63 |
audio_path = Utils.temporary_file(audio)
|
64 |
+
elif approach == "Youtube Link":
|
|
|
|
|
65 |
youtube_link = Interface.get_link_youtube()
|
66 |
if youtube_link:
|
67 |
audio_path = Utils.download_youtube_audio_to_tempfile(youtube_link)
|
68 |
+
if audio_path:
|
69 |
+
with open(audio_path, "rb") as af:
|
70 |
+
st.audio(af.read(), format="audio/wav")
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
+
generate = audio_path and st.button("🚀 Generate Result !!")
|
73 |
return audio_path, generate
|
74 |
|
75 |
+
|
76 |
class Utils:
|
77 |
@staticmethod
|
78 |
def temporary_file(uploaded_file: str) -> str:
|
79 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
80 |
+
tmp.write(uploaded_file.read())
|
81 |
+
return tmp.name
|
82 |
+
|
83 |
+
@staticmethod
|
|
|
|
|
|
|
|
|
|
|
84 |
def clean_transcript(text: str) -> str:
|
|
|
|
|
|
|
|
|
85 |
text = re.sub(r'(?<=[a-zA-Z])\.(?=[a-zA-Z])', ' ', text)
|
86 |
+
text = re.sub(r'[^\w. ]+', ' ', text)
|
87 |
+
return re.sub(r'\s+', ' ', text).strip()
|
88 |
+
|
|
|
89 |
@staticmethod
|
90 |
def preprocess_audio(input_path: str) -> str:
|
|
|
|
|
|
|
91 |
waveform, sample_rate = torchaudio.load(input_path)
|
|
|
|
|
|
|
|
|
92 |
if waveform.shape[0] > 1:
|
93 |
waveform = waveform.mean(dim=0, keepdim=True)
|
94 |
+
if sample_rate != 16000:
|
95 |
+
waveform = Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
|
96 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
97 |
+
torchaudio.save(tmp.name, waveform, 16000)
|
98 |
+
return tmp.name
|
|
|
|
|
|
|
|
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
@staticmethod
|
101 |
+
def _format_filename(name: str, chunk=0) -> str:
|
102 |
+
clean = re.sub(r'[^a-zA-Z0-9]', '_', name.strip().lower())
|
103 |
+
return f"{clean}_chunk_{chunk}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
@staticmethod
|
106 |
+
def download_youtube_audio_to_tempfile(url: str) -> str:
|
107 |
+
try:
|
108 |
+
with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
|
109 |
+
info = ydl.extract_info(url, download=False)
|
110 |
+
filename = Utils._format_filename(info.get('title', 'audio'))
|
111 |
+
|
112 |
+
out_dir = tempfile.mkdtemp()
|
113 |
+
output_path = os.path.join(out_dir, filename)
|
114 |
+
|
115 |
+
ydl_opts = {
|
116 |
+
'format': 'bestaudio/best',
|
117 |
+
'postprocessors': [{
|
118 |
+
'key': 'FFmpegExtractAudio',
|
119 |
+
'preferredcodec': 'wav',
|
120 |
+
'preferredquality': '192',
|
121 |
+
}],
|
122 |
+
'outtmpl': output_path,
|
123 |
+
'quiet': True
|
124 |
+
}
|
125 |
+
|
126 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
127 |
+
ydl.download([url])
|
128 |
+
|
129 |
+
final_path = output_path + ".wav"
|
130 |
+
for _ in range(5):
|
131 |
+
if os.path.exists(final_path):
|
132 |
+
return final_path
|
133 |
+
time.sleep(1)
|
134 |
+
raise FileNotFoundError(f"File not found: {final_path}")
|
135 |
+
except Exception as e:
|
136 |
+
st.toast(f"Download failed: {e}")
|
137 |
+
return None
|
138 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
class Generation:
|
141 |
+
def __init__(self, summarization_model="vian123/brio-finance-finetuned-v2", speech_to_text_model="nyrahealth/CrisperWhisper"):
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
self.device = "cpu"
|
143 |
self.dtype = torch.float32
|
144 |
+
|
145 |
+
self.processor = AutoProcessor.from_pretrained(speech_to_text_model)
|
146 |
+
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(speech_to_text_model, torch_dtype=self.dtype).to(self.device)
|
147 |
+
self.tokenizer = AutoTokenizer.from_pretrained(summarization_model)
|
148 |
+
self.summarizer = pipeline("summarization", model=summarization_model, tokenizer=summarization_model)
|
149 |
+
|
150 |
+
def transcribe(self, audio_path: str) -> str:
|
151 |
+
processed_path = Utils.preprocess_audio(audio_path)
|
152 |
+
waveform, rate = torchaudio.load(processed_path)
|
153 |
+
if waveform.shape[1] / rate < 1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
return ""
|
155 |
|
156 |
+
asr_pipe = pipeline(
|
157 |
"automatic-speech-recognition",
|
158 |
+
model=self.model,
|
159 |
+
tokenizer=self.processor.tokenizer,
|
160 |
+
feature_extractor=self.processor.feature_extractor,
|
161 |
chunk_length_s=5,
|
|
|
|
|
162 |
torch_dtype=self.dtype,
|
163 |
+
device=self.device
|
|
|
164 |
)
|
165 |
|
166 |
try:
|
167 |
+
output = asr_pipe(processed_path)
|
168 |
+
return output.get("text", "")
|
|
|
169 |
except Exception as e:
|
170 |
+
print("ASR error:", e)
|
171 |
return ""
|
172 |
|
173 |
+
def summarize(self, text: str) -> str:
|
174 |
+
if len(text.strip()) < 10:
|
175 |
+
return ""
|
176 |
+
cleaned = self.tokenizer(text, truncation=True, max_length=512, return_tensors="pt")
|
177 |
+
decoded = self.tokenizer.decode(cleaned["input_ids"][0], skip_special_tokens=True)
|
178 |
+
|
179 |
+
word_count = len(decoded.split())
|
180 |
+
min_len, max_len = max(30, int(word_count * 0.5)), max(50, int(word_count * 0.75))
|
181 |
+
|
182 |
try:
|
183 |
+
summary = self.summarizer(decoded, max_length=max_len, min_length=min_len, do_sample=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
return summary[0]['summary_text']
|
185 |
except Exception as e:
|
186 |
+
return f"Summarization error: {e}"
|
187 |
+
|
188 |
+
|
189 |
def main():
|
190 |
+
Interface.get_header(
|
191 |
+
title="Financial YouTube Video Audio Summarization",
|
192 |
+
description="🎧 Upload a financial audio or YouTube video to transcribe and summarize using CrisperWhisper + fine-tuned BRIO."
|
193 |
+
)
|
194 |
+
|
195 |
+
state = dict(session=0)
|
196 |
+
audio_path, generate = Interface.get_sidebar_input(state)
|
197 |
+
|
198 |
+
if generate:
|
199 |
+
with st.spinner("Processing..."):
|
200 |
+
gen = Generation()
|
201 |
+
transcript = gen.transcribe(audio_path)
|
202 |
+
|
203 |
+
st.expander("Transcription Text", expanded=True).text_area("Transcription", transcript, height=300)
|
204 |
+
|
205 |
+
with st.spinner("Summarizing..."):
|
206 |
+
summary = gen.summarize(transcript)
|
207 |
+
st.expander("Summarization Text", expanded=True).text_area("Summarization", summary, height=300)
|
208 |
+
|
|
|
|
|
209 |
|
210 |
if __name__ == "__main__":
|
211 |
main()
|