invincible-jha commited on
Commit
fb59a4d
Β·
verified Β·
1 Parent(s): 258e204

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +232 -165
  2. requirements.txt +1 -4
app.py CHANGED
@@ -1,219 +1,308 @@
1
- # Add these imports at the top
2
- import soundfile as sf
3
- import librosa
 
 
 
4
  from pathlib import Path
 
 
 
 
 
 
 
5
  import humanize
6
- from datetime import timedelta
7
 
8
- # Add these constants
 
 
 
 
 
 
 
9
  MAX_FILE_SIZE = 25 * 1024 * 1024 # 25MB
10
- MAX_AUDIO_DURATION = 600 # 10 minutes in seconds
11
- SUPPORTED_FORMATS = {
12
- '.wav': 'WAV audio',
13
- '.mp3': 'MP3 audio',
14
- '.m4a': 'M4A audio'
 
 
 
 
15
  }
16
 
 
 
 
 
 
 
 
 
17
  class AudioValidator:
18
- """Handles audio file validation and provides detailed feedback"""
19
-
20
  @staticmethod
21
- def validate_audio_file(file) -> tuple[bool, str]:
 
 
 
 
 
 
 
22
  try:
23
- # Check if file is provided
24
  if file is None:
25
- return False, "No file was uploaded."
26
 
27
  # Check file size
28
  file_size = len(file.getvalue())
 
 
29
  if file_size > MAX_FILE_SIZE:
30
- readable_size = humanize.naturalsize(file_size)
31
- max_size = humanize.naturalsize(MAX_FILE_SIZE)
32
- return False, f"File size ({readable_size}) exceeds maximum allowed size ({max_size})"
33
 
34
  # Check file extension
35
  file_extension = Path(file.name).suffix.lower()
 
 
36
  if file_extension not in SUPPORTED_FORMATS:
37
- return False, f"Unsupported file format. Please upload {', '.join(SUPPORTED_FORMATS.values())}"
38
 
39
- # Save file temporarily for duration check
40
  with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
41
  tmp_file.write(file.getvalue())
42
  tmp_file_path = tmp_file.name
43
 
44
  try:
45
- # Check audio duration
46
- duration = librosa.get_duration(path=tmp_file_path)
 
 
 
 
 
 
 
47
  if duration > MAX_AUDIO_DURATION:
48
- formatted_duration = str(timedelta(seconds=int(duration)))
49
- max_duration = str(timedelta(seconds=MAX_AUDIO_DURATION))
50
- return False, f"Audio duration ({formatted_duration}) exceeds maximum allowed length ({max_duration})"
51
 
52
- # Check audio quality
53
- y, sr = librosa.load(tmp_file_path)
54
- if sr < 16000:
55
- return False, f"Audio quality too low. Sample rate ({sr}Hz) should be at least 16kHz"
56
 
57
- return True, "Audio file is valid"
58
 
59
  finally:
60
  os.unlink(tmp_file_path)
61
 
62
  except Exception as e:
63
- logger.error(f"Audio validation error: {str(e)}")
64
- return False, f"Error validating audio file: {str(e)}"
65
 
66
  class AudioProcessor:
67
- """Enhanced audio processor with better feedback and error handling"""
68
-
69
  def __init__(self, model):
70
  self.model = model
71
  self.validator = AudioValidator()
72
-
73
- def process_audio_chunk(self, audio_file) -> tuple[Optional[str], Dict[str, Any]]:
74
- processing_stats = {
75
- 'duration': None,
76
- 'sample_rate': None,
77
- 'file_size': None,
78
  'processing_time': None,
79
- 'status': 'pending'
80
  }
81
 
82
  try:
83
- start_time = datetime.now()
84
-
85
  # Validate file
86
- is_valid, validation_message = self.validator.validate_audio_file(audio_file)
 
 
87
  if not is_valid:
