invincible-jha commited on
Commit
2b82738
·
verified ·
1 Parent(s): 2db9a09

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -259
app.py CHANGED
@@ -1,67 +1,57 @@
1
- import streamlit as st
2
- import whisper
3
- import pandas as pd
4
- from datetime import datetime
5
- import tempfile
6
- import os
7
- import torch
8
- from transformers import (
9
- AutoModelForCausalLM,
10
- AutoTokenizer,
11
- pipeline,
12
- BitsAndBytesConfig
13
- )
14
- import gc
15
- from typing import Optional, Dict, Any, List
16
- import logging
17
- import json
18
- import numpy as np
19
- from dataclasses import dataclass, asdict
20
- from queue import Queue
21
- import threading
22
- from collections import defaultdict
23
 
24
- # Configure logging
25
- logging.basicConfig(level=logging.INFO)
26
- logger = logging.getLogger(__name__)
27
-
28
- # Constants for memory optimization
29
- CHUNK_SIZE = 30 # seconds
30
- MAX_AUDIO_LENGTH = 600 # seconds (10 minutes)
31
- BATCH_SIZE = 8
32
-
33
- # Model configurations with memory optimization
34
  MODEL_CONFIGS = {
35
- "FLAN-T5-Large": {
36
- "path": "google/flan-t5-large",
37
  "description": "Efficient open-source model for analysis",
38
  "memory_required": "8GB"
39
  },
40
- "OpenAssistant": {
41
- "path": "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
42
- "description": "Powerful open-source assistant model",
43
  "memory_required": "12GB"
44
  }
45
  }
46
 
47
- @dataclass
48
- class VCStyle:
49
- """Store VC's personal style preferences"""
50
- name: str
51
- note_format: Dict[str, Any]
52
- key_interests: List[str]
53
- custom_sections: List[str]
54
- insight_preferences: Dict[str, float]
55
-
56
- @dataclass
57
- class LiveCallContext:
58
- """Store context for live calls"""
59
- meeting_id: str
60
- participants: List[str]
61
- topics: List[str]
62
- key_points: List[str]
63
- questions_asked: List[str]
64
- action_items: List[str]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  class ModelManager:
67
  """Handles model loading and resource management"""
@@ -111,131 +101,15 @@ class ModelManager:
111
  temperature=0.7,
112
  top_p=0.95,
113
  repetition_penalty=1.15,
114
- batch_size=BATCH_SIZE
115
  )
116
 
117
  return pipe
118
 
119
  except Exception as e:
120
  logger.error(f"Failed to load LLM {model_name}: {e}")
121
- st.error("Failed to load language model. Please try again.")
122
- return None
123
-
124
- class AudioProcessor:
125
- """Handles audio processing with memory optimization"""
126
-
127
- def __init__(self, model):
128
- self.model = model
129
- self.chunk_queue = Queue()
130
-
131
- def process_audio_chunk(self, audio_chunk) -> Optional[str]:
132
- try:
133
- # Clear GPU memory before processing
134
- if torch.cuda.is_available():
135
- torch.cuda.empty_cache()
136
-
137
- result = self.model.transcribe(
138
- audio_chunk,
139
- language="en",
140
- task="transcribe",
141
- fp16=True # Use half precision
142
- )
143
- return result["text"]
144
-
145
- except Exception as e:
146
- logger.error(f"Error processing audio chunk: {e}")
147
  return None
