|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import librosa |
|
import soundfile as sf |
|
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification |
|
import matplotlib.pyplot as plt |
|
from matplotlib.colors import Normalize |
|
import tempfile |
|
import os |
|
|
|
|
|
SAMPLING_RATE = 16000 |
|
MODEL_NAME = "MIT/ast-finetuned-audioset-10-10-0.4593" |
|
DEFAULT_THRESHOLD = 0.7 |
|
|
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) |
|
model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME) |
|
|
|
def handle_audio_file(audio_file): |
|
"""Handle uploaded audio file and convert to numpy array""" |
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: |
|
tmp.write(audio_file.read()) |
|
tmp_path = tmp.name |
|
|
|
audio, sr = sf.read(tmp_path) |
|
os.unlink(tmp_path) |
|
|
|
|
|
if len(audio.shape) > 1: |
|
audio = np.mean(audio, axis=1) |
|
|
|
return audio, sr |
|
except Exception as e: |
|
raise ValueError(f"Error processing audio file: {str(e)}") |
|
|
|
def analyze_audio(audio_input, threshold=DEFAULT_THRESHOLD): |
|
"""Process audio and detect anomalies""" |
|
try: |
|
|
|
if isinstance(audio_input, str): |
|
audio, sr = handle_audio_file(open(audio_input, 'rb')) |
|
elif hasattr(audio_input, 'name'): |
|
audio, sr = handle_audio_file(audio_input) |
|
elif isinstance(audio_input, tuple): |
|
sr, audio = audio_input |
|
else: |
|
raise ValueError("Unsupported audio input format") |
|
|
|
|
|
if sr != SAMPLING_RATE: |
|
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLING_RATE) |
|
|
|
|
|
inputs = feature_extractor( |
|
audio, |
|
sampling_rate=SAMPLING_RATE, |
|
return_tensors="pt", |
|
padding=True, |
|
return_attention_mask=True |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
predicted_class = "Normal" if probs[0][0] > threshold else "Anomaly" |
|
confidence = probs[0][0].item() if predicted_class == "Normal" else 1 - probs[0][0].item() |
|
|
|
|
|
spectrogram = librosa.feature.melspectrogram( |
|
y=audio, |
|
sr=SAMPLING_RATE, |
|
n_mels=64, |
|
fmax=8000 |
|
) |
|
db_spec = librosa.power_to_db(spectrogram, ref=np.max) |
|
|
|
fig, ax = plt.subplots(figsize=(10, 4)) |
|
img = librosa.display.specshow( |
|
db_spec, |
|
x_axis='time', |
|
y_axis='mel', |
|
sr=SAMPLING_RATE, |
|
fmax=8000, |
|
ax=ax |
|
) |
|
fig.colorbar(img, ax=ax, format='%+2.0f dB') |
|
ax.set(title='Mel Spectrogram') |
|
plt.tight_layout() |
|
|
|
|
|
spec_path = os.path.join(tempfile.gettempdir(), 'spec.png') |
|
plt.savefig(spec_path, bbox_inches='tight') |
|
plt.close() |
|
|
|
return ( |
|
predicted_class, |
|
f"{confidence:.1%}", |
|
spec_path, |
|
str(probs.tolist()[0]) |
|
) |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}", "", None, "" |
|
|
|
|
|
with gr.Blocks(title="Industrial Audio Analyzer", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# π Industrial Equipment Sound Analyzer |
|
### Powered by Audio Spectrogram Transformer (AST) |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
audio_input = gr.Audio( |
|
label="Upload Equipment Audio (.wav)", |
|
type="filepath" |
|
) |
|
threshold = gr.Slider( |
|
minimum=0.5, |
|
maximum=0.95, |
|
step=0.05, |
|
value=DEFAULT_THRESHOLD, |
|
label="Anomaly Detection Threshold" |
|
) |
|
analyze_btn = gr.Button("π Analyze Sound", variant="primary") |
|
|
|
with gr.Column(): |
|
result_label = gr.Label(label="Detection Result") |
|
confidence = gr.Textbox(label="Confidence Score") |
|
spectrogram = gr.Image(label="Spectrogram Visualization") |
|
raw_probs = gr.Textbox( |
|
label="Model Output Probabilities", |
|
visible=False |
|
) |
|
|
|
analyze_btn.click( |
|
fn=analyze_audio, |
|
inputs=[audio_input, threshold], |
|
outputs=[result_label, confidence, spectrogram, raw_probs] |
|
) |
|
|
|
gr.Markdown(""" |
|
**Instructions:** |
|
- Upload .wav audio recordings (5-10 seconds recommended) |
|
- Adjust threshold to control sensitivity |
|
- Results show Normal/Anomaly classification with confidence |
|
""") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|