88
- processing_stats['status'] = 'failed'
89
- processing_stats['error'] = validation_message
90
- return None, processing_stats
91
-
92
- # Get file stats
93
- file_size = len(audio_file.getvalue())
94
- processing_stats['file_size'] = humanize.naturalsize(file_size)
95
 
96
  # Process audio
97
- file_extension = Path(audio_file.name).suffix.lower()
98
- with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
99
- audio_file.seek(0)
100
  tmp_file.write(audio_file.getvalue())
101
  tmp_file_path = tmp_file.name
102
 
103
  try:
104
- # Get audio info
105
- y, sr = librosa.load(tmp_file_path)
106
- duration = librosa.get_duration(y=y, sr=sr)
107
- processing_stats.update({
108
- 'duration': str(timedelta(seconds=int(duration))),
109
- 'sample_rate': f"{sr/1000:.1f}kHz"
110
- })
111
-
112
- # Transcribe audio
113
  result = self.model.transcribe(
114
  tmp_file_path,
115
  language="en",
116
  task="transcribe",
117
- fp16=True if torch.cuda.is_available() else False
118
  )
119
-
120
- # Update stats
121
- processing_time = (datetime.now() - start_time).total_seconds()
122
- processing_stats.update({
123
- 'processing_time': f"{processing_time:.1f}s",
124
- 'status': 'success'
125
- })
126
-
127
- return result["text"], processing_stats
128
 
129
  finally:
130
- if os.path.exists(tmp_file_path):
131
- os.unlink(tmp_file_path)
132
 
133
  except Exception as e:
134
- error_message = str(e)
135
- logger.error(f"Audio processing error: {error_message}")
136
- processing_stats.update({
137
- 'status': 'failed',
138
- 'error': error_message
139
- })
140
- return None, processing_stats
141
  finally:
142
- gc.collect()
143
  if torch.cuda.is_available():
144
  torch.cuda.empty_cache()
 
145
 
146
- class UIManager:
147
- """Enhanced UI manager with better feedback and progress indicators"""
148
-
149
- @staticmethod
150
- def setup_page():
151
- st.set_page_config(
152
- page_title="VC Call Assistant",
153
- page_icon="πŸŽ™οΈ",
154
- layout="wide",
155
- initial_sidebar_state="expanded"
 
 
 
 
156
  )
157
-
158
- @staticmethod
159
- def show_file_uploader() -> Optional[Any]:
160
- st.markdown("""
161
- ### πŸ“ Upload Audio File
162
-
163
- **Supported formats:**
164
- - WAV (recommended)
165
- - MP3
166
- - M4A
167
 
168
- **Limitations:**
169
- - Maximum file size: 25MB
170
- - Maximum duration: 10 minutes
171
- - Minimum sample rate: 16kHz
172
- """)
 
173
 
174
- return st.file_uploader(
175
- "Choose an audio file",
176
- type=['wav', 'mp3', 'm4a']
 
 
 
 
 
 
177
  )
178
-
179
- @staticmethod
180
- def show_processing_stats(stats: Dict[str, Any]):
181
- """Display processing statistics in a nice format"""
182
- if not stats:
183
- return
 
 
184
 
185
- st.markdown("### πŸ“Š Processing Statistics")
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- col1, col2, col3 = st.columns(3)
188
 
