zainulabedin949's picture
Update app.py
e9b0e37 verified
raw
history blame
5.17 kB
import gradio as gr
import numpy as np
import torch
import librosa
import soundfile as sf
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import tempfile
import os
# Constants
SAMPLING_RATE = 16000
MODEL_NAME = "MIT/ast-finetuned-audioset-10-10-0.4593"
DEFAULT_THRESHOLD = 0.7
# Load model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
def handle_audio_file(audio_file):
"""Handle uploaded audio file and convert to numpy array"""
try:
# Save to temp file and load with soundfile
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
tmp.write(audio_file.read())
tmp_path = tmp.name
audio, sr = sf.read(tmp_path)
os.unlink(tmp_path) # Clean up temp file
# Convert to mono if needed
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
return audio, sr
except Exception as e:
raise ValueError(f"Error processing audio file: {str(e)}")
def analyze_audio(audio_input, threshold=DEFAULT_THRESHOLD):
"""Process audio and detect anomalies"""
try:
# Handle different input types
if isinstance(audio_input, str): # File path
audio, sr = handle_audio_file(open(audio_input, 'rb'))
elif hasattr(audio_input, 'name'): # Gradio file object
audio, sr = handle_audio_file(audio_input)
elif isinstance(audio_input, tuple): # Direct numpy array
sr, audio = audio_input
else:
raise ValueError("Unsupported audio input format")
# Resample if needed
if sr != SAMPLING_RATE:
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLING_RATE)
# Extract features
inputs = feature_extractor(
audio,
sampling_rate=SAMPLING_RATE,
return_tensors="pt",
padding=True,
return_attention_mask=True
)
# Run inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
# Get results
predicted_class = "Normal" if probs[0][0] > threshold else "Anomaly"
confidence = probs[0][0].item() if predicted_class == "Normal" else 1 - probs[0][0].item()
# Create spectrogram
spectrogram = librosa.feature.melspectrogram(
y=audio,
sr=SAMPLING_RATE,
n_mels=64,
fmax=8000
)
db_spec = librosa.power_to_db(spectrogram, ref=np.max)
fig, ax = plt.subplots(figsize=(10, 4))
img = librosa.display.specshow(
db_spec,
x_axis='time',
y_axis='mel',
sr=SAMPLING_RATE,
fmax=8000,
ax=ax
)
fig.colorbar(img, ax=ax, format='%+2.0f dB')
ax.set(title='Mel Spectrogram')
plt.tight_layout()
# Save to temp file
spec_path = os.path.join(tempfile.gettempdir(), 'spec.png')
plt.savefig(spec_path, bbox_inches='tight')
plt.close()
return (
predicted_class,
f"{confidence:.1%}",
spec_path,
str(probs.tolist()[0])
)
except Exception as e:
return f"Error: {str(e)}", "", None, ""
# Gradio interface
with gr.Blocks(title="Industrial Audio Analyzer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🏭 Industrial Equipment Sound Analyzer
### Powered by Audio Spectrogram Transformer (AST)
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="Upload Equipment Audio (.wav)",
type="filepath"
)
threshold = gr.Slider(
minimum=0.5,
maximum=0.95,
step=0.05,
value=DEFAULT_THRESHOLD,
label="Anomaly Detection Threshold"
)
analyze_btn = gr.Button("πŸ” Analyze Sound", variant="primary")
with gr.Column():
result_label = gr.Label(label="Detection Result")
confidence = gr.Textbox(label="Confidence Score")
spectrogram = gr.Image(label="Spectrogram Visualization")
raw_probs = gr.Textbox(
label="Model Output Probabilities",
visible=False
)
analyze_btn.click(
fn=analyze_audio,
inputs=[audio_input, threshold],
outputs=[result_label, confidence, spectrogram, raw_probs]
)
gr.Markdown("""
**Instructions:**
- Upload .wav audio recordings (5-10 seconds recommended)
- Adjust threshold to control sensitivity
- Results show Normal/Anomaly classification with confidence
""")
if __name__ == "__main__":
demo.launch()