148
- finally:
149
- # Cleanup
150
- gc.collect()
151
- if torch.cuda.is_available():
152
- torch.cuda.empty_cache()
153
-
154
- class ContentAnalyzer:
155
- """Handles text analysis with optimized prompts"""
156
-
157
- def __init__(self, generator):
158
- self.generator = generator
159
-
160
- def analyze_text(self, text: str, vc_style: VCStyle) -> Optional[Dict[str, Any]]:
161
- try:
162
- prompt = self._create_analysis_prompt(text, vc_style)
163
- response = self._generate_response(prompt, max_length=512)
164
- return self._parse_response(response)
165
- except Exception as e:
166
- logger.error(f"Analysis error: {e}")
167
- return None
168
-
169
- def _create_analysis_prompt(self, text: str, vc_style: VCStyle) -> str:
170
- return f"""Analyze this startup pitch focusing on {', '.join(vc_style.key_interests)}:
171
-
172
- {text}
173
-
174
- Provide structured insights for:
175
- 1. Key Points
176
- 2. Metrics
177
- 3. Risks
178
- 4. Questions"""
179
-
180
- def _generate_response(self, prompt: str, max_length: int) -> str:
181
- try:
182
- response = self.generator(
183
- prompt,
184
- max_new_tokens=max_length,
185
- temperature=0.7,
186
- top_p=0.95,
187
- repetition_penalty=1.15
188
- )
189
- return response[0]['generated_text']
190
- except Exception as e:
191
- logger.error(f"Generation error: {e}")
192
- return ""
193
-
194
- def _parse_response(self, response: str) -> Dict[str, Any]:
195
- try:
196
- # Simple parsing of the response into sections
197
- sections = response.split('\n\n')
198
- parsed_response = {}
199
- current_section = "general"
200
-
201
- for section in sections:
202
- if section.strip().endswith(':'):
203
- current_section = section.strip()[:-1].lower()
204
- parsed_response[current_section] = []
205
- else:
206
- if current_section in parsed_response:
207
- parsed_response[current_section].append(section.strip())
208
- else:
209
- parsed_response[current_section] = [section.strip()]
210
-
211
- return parsed_response
212
- except Exception as e:
213
- logger.error(f"Parsing error: {e}")
214
- return {"error": "Failed to parse response"}
215
-
216
- class UIManager:
217
- """Manages Streamlit UI with performance optimization"""
218
-
219
- @staticmethod
220
- def setup_page():
221
- st.set_page_config(
222
- page_title="VC Call Assistant",
223
- page_icon="🎙️",
224
- layout="wide",
225
- initial_sidebar_state="expanded"
226
- )
227
-
228
- @staticmethod
229
- def show_file_uploader() -> Optional[Any]:
230
- return st.file_uploader(
231
- "Upload Audio (Max 10 minutes)",
232
- type=['wav', 'mp3', 'm4a'],
233
- help="Supports WAV, MP3, M4A formats. Maximum duration: 10 minutes."
234
- )
235
-
236
- @staticmethod
237
- def show_progress(text: str) -> Any:
238
- return st.progress(0, text=text)
239
 
240
  def main():
241
  try:
@@ -256,89 +130,4 @@ def main():
256
  Memory Usage: {MODEL_CONFIGS[model_name]['memory_required']}