189
- with col1:
190
- st.metric("Duration", stats.get('duration', 'N/A'))
191
- st.metric("File Size", stats.get('file_size', 'N/A'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
- with col2:
194
- st.metric("Sample Rate", stats.get('sample_rate', 'N/A'))
195
- st.metric("Processing Time", stats.get('processing_time', 'N/A'))
196
 
197
- with col3:
198
  status = stats.get('status', 'unknown')
199
  if status == 'success':
200
- st.success("Processing Completed")
201
  elif status == 'failed':
202
- st.error(f"Processing Failed: {stats.get('error', 'Unknown error')}")
203
  else:
204
- st.info("Processing Pending")
205
 
206
  def main():
207
  try:
208
- UIManager.setup_page()
209
 
210
  with st.sidebar:
211
  st.title("VC Assistant Settings")
212
- model_name = "GPT2"
213
 
214
- st.info(f"""Using {model_name}
215
- Memory Usage: {MODEL_CONFIGS[model_name]['memory_required']}
216
- Description: {MODEL_CONFIGS[model_name]['description']}""")
217
 
218
  vc_name = st.text_input("Your Name")
219
  note_style = st.selectbox(
@@ -233,40 +322,25 @@ def main():
233
  st.warning("Please enter your name in the sidebar.")
234
  return
235
 
236
- # Initialize models with progress tracking
237
- progress_text = "Loading models..."
238
- progress_bar = st.progress(0, text=progress_text)
239
-
240
- try:
241
- progress_bar.progress(25, text="Loading Whisper model...")
242
- whisper_model = ModelManager.load_whisper()
243
-
244
- progress_bar.progress(50, text="Loading language model...")
245
- llm = ModelManager.load_llm(model_name)
246
 
247
  if not whisper_model or not llm:
248
  st.error("Failed to initialize models. Please refresh the page.")
249
  return
250
 
251
- progress_bar.progress(75, text="Initializing processors...")
252
  audio_processor = AudioProcessor(whisper_model)
253
  analyzer = ContentAnalyzer(llm)
254
-
255
- progress_bar.progress(100, text="Ready!")
256
- finally:
257
- progress_bar.empty()
258
 
259
- # File upload and processing
260
- audio_file = UIManager.show_file_uploader()
261
 
262
  if audio_file:
263
  with st.spinner("Processing audio..."):
264
- transcription, processing_stats = audio_processor.process_audio_chunk(audio_file)
 
265
 
266
- # Show processing statistics
267
- UIManager.show_processing_stats(processing_stats)
268
-
269
- if transcription:
270
  col1, col2 = st.columns(2)
271
 
272
  with col1:
@@ -294,7 +368,7 @@ def main():
294
  "timestamp": datetime.now().isoformat(),
295
  "transcription": transcription,
296
  "analysis": analysis,
297
- "processing_stats": processing_stats
298
  }, indent=2),
299
  file_name=f"vc_analysis_{datetime.now():%Y%m%d_%H%M%S}.json",
300
  mime="application/json"
@@ -302,14 +376,7 @@ def main():
302
 
303
  except Exception as e:
304
  logger.error(f"Application error: {str(e)}")
305
- st.error(f"""
306
- An unexpected error occurred: {str(e)}
307
-
308
- Please try:
309
- 1. Refreshing the page
310
- 2. Using a different audio file
311
- 3. Checking your internet connection
312
- """)
313
 
314
  finally:
315
  gc.collect()
 
1
+ import os
2
+ import gc
3
+ import json
4
+ import logging
5
+ import tempfile
6
+ from datetime import datetime, timedelta
7
  from pathlib import Path
8
+ from dataclasses import dataclass
9
+ import streamlit as st
10
+ import whisper
11
+ import torch
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
13
+ import numpy as np
14
+ import librosa
15
  import humanize
 
16
 
17
+ # Configure logging
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Constants
25
  MAX_FILE_SIZE = 25 * 1024 * 1024 # 25MB
26
+ MAX_AUDIO_DURATION = 600 # 10 minutes
27
+ MIN_SAMPLE_RATE = 16000 # 16kHz
28
+ SUPPORTED_FORMATS = {'.wav', '.mp3', '.m4a'}
29
+
30
+ # Model configuration
31
+ MODEL_CONFIG = {
32
+ "path": "gpt2",
33
+ "description": "Efficient open-source model for analysis",
34
+ "memory_required": "8GB"
35
  }
36
 
37
+ @dataclass
38
+ class VCStyle:
39
+ name: str
40
+ note_format: dict
41
+ key_interests: list
42
+ custom_sections: list
43
+ insight_preferences: dict
44
+
45
  class AudioValidator:
 
 
46
  @staticmethod
47
+ def validate_audio_file(file):
48
+ stats = {
49
+ 'file_size': None,
50
+ 'duration': None,
51
+ 'sample_rate': None,
52
+ 'format': None
53
+ }
54
+
55
  try:
 
56
  if file is None:
57
+ return False, "No file was uploaded.", stats
58
 
59
  # Check file size
60
  file_size = len(file.getvalue())
61
+ stats['file_size'] = humanize.naturalsize(file_size)
62
+
63
  if file_size > MAX_FILE_SIZE:
64
+ return False, f"File size ({stats['file_size']}) exceeds limit", stats
 
 
65
 
66
  # Check file extension
67
  file_extension = Path(file.name).suffix.lower()
68
+ stats['format'] = file_extension
69
+
70
  if file_extension not in SUPPORTED_FORMATS:
71
+ return False, f"Unsupported format {file_extension}", stats
72
 
73
+ # Create temporary file
74
  with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
75
  tmp_file.write(file.getvalue())
76
  tmp_file_path = tmp_file.name
77
 
78
  try:
79
+ # Check audio properties
80
+ y, sr = librosa.load(tmp_file_path, sr=None)
81
+ duration = librosa.get_duration(y=y, sr=sr)
82
+
83
+ stats.update({
84
+ 'duration': str(timedelta(seconds=int(duration))),
85
+ 'sample_rate': f"{sr/1000:.1f}kHz"
86
+ })
87
+
88
  if duration > MAX_AUDIO_DURATION:
89
+ return False, f"Duration ({stats['duration']}) exceeds limit", stats
 
 
90
 
91
+ if sr < MIN_SAMPLE_RATE:
92
+ return False, f"Sample rate too low ({stats['sample_rate']})", stats
 
 
93
 
94
+ return True, "Audio file is valid", stats
95
 
96
  finally:
97
  os.unlink(tmp_file_path)
98
 
99
  except Exception as e:
100
+ logger.error(f"Validation error: {str(e)}")
101
+ return False, str(e), stats
102
 
103
  class AudioProcessor:
 
 
104
  def __init__(self, model):
105
  self.model = model
106
  self.validator = AudioValidator()
107
+
108
+ def process_audio(self, audio_file):
109
+ stats = {
110
+ 'status': 'processing',
111
+ 'start_time': datetime.now(),
112
+ 'file_info': None,
113
  'processing_time': None,
114
+ 'error': None
115
  }
116
 
117
  try:
 
 
118
  # Validate file
119
+ is_valid, message, file_stats = self.validator.validate_audio_file(audio_file)
120
+ stats['file_info'] = file_stats
121
+
122
  if not is_valid:
123
+ stats['status'] = 'failed'
124
+ stats['error'] = message
125
+ return None, stats
 
 
 
 
126
 
127
  # Process audio
128
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_stats['format']) as tmp_file:
 
 
129
  tmp_file.write(audio_file.getvalue())
