Spaces:
Sleeping
Sleeping
import os | |
import gc | |
import json | |
import logging | |
import tempfile | |
from datetime import datetime, timedelta | |
from pathlib import Path | |
from dataclasses import dataclass | |
import streamlit as st | |
import whisper | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import numpy as np | |
import librosa | |
import humanize | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Constants | |
MAX_FILE_SIZE = 25 * 1024 * 1024 # 25MB | |
MAX_AUDIO_DURATION = 600 # 10 minutes | |
MIN_SAMPLE_RATE = 16000 # 16kHz | |
SUPPORTED_FORMATS = {'.wav', '.mp3', '.m4a'} | |
# Model configuration | |
MODEL_CONFIG = { | |
"path": "gpt2", | |
"description": "Efficient open-source model for analysis", | |
"memory_required": "8GB" | |
} | |
class VCStyle: | |
name: str | |
note_format: dict | |
key_interests: list | |
custom_sections: list | |
insight_preferences: dict | |
class AudioValidator: | |
def validate_audio_file(file): | |
stats = { | |
'file_size': None, | |
'duration': None, | |
'sample_rate': None, | |
'format': None | |
} | |
try: | |
if file is None: | |
return False, "No file was uploaded.", stats | |
# Check file size | |
file_size = len(file.getvalue()) | |
stats['file_size'] = humanize.naturalsize(file_size) | |
if file_size > MAX_FILE_SIZE: | |
return False, f"File size ({stats['file_size']}) exceeds limit", stats | |
# Check file extension | |
file_extension = Path(file.name).suffix.lower() | |
stats['format'] = file_extension | |
if file_extension not in SUPPORTED_FORMATS: | |
return False, f"Unsupported format {file_extension}", stats | |
# Create temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file: | |
tmp_file.write(file.getvalue()) | |
tmp_file_path = tmp_file.name | |
try: | |
# Check audio properties | |
y, sr = librosa.load(tmp_file_path, sr=None) | |
duration = librosa.get_duration(y=y, sr=sr) | |
stats.update({ | |
'duration': str(timedelta(seconds=int(duration))), | |
'sample_rate': f"{sr/1000:.1f}kHz" | |
}) | |
if duration > MAX_AUDIO_DURATION: | |
return False, f"Duration ({stats['duration']}) exceeds limit", stats | |
if sr < MIN_SAMPLE_RATE: | |
return False, f"Sample rate too low ({stats['sample_rate']})", stats | |
return True, "Audio file is valid", stats | |
finally: | |
os.unlink(tmp_file_path) | |
except Exception as e: | |
logger.error(f"Validation error: {str(e)}") | |
return False, str(e), stats | |
class AudioProcessor: | |
def __init__(self, model): | |
self.model = model | |
self.validator = AudioValidator() | |
def process_audio(self, audio_file): | |
stats = { | |
'status': 'processing', | |
'start_time': datetime.now(), | |
'file_info': None, | |
'processing_time': None, | |
'error': None | |
} | |
try: | |
# Validate file | |
is_valid, message, file_stats = self.validator.validate_audio_file(audio_file) | |
stats['file_info'] = file_stats | |
if not is_valid: | |
stats['status'] = 'failed' | |
stats['error'] = message | |
return None, stats | |
# Process audio | |
with tempfile.NamedTemporaryFile(delete=False, suffix=file_stats['format']) as tmp_file: | |
tmp_file.write(audio_file.getvalue()) | |
tmp_file_path = tmp_file.name | |
try: | |
result = self.model.transcribe( | |
tmp_file_path, | |
language="en", | |
task="transcribe", | |
fp16=torch.cuda.is_available() | |
) | |
stats['status'] = 'success' | |
stats['processing_time'] = str(datetime.now() - stats['start_time']) | |
return result["text"], stats | |
finally: | |
os.unlink(tmp_file_path) | |
except Exception as e: | |
logger.error(f"Processing error: {str(e)}") | |
stats['status'] = 'failed' | |
stats['error'] = str(e) | |
return None, stats | |
finally: | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def load_whisper(): | |
try: | |
return whisper.load_model("base") | |
except Exception as e: | |
logger.error(f"Whisper model loading error: {str(e)}") | |
return None | |
def load_llm(): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_CONFIG["path"], | |
trust_remote_code=True | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_CONFIG["path"], | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
) | |
return pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.15, | |
batch_size=1 | |
) | |
except Exception as e: | |
logger.error(f"LLM loading error: {str(e)}") | |
return None | |
class ContentAnalyzer: | |
def __init__(self, generator): | |
self.generator = generator | |
def analyze_text(self, text, vc_style): | |
try: | |
prompt = self._create_analysis_prompt(text, vc_style) | |
response = self._generate_response(prompt) | |
return self._parse_response(response) | |
except Exception as e: | |
logger.error(f"Analysis error: {str(e)}") | |
return None | |
def _create_analysis_prompt(self, text, vc_style): | |
interests = ', '.join(vc_style.key_interests) | |
return f"""Analyze this startup pitch focusing on {interests}: | |
{text} | |
Provide structured insights for: | |
1. Key Points | |
2. Metrics | |
3. Risks | |
4. Questions""" | |
def _generate_response(self, prompt): | |
try: | |
response = self.generator(prompt) | |
return response[0]['generated_text'] | |
except Exception as e: | |
logger.error(f"Generation error: {str(e)}") | |
return "" | |
def _parse_response(self, response): | |
try: | |
sections = response.split('\n\n') | |
parsed = {} | |
current_section = "general" | |
for section in sections: | |
if section.strip().endswith(':'): | |
current_section = section.strip()[:-1].lower() | |
parsed[current_section] = [] | |
else: | |
if current_section in parsed: | |
parsed[current_section].append(section.strip()) | |
else: | |
parsed[current_section] = [section.strip()] | |
return parsed | |
except Exception as e: | |
logger.error(f"Parsing error: {str(e)}") | |
return {"error": "Failed to parse response"} | |
def setup_page(): | |
st.set_page_config( | |
page_title="VC Call Assistant", | |
page_icon="ποΈ", | |
layout="wide", | |
) | |
def show_file_uploader(): | |
st.markdown(""" | |
### π Upload Audio File | |
**Supported formats:** WAV, MP3, M4A | |
**Limits:** 25MB, 10 minutes, 16kHz min quality | |
""") | |
return st.file_uploader( | |
"Choose an audio file", | |
type=['wav', 'mp3', 'm4a'] | |
) | |
def show_processing_stats(stats): | |
if not stats: | |
return | |
st.markdown("### π Processing Information") | |
cols = st.columns(3) | |
if stats.get('file_info'): | |
with cols[0]: | |
st.metric("File Size", stats['file_info'].get('file_size', 'N/A')) | |
st.metric("Format", stats['file_info'].get('format', 'N/A')) | |
with cols[1]: | |
st.metric("Duration", stats['file_info'].get('duration', 'N/A')) | |
st.metric("Sample Rate", stats['file_info'].get('sample_rate', 'N/A')) | |
with cols[2]: | |
status = stats.get('status', 'unknown') | |
if status == 'success': | |
st.success(f"Processed in {stats.get('processing_time', 'N/A')}") | |
elif status == 'failed': | |
st.error(f"Failed: {stats.get('error', 'Unknown error')}") | |
else: | |
st.info("Processing...") | |
def main(): | |
try: | |
setup_page() | |
with st.sidebar: | |
st.title("VC Assistant Settings") | |
st.info(f"""Using GPT2 | |
Memory: {MODEL_CONFIG['memory_required']} | |
Info: {MODEL_CONFIG['description']}""") | |
vc_name = st.text_input("Your Name") | |
note_style = st.selectbox( | |
"Note Style", | |
["Bullet Points", "Paragraphs", "Q&A"] | |
) | |
interests = st.multiselect( | |
"Focus Areas", | |
["Product", "Market", "Team", "Financials", "Technology"], | |
default=["Product", "Market"] | |
) | |
st.title("ποΈ VC Call Assistant") | |
if not vc_name: | |
st.warning("Please enter your name in the sidebar.") | |
return | |
with st.spinner("Loading models..."): | |
whisper_model = load_whisper() | |
llm = load_llm() | |
if not whisper_model or not llm: | |
st.error("Failed to initialize models. Please refresh the page.") | |
return | |
audio_processor = AudioProcessor(whisper_model) | |
analyzer = ContentAnalyzer(llm) | |
audio_file = show_file_uploader() | |
if audio_file: | |
with st.spinner("Processing audio..."): | |
transcription, stats = audio_processor.process_audio(audio_file) | |
show_processing_stats(stats) | |
if transcription and stats['status'] == 'success': | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("π Transcript") | |
st.write(transcription) | |
with col2: | |
st.subheader("π Analysis") | |
with st.spinner("Analyzing transcript..."): | |
vc_style = VCStyle( | |
name=vc_name, | |
note_format={"style": note_style}, | |
key_interests=interests, | |
custom_sections=[], | |
insight_preferences={} | |
) | |
analysis = analyzer.analyze_text(transcription, vc_style) | |
if analysis: | |
st.write(analysis) | |
st.download_button( | |
"π₯ Export Analysis", | |
data=json.dumps({ | |
"timestamp": datetime.now().isoformat(), | |
"transcription": transcription, | |
"analysis": analysis, | |
"processing_stats": stats | |
}, indent=2), | |
file_name=f"vc_analysis_{datetime.now():%Y%m%d_%H%M%S}.json", | |
mime="application/json" | |
) | |
except Exception as e: | |
logger.error(f"Application error: {str(e)}") | |
st.error("An error occurred. Please refresh the page and try again.") | |
finally: | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
if __name__ == "__main__": | |
main() | |