Spaces:
Sleeping
Sleeping
# Add these imports at the top | |
import soundfile as sf | |
import librosa | |
from pathlib import Path | |
import humanize | |
from datetime import timedelta | |
# Add these constants | |
MAX_FILE_SIZE = 25 * 1024 * 1024 # 25MB | |
MAX_AUDIO_DURATION = 600 # 10 minutes in seconds | |
SUPPORTED_FORMATS = { | |
'.wav': 'WAV audio', | |
'.mp3': 'MP3 audio', | |
'.m4a': 'M4A audio' | |
} | |
class AudioValidator: | |
"""Handles audio file validation and provides detailed feedback""" | |
def validate_audio_file(file) -> tuple[bool, str]: | |
try: | |
# Check if file is provided | |
if file is None: | |
return False, "No file was uploaded." | |
# Check file size | |
file_size = len(file.getvalue()) | |
if file_size > MAX_FILE_SIZE: | |
readable_size = humanize.naturalsize(file_size) | |
max_size = humanize.naturalsize(MAX_FILE_SIZE) | |
return False, f"File size ({readable_size}) exceeds maximum allowed size ({max_size})" | |
# Check file extension | |
file_extension = Path(file.name).suffix.lower() | |
if file_extension not in SUPPORTED_FORMATS: | |
return False, f"Unsupported file format. Please upload {', '.join(SUPPORTED_FORMATS.values())}" | |
# Save file temporarily for duration check | |
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: | |
tmp_file.write(file.getvalue()) | |
tmp_file_path = tmp_file.name | |
try: | |
# Check audio duration | |
duration = librosa.get_duration(path=tmp_file_path) | |
if duration > MAX_AUDIO_DURATION: | |
formatted_duration = str(timedelta(seconds=int(duration))) | |
max_duration = str(timedelta(seconds=MAX_AUDIO_DURATION)) | |
return False, f"Audio duration ({formatted_duration}) exceeds maximum allowed length ({max_duration})" | |
# Check audio quality | |
y, sr = librosa.load(tmp_file_path) | |
if sr < 16000: | |
return False, f"Audio quality too low. Sample rate ({sr}Hz) should be at least 16kHz" | |
return True, "Audio file is valid" | |
finally: | |
os.unlink(tmp_file_path) | |
except Exception as e: | |
logger.error(f"Audio validation error: {str(e)}") | |
return False, f"Error validating audio file: {str(e)}" | |
class AudioProcessor: | |
"""Enhanced audio processor with better feedback and error handling""" | |
def __init__(self, model): | |
self.model = model | |
self.validator = AudioValidator() | |
def process_audio_chunk(self, audio_file) -> tuple[Optional[str], Dict[str, Any]]: | |
processing_stats = { | |
'duration': None, | |
'sample_rate': None, | |
'file_size': None, | |
'processing_time': None, | |
'status': 'pending' | |
} | |
try: | |
start_time = datetime.now() | |
# Validate file | |
is_valid, validation_message = self.validator.validate_audio_file(audio_file) | |
if not is_valid: | |
processing_stats['status'] = 'failed' | |
processing_stats['error'] = validation_message | |
return None, processing_stats | |
# Get file stats | |
file_size = len(audio_file.getvalue()) | |
processing_stats['file_size'] = humanize.naturalsize(file_size) | |
# Process audio | |
file_extension = Path(audio_file.name).suffix.lower() | |
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: | |
audio_file.seek(0) | |
tmp_file.write(audio_file.getvalue()) | |
tmp_file_path = tmp_file.name | |
try: | |
# Get audio info | |
y, sr = librosa.load(tmp_file_path) | |
duration = librosa.get_duration(y=y, sr=sr) | |
processing_stats.update({ | |
'duration': str(timedelta(seconds=int(duration))), | |
'sample_rate': f"{sr/1000:.1f}kHz" | |
}) | |
# Transcribe audio | |
result = self.model.transcribe( | |
tmp_file_path, | |
language="en", | |
task="transcribe", | |
fp16=True if torch.cuda.is_available() else False | |
) | |
# Update stats | |
processing_time = (datetime.now() - start_time).total_seconds() | |
processing_stats.update({ | |
'processing_time': f"{processing_time:.1f}s", | |
'status': 'success' | |
}) | |
return result["text"], processing_stats | |
finally: | |
if os.path.exists(tmp_file_path): | |
os.unlink(tmp_file_path) | |
except Exception as e: | |
error_message = str(e) | |
logger.error(f"Audio processing error: {error_message}") | |
processing_stats.update({ | |
'status': 'failed', | |
'error': error_message | |
}) | |
return None, processing_stats | |
finally: | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
class UIManager: | |
"""Enhanced UI manager with better feedback and progress indicators""" | |
def setup_page(): | |
st.set_page_config( | |
page_title="VC Call Assistant", | |
page_icon="ποΈ", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
def show_file_uploader() -> Optional[Any]: | |
st.markdown(""" | |
### π Upload Audio File | |
**Supported formats:** | |
- WAV (recommended) | |
- MP3 | |
- M4A | |
**Limitations:** | |
- Maximum file size: 25MB | |
- Maximum duration: 10 minutes | |
- Minimum sample rate: 16kHz | |
""") | |
return st.file_uploader( | |
"Choose an audio file", | |
type=['wav', 'mp3', 'm4a'] | |
) | |
def show_processing_stats(stats: Dict[str, Any]): | |
"""Display processing statistics in a nice format""" | |
if not stats: | |
return | |
st.markdown("### π Processing Statistics") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Duration", stats.get('duration', 'N/A')) | |
st.metric("File Size", stats.get('file_size', 'N/A')) | |
with col2: | |
st.metric("Sample Rate", stats.get('sample_rate', 'N/A')) | |
st.metric("Processing Time", stats.get('processing_time', 'N/A')) | |
with col3: | |
status = stats.get('status', 'unknown') | |
if status == 'success': | |
st.success("Processing Completed") | |
elif status == 'failed': | |
st.error(f"Processing Failed: {stats.get('error', 'Unknown error')}") | |
else: | |
st.info("Processing Pending") | |
def main(): | |
try: | |
UIManager.setup_page() | |
with st.sidebar: | |
st.title("VC Assistant Settings") | |
model_name = "GPT2" | |
st.info(f"""Using {model_name} | |
Memory Usage: {MODEL_CONFIGS[model_name]['memory_required']} | |
Description: {MODEL_CONFIGS[model_name]['description']}""") | |
vc_name = st.text_input("Your Name") | |
note_style = st.selectbox( | |
"Note Style", | |
["Bullet Points", "Paragraphs", "Q&A"] | |
) | |
interests = st.multiselect( | |
"Focus Areas", | |
["Product", "Market", "Team", "Financials", "Technology"], | |
default=["Product", "Market"] | |
) | |
st.title("ποΈ VC Call Assistant") | |
if not vc_name: | |
st.warning("Please enter your name in the sidebar.") | |
return | |
# Initialize models with progress tracking | |
progress_text = "Loading models..." | |
progress_bar = st.progress(0, text=progress_text) | |
try: | |
progress_bar.progress(25, text="Loading Whisper model...") | |
whisper_model = ModelManager.load_whisper() | |
progress_bar.progress(50, text="Loading language model...") | |
llm = ModelManager.load_llm(model_name) | |
if not whisper_model or not llm: | |
st.error("Failed to initialize models. Please refresh the page.") | |
return | |
progress_bar.progress(75, text="Initializing processors...") | |
audio_processor = AudioProcessor(whisper_model) | |
analyzer = ContentAnalyzer(llm) | |
progress_bar.progress(100, text="Ready!") | |
finally: | |
progress_bar.empty() | |
# File upload and processing | |
audio_file = UIManager.show_file_uploader() | |
if audio_file: | |
with st.spinner("Processing audio..."): | |
transcription, processing_stats = audio_processor.process_audio_chunk(audio_file) | |
# Show processing statistics | |
UIManager.show_processing_stats(processing_stats) | |
if transcription: | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("π Transcript") | |
st.write(transcription) | |
with col2: | |
st.subheader("π Analysis") | |
with st.spinner("Analyzing transcript..."): | |
vc_style = VCStyle( | |
name=vc_name, | |
note_format={"style": note_style}, | |
key_interests=interests, | |
custom_sections=[], | |
insight_preferences={} | |
) | |
analysis = analyzer.analyze_text(transcription, vc_style) | |
if analysis: | |
st.write(analysis) | |
st.download_button( | |
"π₯ Export Analysis", | |
data=json.dumps({ | |
"timestamp": datetime.now().isoformat(), | |
"transcription": transcription, | |
"analysis": analysis, | |
"processing_stats": processing_stats | |
}, indent=2), | |
file_name=f"vc_analysis_{datetime.now():%Y%m%d_%H%M%S}.json", | |
mime="application/json" | |
) | |
except Exception as e: | |
logger.error(f"Application error: {str(e)}") | |
st.error(f""" | |
An unexpected error occurred: {str(e)} | |
Please try: | |
1. Refreshing the page | |
2. Using a different audio file | |
3. Checking your internet connection | |
""") | |
finally: | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
if __name__ == "__main__": | |
main() | |