130
  tmp_file_path = tmp_file.name
131
 
132
  try:
 
 
 
 
 
 
 
 
 
133
  result = self.model.transcribe(
134
  tmp_file_path,
135
  language="en",
136
  task="transcribe",
137
+ fp16=torch.cuda.is_available()
138
  )
139
+
140
+ stats['status'] = 'success'
141
+ stats['processing_time'] = str(datetime.now() - stats['start_time'])
142
+ return result["text"], stats
 
 
 
 
 
143
 
144
  finally:
145
+ os.unlink(tmp_file_path)
 
146
 
147
  except Exception as e:
148
+ logger.error(f"Processing error: {str(e)}")
149
+ stats['status'] = 'failed'
150
+ stats['error'] = str(e)
151
+ return None, stats
152
+
 
 
153
  finally:
 
154
  if torch.cuda.is_available():
155
  torch.cuda.empty_cache()
156
+ gc.collect()
157
 
158
+ @st.cache_resource
159
+ def load_whisper():
160
+ try:
161
+ return whisper.load_model("base")
162
+ except Exception as e:
163
+ logger.error(f"Whisper model loading error: {str(e)}")
164
+ return None
165
+
166
+ @st.cache_resource
167
+ def load_llm():
168
+ try:
169
+ tokenizer = AutoTokenizer.from_pretrained(
170
+ MODEL_CONFIG["path"],
171
+ trust_remote_code=True
172
  )
 
 
 
 
 
 
 
 
 
 
