from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor from transformers.pipelines import pipeline import torch import torchaudio.transforms as T import numpy as np import json # Initialize Whisper components globally (these are lightweight) feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en") tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en") processor = WhisperProcessor(feature_extractor, tokenizer) # Update transcription handler def update_live_transcription(audio): """Real-time transcription updates.""" print("update_live_transcription called with:", type(audio)) if not audio or not isinstance(audio, tuple): return "" try: sample_rate, audio_array = audio print(f"got audio tuple – sample_rate={sample_rate}, shape={audio_array.shape}") def process_audio(audio_array, sample_rate): """Pre-process audio for Whisper.""" if audio_array.ndim > 1: audio_array = audio_array.mean(axis=1) # Convert to tensor for resampling audio_tensor = torch.FloatTensor(audio_array) # Resample to 16kHz if needed if sample_rate != 16000: resampler = T.Resample(sample_rate, 16000) audio_tensor = resampler(audio_tensor) # Normalize audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor)) # Convert back to numpy array and return in correct format return { "raw": audio_tensor.numpy(), # Key must be "raw" "sampling_rate": 16000 # Key must be "sampling_rate" } features = process_audio(audio_array, sample_rate) asr = get_asr_pipeline() result = asr(features) return result.get("text", "").strip() if isinstance(result, dict) else str(result).strip() except Exception as e: print(f"Transcription error: {str(e)}") return "" def get_asr_pipeline(): """Lazy load ASR pipeline with proper configuration.""" global transcriber if "transcriber" not in globals(): transcriber = pipeline( "automatic-speech-recognition", model="openai/whisper-base.en", chunk_length_s=30, stride_length_s=5, device="cpu", torch_dtype=torch.float32 ) return transcriber def process_speech(audio_data, symptom_index): """Process speech input and convert to text.""" if not audio_data: return [] if isinstance(audio_data, tuple) and len(audio_data) == 2: sample_rate, audio_array = audio_data # Audio preprocessing if audio_array.ndim > 1: audio_array = audio_array.mean(axis=1) audio_array = audio_array.astype(np.float32) audio_array /= np.max(np.abs(audio_array)) # Ensure correct sampling rate if sample_rate != 16000: resampler = T.Resample(sample_rate, 16000) audio_tensor = torch.FloatTensor(audio_array) audio_tensor = resampler(audio_tensor) audio_array = audio_tensor.numpy() sample_rate = 16000 # Transcribe with error handling # Format dictionary correctly with required keys input_features = { "raw": audio_array, "sampling_rate": sample_rate } result = transcriber(input_features) # Handle different result types if isinstance(result, dict) and "text" in result: transcript = result["text"].strip() elif isinstance(result, str): transcript = result.strip() else: print(f"Unexpected transcriber result type: {type(result)}") return [] if not transcript: print("No transcription generated") return [] # Query symptoms with transcribed text diagnosis_query = f""" Given these symptoms: '{transcript}' Identify the most likely ICD-10 diagnoses and key questions. Focus on clinical implications. """ response = symptom_index.as_query_engine().query(diagnosis_query) return [ {"role": "user", "content": transcript}, {"role": "assistant", "content": json.dumps({ "diagnoses": [], "confidences": [], "follow_up": str(response) })} ] else: print(f"Invalid audio format: {type(audio_data)}") return [] def format_response_for_user(response_dict): """Format the assistant's response dictionary into a user-friendly string.""" diagnoses = response_dict.get("diagnoses", []) confidences = response_dict.get("confidences", []) follow_up = response_dict.get("follow_up", "") result = "" if diagnoses: result += "Possible Diagnoses:\n" for i, diag in enumerate(diagnoses): conf = f" ({confidences[i]*100:.1f}%)" if i < len(confidences) else "" result += f"- {diag}{conf}\n" if follow_up: result += f"\nFollow-up: {follow_up}" return result.strip() def enhanced_process_speech(audio_path, symptom_index, history, api_key=None, model_tier="small", temp=0.7): """Handle streaming speech processing and chat updates.""" transcriber = get_asr_pipeline() if not audio_path: return history if isinstance(audio_path, tuple) and len(audio_path) == 2: sample_rate, audio_array = audio_path # Audio preprocessing if audio_array.ndim > 1: audio_array = audio_array.mean(axis=1) audio_array = audio_array.astype(np.float32) audio_array /= np.max(np.abs(audio_array)) # Ensure correct sampling rate if sample_rate != 16000: resampler = T.Resample( orig_freq=sample_rate, new_freq=16000 ) audio_tensor = torch.FloatTensor(audio_array) audio_tensor = resampler(audio_tensor) audio_array = audio_tensor.numpy() sample_rate = 16000 # Format input dictionary exactly as required transcriber_input = { "raw": audio_array, "sampling_rate": sample_rate } # Get transcription from Whisper result = transcriber(transcriber_input) # Extract text from result transcript = "" if isinstance(result, dict): transcript = result.get("text", "").strip() elif isinstance(result, str): transcript = result.strip() if not transcript: return history # Process the symptoms diagnosis_query = f""" Based on these symptoms: '{transcript}' Provide relevant ICD-10 codes and diagnostic questions. """ response = symptom_index.as_query_engine().query(diagnosis_query) # Format and return chat messages return history + [ {"role": "user", "content": transcript}, {"role": "assistant", "content": format_response_for_user({ "diagnoses": [], "confidences": [], "follow_up": str(response) })} ]