Spaces:
Sleeping
Sleeping
| import base64 | |
| import os | |
| from langchain_core.tools import tool as langchain_tool | |
| from smolagents.tools import Tool, tool | |
| from pydub import AudioSegment | |
| from pyAudioAnalysis import audioSegmentation as aS | |
| from io import BytesIO | |
| from huggingface_hub import InferenceClient | |
| class TranscribeAudioTool(Tool): | |
| name = "transcribe_audio" | |
| description = "Transcribe an audio file (in base64 format or as an AudioSegment object)" | |
| inputs = { | |
| "audio": {"type": "any", "description": "The audio file in base64 format or as an AudioSegment object only"} | |
| } | |
| output_type = "string" | |
| def setup(self): | |
| self.model = InferenceClient(model="openai/whisper-large-v3", provider="hf-inference", token=os.getenv("HUGGINGFACE_API_KEY")) | |
| def _convert_audio_segment_to_wav(self, audio_segment: AudioSegment) -> bytes: | |
| """Convert AudioSegment to WAV format bytes""" | |
| try: | |
| # Ensure audio is in the correct format for Whisper | |
| # Convert to mono if stereo | |
| if audio_segment.channels > 1: | |
| audio_segment = audio_segment.set_channels(1) | |
| # Convert to 16kHz if different sample rate | |
| if audio_segment.frame_rate != 16000: | |
| audio_segment = audio_segment.set_frame_rate(16000) | |
| # Convert to 16-bit if different bit depth | |
| if audio_segment.sample_width != 2: # 2 bytes = 16 bits | |
| audio_segment = audio_segment.set_sample_width(2) | |
| # Export to WAV format | |
| buffer = BytesIO() | |
| audio_segment.export(buffer, format="wav") | |
| return buffer.getvalue() | |
| except Exception as e: | |
| raise RuntimeError(f"Error converting audio segment: {str(e)}") | |
| def forward(self, audio: any) -> str: | |
| try: | |
| # Handle AudioSegment object | |
| if isinstance(audio, AudioSegment): | |
| # Direct conversion to WAV bytes with proper format | |
| audio_data = self._convert_audio_segment_to_wav(audio) | |
| # Handle base64 string | |
| elif isinstance(audio, str): | |
| try: | |
| # Decode base64 and convert to AudioSegment for format standardization | |
| audio_data = base64.b64decode(audio) | |
| audio_segment = AudioSegment.from_file(BytesIO(audio_data)) | |
| # Convert to proper format for Whisper | |
| audio_data = self._convert_audio_segment_to_wav(audio_segment) | |
| except Exception as e: | |
| raise ValueError(f"Invalid base64 audio data: {str(e)}") | |
| else: | |
| raise ValueError(f"Unsupported audio type: {type(audio)}. Expected base64 string or AudioSegment object.") | |
| # Transcribe using the model | |
| try: | |
| result = self.model.automatic_speech_recognition(audio_data) | |
| return result["text"] | |
| except Exception as e: | |
| raise RuntimeError(f"Error in transcription: {str(e)}") | |
| except Exception as e: | |
| raise RuntimeError(f"Error in transcription: {str(e)}") | |
| transcribe_audio_tool = TranscribeAudioTool() | |
| def get_audio_from_file_path(file_path: str) -> str: | |
| """ | |
| Load an audio file from a file path and convert it to a base64 string | |
| Args: | |
| file_path: Path to the audio file (should be in mp3 format) | |
| Returns: | |
| The audio file in base64 format | |
| """ | |
| # Load the audio file | |
| try: | |
| audio = AudioSegment.from_file(file_path) | |
| except Exception as e: | |
| current_file_path = os.path.abspath(__file__) | |
| current_file_dir = os.path.dirname(current_file_path) | |
| file_path = os.path.join(current_file_dir, file_path) | |
| audio = AudioSegment.from_file(file_path) | |
| # Export the audio to a BytesIO object | |
| buffer = BytesIO() | |
| audio.export(buffer, format="wav") # You can change the format if needed | |
| # Encode the audio data to base64 | |
| audio_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| return audio_base64 | |
| def noise_reduction(audio: str) -> str: | |
| """ | |
| Reduce noise from an audio file | |
| Args: | |
| audio: The audio file in base64 format | |
| Returns: | |
| The denoised audio file in base64 format | |
| """ | |
| # Decode the base64 audio | |
| audio_data = base64.b64decode(audio) | |
| audio_segment = AudioSegment.from_file(BytesIO(audio_data)) | |
| # Apply noise reduction (simple example using low-pass filter) | |
| denoised_audio = audio_segment.low_pass_filter(3000) | |
| # Encode back to base64 | |
| buffer = BytesIO() | |
| denoised_audio.export(buffer, format="wav") | |
| return base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| def audio_segmentation(audio: str, segment_length: int = 30) -> list: | |
| """ | |
| Segment an audio file into smaller chunks | |
| Args: | |
| audio: The audio file in base64 format | |
| segment_length: Length of each segment in seconds | |
| Returns: | |
| List of audio segments in base64 format. Each of these segments can be used as input for the `transcribe_audio` tool. | |
| """ | |
| # Decode the base64 audio | |
| audio_data = base64.b64decode(audio) | |
| audio_segment = AudioSegment.from_file(BytesIO(audio_data)) | |
| # Segment the audio | |
| segments = [] | |
| for i in range(0, len(audio_segment), segment_length * 1000): | |
| segment = audio_segment[i:i + segment_length * 1000] | |
| buffer = BytesIO() | |
| segment.export(buffer, format="wav") | |
| segments.append(base64.b64encode(buffer.getvalue()).decode('utf-8')) | |
| return segments | |
| def speaker_diarization(audio: str) -> list: | |
| """ | |
| Diarize an audio file into speakers | |
| Args: | |
| audio: The audio file in base64 format | |
| Returns: | |
| List of speaker segments | |
| """ | |
| # Decode the base64 audio | |
| audio_data = base64.b64decode(audio) | |
| audio_buffer = BytesIO(audio_data) | |
| # Create a temporary BytesIO object for processing | |
| temp_buffer = BytesIO() | |
| audio_segment = AudioSegment.from_file(audio_buffer) | |
| audio_segment.export(temp_buffer, format="wav") | |
| temp_buffer.seek(0) | |
| # Perform speaker diarization using the buffer | |
| [flags, classes, centers] = aS.speakerDiarization(temp_buffer, 2) # Assuming 2 speakers | |
| # Process the output | |
| speaker_segments = [] | |
| for i, flag in enumerate(flags): | |
| speaker_segments.append((i, flag)) | |
| return speaker_segments |