173
 
174
+ model = AutoModelForCausalLM.from_pretrained(
175
+ MODEL_CONFIG["path"],
176
+ device_map="auto",
177
+ torch_dtype=torch.float16,
178
+ low_cpu_mem_usage=True
179
+ )
180
 
181
+ return pipeline(
182
+ "text-generation",
183
+ model=model,
184
+ tokenizer=tokenizer,
185
+ max_new_tokens=512,
186
+ temperature=0.7,
187
+ top_p=0.95,
188
+ repetition_penalty=1.15,
189
+ batch_size=1
190
  )
191
+
192
+ except Exception as e:
193
+ logger.error(f"LLM loading error: {str(e)}")
194
+ return None
195
+
196
+ class ContentAnalyzer:
197
+ def __init__(self, generator):
198
+ self.generator = generator
199
 
200
+ def analyze_text(self, text, vc_style):
201
+ try:
202
+ prompt = self._create_analysis_prompt(text, vc_style)
203
+ response = self._generate_response(prompt)
204
+ return self._parse_response(response)
205
+ except Exception as e:
206
+ logger.error(f"Analysis error: {str(e)}")
207
+ return None
208
+
209
+ def _create_analysis_prompt(self, text, vc_style):
210
+ interests = ', '.join(vc_style.key_interests)
211
+ return f"""Analyze this startup pitch focusing on {interests}:
212
 
213
+ {text}
214
 
215
+ Provide structured insights for:
216
+ 1. Key Points
217
+ 2. Metrics
218
+ 3. Risks
219
+ 4. Questions"""
220
+
221
+ def _generate_response(self, prompt):
222
+ try:
223
+ response = self.generator(prompt)
224
+ return response[0]['generated_text']
225
+ except Exception as e:
226
+ logger.error(f"Generation error: {str(e)}")
227
+ return ""
228
+
229
+ def _parse_response(self, response):
230
+ try:
231
+ sections = response.split('\n\n')
232
+ parsed = {}
233
+ current_section = "general"
234
+
235
+ for section in sections:
236
+ if section.strip().endswith(':'):
237
+ current_section = section.strip()[:-1].lower()
238
+ parsed[current_section] = []
239
+ else:
240
+ if current_section in parsed:
241
+ parsed[current_section].append(section.strip())
242
+ else:
243
+ parsed[current_section] = [section.strip()]
244
+
245
+ return parsed
246
+ except Exception as e:
247
+ logger.error(f"Parsing error: {str(e)}")
248
+ return {"error": "Failed to parse response"}
249
+
250
+ def setup_page():
251
+ st.set_page_config(
252
+ page_title="VC Call Assistant",
253
+ page_icon="πŸŽ™οΈ",
254
+ layout="wide",
255
+ )
256
+
257
+ def show_file_uploader():
258
+ st.markdown("""
259
+ ### πŸ“ Upload Audio File
260
+
261
+ **Supported formats:** WAV, MP3, M4A
262
+ **Limits:** 25MB, 10 minutes, 16kHz min quality
263
+ """)
264
+
265
+ return st.file_uploader(
266
+ "Choose an audio file",
267
+ type=['wav', 'mp3', 'm4a']
268
+ )
269
+
270
+ def show_processing_stats(stats):
271
+ if not stats:
272
+ return
273
+
274
+ st.markdown("### πŸ“Š Processing Information")
275
+
276
+ cols = st.columns(3)
277
+
278
+ if stats.get('file_info'):
279
+ with cols[0]:
280
+ st.metric("File Size", stats['file_info'].get('file_size', 'N/A'))
281
+ st.metric("Format", stats['file_info'].get('format', 'N/A'))
282
 
