invincible-jha commited on
Commit
3c834a4
Β·
verified Β·
1 Parent(s): 2b82738

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -13
app.py CHANGED
@@ -1,6 +1,36 @@
1
- # Only showing the modified sections for brevity. The rest remains the same.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- # Update MODEL_CONFIGS to use appropriate models
 
 
 
 
 
 
 
 
 
4
  MODEL_CONFIGS = {
5
  "GPT2": {
6
  "path": "gpt2",
@@ -14,6 +44,15 @@ MODEL_CONFIGS = {
14
  }
15
  }
16
 
 
 
 
 
 
 
 
 
 
17
  class AudioProcessor:
18
  """Handles audio processing with memory optimization"""
19
 
@@ -72,14 +111,6 @@ class ModelManager:
72
  try:
73
  config = MODEL_CONFIGS[model_name]
74
 
75
- # Optimized quantization config
76
- bnb_config = BitsAndBytesConfig(
77
- load_in_4bit=True,
78
- bnb_4bit_quant_type="nf4",
79
- bnb_4bit_compute_dtype=torch.float16,
80
- bnb_4bit_use_double_quant=True,
81
- )
82
-
83
  tokenizer = AutoTokenizer.from_pretrained(
84
  config["path"],
85
  trust_remote_code=True
@@ -87,7 +118,6 @@ class ModelManager:
87
 
88
  model = AutoModelForCausalLM.from_pretrained(
89
  config["path"],
90
- quantization_config=bnb_config,
91
  device_map="auto",
92
  torch_dtype=torch.float16,
93
  low_cpu_mem_usage=True
@@ -101,7 +131,7 @@ class ModelManager:
101
  temperature=0.7,
102
  top_p=0.95,
103
  repetition_penalty=1.15,
104
- batch_size=1 # Reduced for stability
105
  )
106
 
107
  return pipe
@@ -111,6 +141,87 @@ class ModelManager:
111
  st.error(f"Failed to load language model: {str(e)}")
112
  return None
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  def main():
115
  try:
116
  # Initialize UI
@@ -130,4 +241,89 @@ def main():
130
  Memory Usage: {MODEL_CONFIGS[model_name]['memory_required']}
131
  Description: {MODEL_CONFIGS[model_name]['description']}""")
132
 
133
- # Rest of the sidebar code remains the same as before...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import whisper
3
+ import pandas as pd
4
+ from datetime import datetime
5
+ import tempfile
6
+ import os
7
+ import torch
8
+ from transformers import (
9
+ AutoModelForCausalLM,
10
+ AutoTokenizer,
11
+ pipeline,
12
+ BitsAndBytesConfig
13
+ )
14
+ import gc
15
+ from typing import Optional, Dict, Any, List
16
+ import logging
17
+ import json
18
+ import numpy as np
19
+ from dataclasses import dataclass, asdict
20
+ from queue import Queue
21
+ import threading
22
+ from collections import defaultdict
23
 
24
+ # Configure logging
25
+ logging.basicConfig(level=logging.INFO)
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Constants for memory optimization
29
+ CHUNK_SIZE = 30 # seconds
30
+ MAX_AUDIO_LENGTH = 600 # seconds (10 minutes)
31
+ BATCH_SIZE = 1
32
+
33
+ # Model configurations with memory optimization
34
  MODEL_CONFIGS = {
35
  "GPT2": {
36
  "path": "gpt2",
 
44
  }
45
  }
46
 
47
+ @dataclass
48
+ class VCStyle:
49
+ """Store VC's personal style preferences"""
50
+ name: str
51
+ note_format: Dict[str, Any]
52
+ key_interests: List[str]
53
+ custom_sections: List[str]
54
+ insight_preferences: Dict[str, float]
55
+
56
  class AudioProcessor:
57
  """Handles audio processing with memory optimization"""
58
 
 
111
  try:
112
  config = MODEL_CONFIGS[model_name]
113
 
 
 
 
 
 
 
 
 
114
  tokenizer = AutoTokenizer.from_pretrained(
115
  config["path"],
116
  trust_remote_code=True
 
118
 
119
  model = AutoModelForCausalLM.from_pretrained(
120
  config["path"],
 
121
  device_map="auto",
122
  torch_dtype=torch.float16,
123
  low_cpu_mem_usage=True
 
131
  temperature=0.7,
132
  top_p=0.95,
133
  repetition_penalty=1.15,
134
+ batch_size=1
135
  )
136
 
137
  return pipe
 
141
  st.error(f"Failed to load language model: {str(e)}")
142
  return None
143
 
144
+ class ContentAnalyzer:
145
+ """Handles text analysis with optimized prompts"""
146
+
147
+ def __init__(self, generator):
148
+ self.generator = generator
149
+
150
+ def analyze_text(self, text: str, vc_style: VCStyle) -> Optional[Dict[str, Any]]:
151
+ try:
152
+ prompt = self._create_analysis_prompt(text, vc_style)
153
+ response = self._generate_response(prompt, max_length=512)
154
+ return self._parse_response(response)
155
+ except Exception as e:
156
+ logger.error(f"Analysis error: {e}")
157
+ return None
158
+
159
+ def _create_analysis_prompt(self, text: str, vc_style: VCStyle) -> str:
160
+ return f"""Analyze this startup pitch focusing on {', '.join(vc_style.key_interests)}:
161
+
162
+ {text}
163
+
164
+ Provide structured insights for:
165
+ 1. Key Points
166
+ 2. Metrics
167
+ 3. Risks
168
+ 4. Questions"""
169
+
170
+ def _generate_response(self, prompt: str, max_length: int) -> str:
171
+ try:
172
+ response = self.generator(
173
+ prompt,
174
+ max_new_tokens=max_length,
175
+ temperature=0.7,
176
+ top_p=0.95,
177
+ repetition_penalty=1.15
178
+ )
179
+ return response[0]['generated_text']
180
+ except Exception as e:
181
+ logger.error(f"Generation error: {e}")
182
+ return ""
183
+
184
+ def _parse_response(self, response: str) -> Dict[str, Any]:
185
+ try:
186
+ sections = response.split('\n\n')
187
+ parsed_response = {}
188
+ current_section = "general"
189
+
190
+ for section in sections:
191
+ if section.strip().endswith(':'):
192
+ current_section = section.strip()[:-1].lower()
193
+ parsed_response[current_section] = []
194
+ else:
195
+ if current_section in parsed_response:
196
+ parsed_response[current_section].append(section.strip())
197
+ else:
198
+ parsed_response[current_section] = [section.strip()]
199
+
200
+ return parsed_response
201
+ except Exception as e:
202
+ logger.error(f"Parsing error: {e}")
203
+ return {"error": "Failed to parse response"}
204
+
205
+ class UIManager:
206
+ """Manages Streamlit UI with performance optimization"""
207
+
208
+ @staticmethod
209
+ def setup_page():
210
+ st.set_page_config(
211
+ page_title="VC Call Assistant",
212
+ page_icon="πŸŽ™οΈ",
213
+ layout="wide",
214
+ initial_sidebar_state="expanded"
215
+ )
216
+
217
+ @staticmethod
218
+ def show_file_uploader() -> Optional[Any]:
219
+ return st.file_uploader(
220
+ "Upload Audio (Max 10 minutes)",
221
+ type=['wav', 'mp3', 'm4a'],
222
+ help="Supports WAV, MP3, M4A formats. Maximum duration: 10 minutes."
223
+ )
224
+
225
  def main():
226
  try:
227
  # Initialize UI
 
241
  Memory Usage: {MODEL_CONFIGS[model_name]['memory_required']}
242
  Description: {MODEL_CONFIGS[model_name]['description']}""")
243
 
244
+ # VC Profile
245
+ vc_name = st.text_input("Your Name")
246
+ note_style = st.selectbox(
247
+ "Note Style",
248
+ ["Bullet Points", "Paragraphs", "Q&A"]
249
+ )
250
+
251
+ interests = st.multiselect(
252
+ "Focus Areas",
253
+ ["Product", "Market", "Team", "Financials", "Technology"],
254
+ default=["Product", "Market"]
255
+ )
256
+
257
+ # Main content
258
+ st.title("πŸŽ™οΈ VC Call Assistant")
259
+
260
+ if not vc_name:
261
+ st.warning("Please enter your name in the sidebar.")
262
+ return
263
+
264
+ # Initialize processors
265
+ with st.spinner("Loading models..."):
266
+ whisper_model = ModelManager.load_whisper()
267
+ llm = ModelManager.load_llm(model_name)
268
+
269
+ if not whisper_model or not llm:
270
+ st.error("Failed to initialize models. Please refresh the page.")
271
+ return
272
+
273
+ audio_processor = AudioProcessor(whisper_model)
274
+ analyzer = ContentAnalyzer(llm)
275
+
276
+ # File upload
277
+ audio_file = UIManager.show_file_uploader()
278
+
279
+ if audio_file:
280
+ # Process audio
281
+ with st.spinner("Processing audio..."):
282
+ transcription = audio_processor.process_audio_chunk(audio_file)
283
+
284
+ if transcription:
285
+ # Display results
286
+ col1, col2 = st.columns(2)
287
+
288
+ with col1:
289
+ st.subheader("πŸ“ Transcript")
290
+ st.write(transcription)
291
+
292
+ with col2:
293
+ st.subheader("πŸ” Analysis")
294
+ vc_style = VCStyle(
295
+ name=vc_name,
296
+ note_format={"style": note_style},
297
+ key_interests=interests,
298
+ custom_sections=[],
299
+ insight_preferences={}
300
+ )
301
+
302
+ analysis = analyzer.analyze_text(transcription, vc_style)
303
+ if analysis:
304
+ st.write(analysis)
305
+
306
+ # Export option
307
+ st.download_button(
308
+ "πŸ“₯ Export Analysis",
309
+ data=json.dumps({
310
+ "timestamp": datetime.now().isoformat(),
311
+ "transcription": transcription,
312
+ "analysis": analysis
313
+ }, indent=2),
314
+ file_name=f"vc_analysis_{datetime.now():%Y%m%d_%H%M%S}.json",
315
+ mime="application/json"
316
+ )
317
+
318
+ except Exception as e:
319
+ logger.error(f"Application error: {e}")
320
+ st.error("An unexpected error occurred. Please refresh the page.")
321
+
322
+ finally:
323
+ # Cleanup
324
+ gc.collect()
325
+ if torch.cuda.is_available():
326
+ torch.cuda.empty_cache()
327
+
328
+ if __name__ == "__main__":
329
+ main()