allanctan commited on
Commit
0a3dd60
·
1 Parent(s): 94dc091

revert to working state

Browse files
Files changed (1) hide show
  1. main.py +8 -113
main.py CHANGED
@@ -1,31 +1,12 @@
1
- from fastapi import FastAPI, UploadFile, File, WebSocket, WebSocketDisconnect
2
- from fastapi.middleware.cors import CORSMiddleware
3
  from unsloth import FastVisionModel
4
  import torch
5
  import shutil
6
  import os
7
- import json
8
- import base64
9
- import tempfile
10
- import logging
11
-
12
- logging.basicConfig(level=logging.INFO)
13
- logger = logging.getLogger(__name__)
14
-
15
  os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torchinductor"
16
 
17
  app = FastAPI()
18
 
19
- # Add CORS for WebSocket
20
- app.add_middleware(
21
- CORSMiddleware,
22
- allow_origins=["*"],
23
- allow_credentials=True,
24
- allow_methods=["*"],
25
- allow_headers=["*"],
26
- )
27
-
28
- # Load model at startup (same as your original)
29
  model, processor = FastVisionModel.from_pretrained("unsloth/gemma-3n-e2b-it", load_in_4bit=True)
30
  model.generation_config.cache_implementation = "static"
31
 
@@ -64,101 +45,15 @@ async def transcribe_audio(file: UploadFile = File(...)):
64
  tokenize=True, return_dict=True, return_tensors="pt"
65
  ).to(model.device, dtype=model.dtype)
66
 
67
- outputs = model.generate(**input_ids, max_new_tokens=64, do_sample=False, temperature=0.1)
 
 
 
 
68
  result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
 
 
69
  result = result.split("model\n")[-1].split("<end_of_turn>")[0].strip()
70
-
71
- # Cleanup
72
- if os.path.exists(filepath):
73
- os.remove(filepath)
74
-
75
  return {"text": result}
76
 
77
- # Simple WebSocket endpoint
78
- @app.websocket("/ws")
79
- async def websocket_endpoint(websocket: WebSocket):
80
- await websocket.accept()
81
- logger.info("WebSocket client connected")
82
-
83
- try:
84
- while True:
85
- # Receive message
86
- data = await websocket.receive_text()
87
- message = json.loads(data)
88
- logger.info(f"Received message: {message}")
89
-
90
- # Handle audio data
91
- if "audio_data" in message:
92
- audio_b64 = message["audio_data"]
93
- mime_type = message.get("mime_type", "audio/wav")
94
-
95
- try:
96
- # Use your exact transcribe logic
97
- transcription = await transcribe_base64_audio(audio_b64, mime_type)
98
-
99
- # Send response
100
- response = {
101
- "type": "transcription",
102
- "text": transcription
103
- }
104
- await websocket.send_text(json.dumps(response))
105
-
106
- except Exception as e:
107
- logger.error(f"Transcription error: {e}")
108
- await websocket.send_text(json.dumps({
109
- "type": "error",
110
- "message": str(e)
111
- }))
112
-
113
- # Handle ping/pong
114
- elif message.get("type") == "ping":
115
- await websocket.send_text(json.dumps({"type": "pong"}))
116
-
117
- else:
118
- await websocket.send_text(json.dumps({
119
- "type": "error",
120
- "message": "Unknown message format"
121
- }))
122
-
123
- except WebSocketDisconnect:
124
- logger.info("WebSocket client disconnected")
125
- except Exception as e:
126
- logger.error(f"WebSocket error: {e}")
127
-
128
- async def transcribe_base64_audio(audio_b64: str, mime_type: str) -> str:
129
- """Use your exact transcribe logic but with base64 audio data"""
130
-
131
- # Convert base64 to file (same as your transcribe logic)
132
- audio_data = base64.b64decode(audio_b64)
133
-
134
- # Create temp file
135
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
136
- temp_file.write(audio_data)
137
- filepath = temp_file.name
138
-
139
- try:
140
- # Your exact transcribe logic
141
- messages = [{
142
- "role": "user",
143
- "content": [
144
- {"type": "audio", "audio": filepath},
145
- {"type": "text", "text": "Transcribe this audio"},
146
- ]
147
- }]
148
-
149
- input_ids = processor.apply_chat_template(
150
- messages, add_generation_prompt=True,
151
- tokenize=True, return_dict=True, return_tensors="pt"
152
- ).to(model.device, dtype=model.dtype)
153
 
154
- outputs = model.generate(**input_ids, max_new_tokens=64, do_sample=False, temperature=0.1)
155
- result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
156
- print(result)
157
- result = result.split("model\n")[-1].split("<end_of_turn>")[0].strip()
158
-
159
- return result
160
-
161
- finally:
162
- # Cleanup temp file
163
- if os.path.exists(filepath):
164
- os.remove(filepath)
 
1
+ from fastapi import FastAPI, UploadFile, File
 
2
  from unsloth import FastVisionModel
3
  import torch
4
  import shutil
5
  import os
 
 
 
 
 
 
 
 
6
  os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/torchinductor"
7
 
8
  app = FastAPI()
9
 
 
 
 
 
 
 
 
 
 
 
10
  model, processor = FastVisionModel.from_pretrained("unsloth/gemma-3n-e2b-it", load_in_4bit=True)
11
  model.generation_config.cache_implementation = "static"
12
 
 
45
  tokenize=True, return_dict=True, return_tensors="pt"
46
  ).to(model.device, dtype=model.dtype)
47
 
48
+ # Generate output from the model
49
+ outputs = model.generate(**input_ids, max_new_tokens=64, do_sample=False,
50
+ temperature=0.1)
51
+
52
+ # decode and print the output as text
53
  result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
54
+
55
+ # Extract only transcription
56
  result = result.split("model\n")[-1].split("<end_of_turn>")[0].strip()
 
 
 
 
 
57
  return {"text": result}
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59