283
+ with cols[1]:
284
+ st.metric("Duration", stats['file_info'].get('duration', 'N/A'))
285
+ st.metric("Sample Rate", stats['file_info'].get('sample_rate', 'N/A'))
286
 
287
+ with cols[2]:
288
  status = stats.get('status', 'unknown')
289
  if status == 'success':
290
+ st.success(f"Processed in {stats.get('processing_time', 'N/A')}")
291
  elif status == 'failed':
292
+ st.error(f"Failed: {stats.get('error', 'Unknown error')}")
293
  else:
294
+ st.info("Processing...")
295
 
296
  def main():
297
  try:
298
+ setup_page()
299
 
300
  with st.sidebar:
301
  st.title("VC Assistant Settings")
 
302
 
303
+ st.info(f"""Using GPT2
304
+ Memory: {MODEL_CONFIG['memory_required']}
305
+ Info: {MODEL_CONFIG['description']}""")
306
 
307
  vc_name = st.text_input("Your Name")
308
  note_style = st.selectbox(
 
322
  st.warning("Please enter your name in the sidebar.")
323
  return
324
 
325
+ with st.spinner("Loading models..."):
326
+ whisper_model = load_whisper()
327
+ llm = load_llm()
 
 
 
 
 
 
 
328
 
329
  if not whisper_model or not llm:
330
  st.error("Failed to initialize models. Please refresh the page.")
331
  return
332
 
 
333
  audio_processor = AudioProcessor(whisper_model)
334
  analyzer = ContentAnalyzer(llm)
 
 
 
 
335
 
336
+ audio_file = show_file_uploader()
 
337
 
338
  if audio_file:
339
  with st.spinner("Processing audio..."):
340
+ transcription, stats = audio_processor.process_audio(audio_file)
341
+ show_processing_stats(stats)
342
 
343
+ if transcription and stats['status'] == 'success':
 
 
 
344
  col1, col2 = st.columns(2)
345
 
346
  with col1:
 
368
  "timestamp": datetime.now().isoformat(),
369
  "transcription": transcription,
370
  "analysis": analysis,
371
+ "processing_stats": stats
372
  }, indent=2),
373
  file_name=f"vc_analysis_{datetime.now():%Y%m%d_%H%M%S}.json",
374
  mime="application/json"
 
376
 
377
  except Exception as e:
378
  logger.error(f"Application error: {str(e)}")
379
+ st.error("An error occurred. Please refresh the page and try again.")
 
 
 
 
 
 
 
380
 
381
  finally:
382
  gc.collect()
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  streamlit==1.24.0
2
- whisper-openai==1.0.0
3
  pandas==1.5.3
4
  numpy==1.23.5
5
  torch==2.0.1
@@ -9,8 +9,5 @@ bitsandbytes==0.41.1
9
  scipy==1.11.3
10
  sentencepiece==0.1.99
11
  huggingface-hub==0.19.4
12
- python-dotenv==1.0.0
13
- dataclasses-json==0.5.7
14
  librosa==0.10.1
15
- soundfile==0.12.1
16
  humanize==4.7.0
 
1
  streamlit==1.24.0
2
+ openai-whisper==20231117
3
  pandas==1.5.3
4
  numpy==1.23.5
5
  torch==2.0.1
 
9
  scipy==1.11.3
10
  sentencepiece==0.1.99
11
  huggingface-hub==0.19.4
 
 
12
  librosa==0.10.1
 
13
  humanize==4.7.0