Spaces:
Running
Running
File size: 7,610 Bytes
3b5fe24 5e4e457 3b5fe24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
from transformers.pipelines import pipeline
import torch
import torchaudio.transforms as T
import numpy as np
import json
# Initialize Whisper components globally (these are lightweight)
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
processor = WhisperProcessor(feature_extractor, tokenizer)
# Update transcription handler
def update_live_transcription(audio):
"""Real-time transcription updates."""
print("update_live_transcription called with:", type(audio))
if not audio or not isinstance(audio, tuple):
return ""
try:
sample_rate, audio_array = audio
print(f"got audio tuple – sample_rate={sample_rate}, shape={audio_array.shape}")
def process_audio(audio_array, sample_rate):
"""Pre-process audio for Whisper."""
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
# Convert to tensor for resampling
audio_tensor = torch.FloatTensor(audio_array)
# Resample to 16kHz if needed
if sample_rate != 16000:
resampler = T.Resample(sample_rate, 16000)
audio_tensor = resampler(audio_tensor)
# Normalize
audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
# Convert back to numpy array and return in correct format
return {
"raw": audio_tensor.numpy(), # Key must be "raw"
"sampling_rate": 16000 # Key must be "sampling_rate"
}
features = process_audio(audio_array, sample_rate)
asr = get_asr_pipeline()
result = asr(features)
return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
except Exception as e:
print(f"Transcription error: {str(e)}")
return ""
def get_asr_pipeline():
"""Lazy load ASR pipeline with proper configuration."""
global transcriber
if "transcriber" not in globals():
transcriber = pipeline(
"automatic-speech-recognition",
model="openai/whisper-base.en",
chunk_length_s=30,
stride_length_s=5,
device="cpu",
torch_dtype=torch.float32
)
return transcriber
def process_speech(audio_data, symptom_index):
"""Process speech input and convert to text."""
if not audio_data:
return []
if isinstance(audio_data, tuple) and len(audio_data) == 2:
sample_rate, audio_array = audio_data
# Audio preprocessing
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
audio_array = audio_array.astype(np.float32)
audio_array /= np.max(np.abs(audio_array))
# Ensure correct sampling rate
if sample_rate != 16000:
resampler = T.Resample(sample_rate, 16000)
audio_tensor = torch.FloatTensor(audio_array)
audio_tensor = resampler(audio_tensor)
audio_array = audio_tensor.numpy()
sample_rate = 16000
# Transcribe with error handling
# Format dictionary correctly with required keys
input_features = {
"raw": audio_array,
"sampling_rate": sample_rate
}
result = transcriber(input_features)
# Handle different result types
if isinstance(result, dict) and "text" in result:
transcript = result["text"].strip()
elif isinstance(result, str):
transcript = result.strip()
else:
print(f"Unexpected transcriber result type: {type(result)}")
return []
if not transcript:
print("No transcription generated")
return []
# Query symptoms with transcribed text
diagnosis_query = f"""
Given these symptoms: '{transcript}'
Identify the most likely ICD-10 diagnoses and key questions.
Focus on clinical implications.
"""
response = symptom_index.as_query_engine().query(diagnosis_query)
return [
{"role": "user", "content": transcript},
{"role": "assistant", "content": json.dumps({
"diagnoses": [],
"confidences": [],
"follow_up": str(response)
})}
]
else:
print(f"Invalid audio format: {type(audio_data)}")
return []
def format_response_for_user(response_dict):
"""Format the assistant's response dictionary into a user-friendly string."""
diagnoses = response_dict.get("diagnoses", [])
confidences = response_dict.get("confidences", [])
follow_up = response_dict.get("follow_up", "")
result = ""
if diagnoses:
result += "Possible Diagnoses:\n"
for i, diag in enumerate(diagnoses):
conf = f" ({confidences[i]*100:.1f}%)" if i < len(confidences) else ""
result += f"- {diag}{conf}\n"
if follow_up:
result += f"\nFollow-up: {follow_up}"
return result.strip()
def enhanced_process_speech(audio_path, symptom_index, history, api_key=None, model_tier="small", temp=0.7):
"""Handle streaming speech processing and chat updates."""
transcriber = get_asr_pipeline()
if not audio_path:
return history
if isinstance(audio_path, tuple) and len(audio_path) == 2:
sample_rate, audio_array = audio_path
# Audio preprocessing
if audio_array.ndim > 1:
audio_array = audio_array.mean(axis=1)
audio_array = audio_array.astype(np.float32)
audio_array /= np.max(np.abs(audio_array))
# Ensure correct sampling rate
if sample_rate != 16000:
resampler = T.Resample(
orig_freq=sample_rate,
new_freq=16000
)
audio_tensor = torch.FloatTensor(audio_array)
audio_tensor = resampler(audio_tensor)
audio_array = audio_tensor.numpy()
sample_rate = 16000
# Format input dictionary exactly as required
transcriber_input = {
"raw": audio_array,
"sampling_rate": sample_rate
}
# Get transcription from Whisper
result = transcriber(transcriber_input)
# Extract text from result
transcript = ""
if isinstance(result, dict):
transcript = result.get("text", "").strip()
elif isinstance(result, str):
transcript = result.strip()
if not transcript:
return history
# Process the symptoms
diagnosis_query = f"""
Based on these symptoms: '{transcript}'
Provide relevant ICD-10 codes and diagnostic questions.
"""
response = symptom_index.as_query_engine().query(diagnosis_query)
# Format and return chat messages
return history + [
{"role": "user", "content": transcript},
{"role": "assistant", "content": format_response_for_user({
"diagnoses": [],
"confidences": [],
"follow_up": str(response)
})}
]
|