Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import tempfile | |
import time | |
import sys | |
import re | |
import os | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq, pipeline, AutoTokenizer | |
from torchaudio.transforms import Resample | |
import soundfile as sf | |
import torchaudio | |
import yt_dlp | |
import torch | |
class Interface: | |
def get_header(title: str, description: str) -> None: | |
st.set_page_config(page_title="Audio Summarization", page_icon="π£οΈ") | |
st.markdown(""" | |
<style> | |
header, #MainMenu, footer {visibility: hidden;} | |
</style> | |
""", unsafe_allow_html=True) | |
st.title(title) | |
st.info(description) | |
def get_audio_file(): | |
uploaded_file = st.file_uploader("Choose an audio file", type=["wav"], help="Upload a .wav audio file.") | |
if uploaded_file is not None: | |
if uploaded_file.name.endswith(".wav"): | |
st.audio(uploaded_file, format="audio/wav") | |
return uploaded_file # Return UploadedFile, not str | |
else: | |
st.warning("Please upload a valid .wav audio file.") | |
return None | |
def get_approach() -> str: | |
return st.selectbox("Select summarization approach", ["Youtube Link", "Input Audio File"], index=1) | |
def get_link_youtube() -> str: | |
youtube_link = st.text_input("Enter YouTube link", placeholder="https://www.youtube.com/watch?v=example") | |
if youtube_link.strip(): | |
st.video(youtube_link) | |
return youtube_link | |
def get_sidebar_input(state: dict) -> tuple: | |
with st.sidebar: | |
st.markdown("### Select Approach") | |
approach = Interface.get_approach() | |
state['session'] = 1 | |
audio_path = None | |
if approach == "Input Audio File": | |
audio = Interface.get_audio_file() | |
if audio: | |
audio_path = Utils.temporary_file(audio) | |
elif approach == "Youtube Link": | |
youtube_link = Interface.get_link_youtube() | |
if youtube_link: | |
audio_path = Utils.download_youtube_audio_to_tempfile(youtube_link) | |
if audio_path: | |
with open(audio_path, "rb") as af: | |
st.audio(af.read(), format="audio/wav") | |
generate = audio_path and st.button("π Generate Result !!") | |
return audio_path, generate | |
class Utils: | |
def temporary_file(uploaded_file: str) -> str: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
tmp.write(uploaded_file.read()) | |
return tmp.name | |
def clean_transcript(text: str) -> str: | |
text = re.sub(r'(?<=[a-zA-Z])\.(?=[a-zA-Z])', ' ', text) | |
text = re.sub(r'[^\w. ]+', ' ', text) | |
return re.sub(r'\s+', ' ', text).strip() | |
def preprocess_audio(input_path: str) -> str: | |
waveform, sample_rate = torchaudio.load(input_path) | |
if waveform.shape[0] > 1: | |
waveform = waveform.mean(dim=0, keepdim=True) | |
if sample_rate != 16000: | |
waveform = Resample(orig_freq=sample_rate, new_freq=16000)(waveform) | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
torchaudio.save(tmp.name, waveform, 16000) | |
return tmp.name | |
def _format_filename(name: str, chunk=0) -> str: | |
clean = re.sub(r'[^a-zA-Z0-9]', '_', name.strip().lower()) | |
return f"{clean}_chunk_{chunk}" | |
def download_youtube_audio_to_tempfile(url: str) -> str: | |
try: | |
with yt_dlp.YoutubeDL({'quiet': True}) as ydl: | |
info = ydl.extract_info(url, download=False) | |
filename = Utils._format_filename(info.get('title', 'audio')) | |
out_dir = tempfile.mkdtemp() | |
output_path = os.path.join(out_dir, filename) | |
ydl_opts = { | |
'format': 'bestaudio/best', | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'wav', | |
'preferredquality': '192', | |
}], | |
'outtmpl': output_path, | |
'quiet': True | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([url]) | |
final_path = output_path + ".wav" | |
for _ in range(5): | |
if os.path.exists(final_path): | |
return final_path | |
time.sleep(1) | |
raise FileNotFoundError(f"File not found: {final_path}") | |
except Exception as e: | |
st.toast(f"Download failed: {e}") | |
return None | |
class Generation: | |
def __init__(self, summarization_model="vian123/brio-finance-finetuned-v2", speech_to_text_model="nyrahealth/CrisperWhisper"): | |
self.device = "cpu" | |
self.dtype = torch.float32 | |
self.processor = AutoProcessor.from_pretrained(speech_to_text_model) | |
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(speech_to_text_model, torch_dtype=self.dtype).to(self.device) | |
self.tokenizer = AutoTokenizer.from_pretrained(summarization_model) | |
self.summarizer = pipeline("summarization", model=summarization_model, tokenizer=summarization_model) | |
def transcribe(self, audio_path: str) -> str: | |
processed_path = Utils.preprocess_audio(audio_path) | |
waveform, rate = torchaudio.load(processed_path) | |
if waveform.shape[1] / rate < 1: | |
return "" | |
asr_pipe = pipeline( | |
"automatic-speech-recognition", | |
model=self.model, | |
tokenizer=self.processor.tokenizer, | |
feature_extractor=self.processor.feature_extractor, | |
chunk_length_s=5, | |
torch_dtype=self.dtype, | |
device=self.device | |
) | |
try: | |
output = asr_pipe(processed_path) | |
return output.get("text", "") | |
except Exception as e: | |
print("ASR error:", e) | |
return "" | |
def summarize(self, text: str) -> str: | |
if len(text.strip()) < 10: | |
return "" | |
cleaned = self.tokenizer(text, truncation=True, max_length=512, return_tensors="pt") | |
decoded = self.tokenizer.decode(cleaned["input_ids"][0], skip_special_tokens=True) | |
word_count = len(decoded.split()) | |
min_len, max_len = max(30, int(word_count * 0.5)), max(50, int(word_count * 0.75)) | |
try: | |
summary = self.summarizer(decoded, max_length=max_len, min_length=min_len, do_sample=False) | |
return summary[0]['summary_text'] | |
except Exception as e: | |
return f"Summarization error: {e}" | |
def main(): | |
Interface.get_header( | |
title="Financial YouTube Video Audio Summarization", | |
description="π§ Upload a financial audio or YouTube video to transcribe and summarize using CrisperWhisper + fine-tuned BRIO." | |
) | |
state = dict(session=0) | |
audio_path, generate = Interface.get_sidebar_input(state) | |
if generate: | |
with st.spinner("Processing..."): | |
gen = Generation() | |
transcript = gen.transcribe(audio_path) | |
st.expander("Transcription Text", expanded=True).text_area("Transcription", transcript, height=300) | |
with st.spinner("Summarizing..."): | |
summary = gen.summarize(transcript) | |
st.expander("Summarization Text", expanded=True).text_area("Summarization", summary, height=300) | |
if __name__ == "__main__": | |
main() | |