257
  Description: {MODEL_CONFIGS[model_name]['description']}""")
258
 
259
- # VC Profile
260
- vc_name = st.text_input("Your Name")
261
- note_style = st.selectbox(
262
- "Note Style",
263
- ["Bullet Points", "Paragraphs", "Q&A"]
264
- )
265
-
266
- interests = st.multiselect(
267
- "Focus Areas",
268
- ["Product", "Market", "Team", "Financials", "Technology"],
269
- default=["Product", "Market"]
270
- )
271
-
272
- # Main content
273
- st.title("🎙️ VC Call Assistant")
274
-
275
- if not vc_name:
276
- st.warning("Please enter your name in the sidebar.")
277
- return
278
-
279
- # Initialize processors
280
- with st.spinner("Loading models..."):
281
- whisper_model = ModelManager.load_whisper()
282
- llm = ModelManager.load_llm(model_name)
283
-
284
- if not whisper_model or not llm:
285
- st.error("Failed to initialize models. Please refresh the page.")
286
- return
287
-
288
- audio_processor = AudioProcessor(whisper_model)
289
- analyzer = ContentAnalyzer(llm)
290
-
291
- # File upload
292
- audio_file = UIManager.show_file_uploader()
293
-
294
- if audio_file:
295
- # Process audio
296
- with st.spinner("Processing audio..."):
297
- transcription = audio_processor.process_audio_chunk(audio_file)
298
-
299
- if transcription:
300
- # Display results
301
- col1, col2 = st.columns(2)
302
-
303
- with col1:
304
- st.subheader("📝 Transcript")
305
- st.write(transcription)
306
-
307
- with col2:
308
- st.subheader("🔍 Analysis")
309
- vc_style = VCStyle(
310
- name=vc_name,
311
- note_format={"style": note_style},
312
- key_interests=interests,
313
- custom_sections=[],
314
- insight_preferences={}
315
- )
316
-
317
- analysis = analyzer.analyze_text(transcription, vc_style)
318
- if analysis:
319
- st.write(analysis)
320
-
321
- # Export option
322
- st.download_button(
323
- "📥 Export Analysis",
324
- data=json.dumps({
325
- "timestamp": datetime.now().isoformat(),
326
- "transcription": transcription,
327
- "analysis": analysis
328
- }, indent=2),
329
- file_name=f"vc_analysis_{datetime.now():%Y%m%d_%H%M%S}.json",
330
- mime="application/json"
331
- )
332
-
333
- except Exception as e:
334
- logger.error(f"Application error: {e}")
335
- st.error("An unexpected error occurred. Please refresh the page.")
336
-
337
- finally:
338
- # Cleanup
339
- gc.collect()
340
- if torch.cuda.is_available():
341
- torch.cuda.empty_cache()
342
-
343
- if __name__ == "__main__":
344
- main()
 
1
+ # Only showing the modified sections for brevity. The rest remains the same.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ # Update MODEL_CONFIGS to use appropriate models
 
 
 
 
 
 
 
 
 
4
  MODEL_CONFIGS = {
5
+ "GPT2": {
6
+ "path": "gpt2",
7
  "description": "Efficient open-source model for analysis",
8
  "memory_required": "8GB"
9
  },
10
+ "GPT-Neo": {
11
+ "path": "EleutherAI/gpt-neo-1.3B",
12
+ "description": "Powerful open-source model",
13
  "memory_required": "12GB"
14
  }
15
  }
16
 
17
+ class AudioProcessor:
18
+ """Handles audio processing with memory optimization"""
19
+
20
+ def __init__(self, model):
21
+ self.model = model
22
+
23
+ def process_audio_chunk(self, audio_file) -> Optional[str]:
24
+ try:
25
+ # Clear GPU memory before processing
26
+ if torch.cuda.is_available():
27
+ torch.cuda.empty_cache()
28
+
29
+ # Save the uploaded file temporarily
30
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
31
+ tmp_file.write(audio_file.read())
32
+ tmp_file_path = tmp_file.name
33
+
34
+ # Process the audio file
35
+ result = self.model.transcribe(
36
+ tmp_file_path,
37
+ language="en",
38
+ task="transcribe",
39
+ fp16=True if torch.cuda.is_available() else False
40
+ )
41
+
42
+ # Cleanup
43
+ os.unlink(tmp_file_path)
44
+
45
+ return result["text"]
46
+
47
+ except Exception as e:
48
+ logger.error(f"Error processing audio chunk: {e}")
49
+ return None
50
+ finally:
51
+ # Cleanup
52
+ gc.collect()
53
+ if torch.cuda.is_available():
54
+ torch.cuda.empty_cache()
55
 
56
  class ModelManager:
57
  """Handles model loading and resource management"""
 
101
  temperature=0.7,
102
  top_p=0.95,
103
  repetition_penalty=1.15,
104
+ batch_size=1 # Reduced for stability
105
  )
106
 
107
  return pipe
108
 
109
  except Exception as e:
110
  logger.error(f"Failed to load LLM {model_name}: {e}")
111
+ st.error(f"Failed to load language model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  def main():
115
  try:
 
130
  Memory Usage: {MODEL_CONFIGS[model_name]['memory_required']}
131
  Description: {MODEL_CONFIGS[model_name]['description']}""")
132
 
133
+ # Rest of the sidebar code remains the same as before...