File size: 5,169 Bytes
d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e 07c6db0 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e e9b0e37 d0cb32e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import gradio as gr
import numpy as np
import torch
import librosa
import soundfile as sf
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import tempfile
import os
# Constants
SAMPLING_RATE = 16000
MODEL_NAME = "MIT/ast-finetuned-audioset-10-10-0.4593"
DEFAULT_THRESHOLD = 0.7
# Load model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
def handle_audio_file(audio_file):
"""Handle uploaded audio file and convert to numpy array"""
try:
# Save to temp file and load with soundfile
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
tmp.write(audio_file.read())
tmp_path = tmp.name
audio, sr = sf.read(tmp_path)
os.unlink(tmp_path) # Clean up temp file
# Convert to mono if needed
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
return audio, sr
except Exception as e:
raise ValueError(f"Error processing audio file: {str(e)}")
def analyze_audio(audio_input, threshold=DEFAULT_THRESHOLD):
"""Process audio and detect anomalies"""
try:
# Handle different input types
if isinstance(audio_input, str): # File path
audio, sr = handle_audio_file(open(audio_input, 'rb'))
elif hasattr(audio_input, 'name'): # Gradio file object
audio, sr = handle_audio_file(audio_input)
elif isinstance(audio_input, tuple): # Direct numpy array
sr, audio = audio_input
else:
raise ValueError("Unsupported audio input format")
# Resample if needed
if sr != SAMPLING_RATE:
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLING_RATE)
# Extract features
inputs = feature_extractor(
audio,
sampling_rate=SAMPLING_RATE,
return_tensors="pt",
padding=True,
return_attention_mask=True
)
# Run inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
# Get results
predicted_class = "Normal" if probs[0][0] > threshold else "Anomaly"
confidence = probs[0][0].item() if predicted_class == "Normal" else 1 - probs[0][0].item()
# Create spectrogram
spectrogram = librosa.feature.melspectrogram(
y=audio,
sr=SAMPLING_RATE,
n_mels=64,
fmax=8000
)
db_spec = librosa.power_to_db(spectrogram, ref=np.max)
fig, ax = plt.subplots(figsize=(10, 4))
img = librosa.display.specshow(
db_spec,
x_axis='time',
y_axis='mel',
sr=SAMPLING_RATE,
fmax=8000,
ax=ax
)
fig.colorbar(img, ax=ax, format='%+2.0f dB')
ax.set(title='Mel Spectrogram')
plt.tight_layout()
# Save to temp file
spec_path = os.path.join(tempfile.gettempdir(), 'spec.png')
plt.savefig(spec_path, bbox_inches='tight')
plt.close()
return (
predicted_class,
f"{confidence:.1%}",
spec_path,
str(probs.tolist()[0])
)
except Exception as e:
return f"Error: {str(e)}", "", None, ""
# Gradio interface
with gr.Blocks(title="Industrial Audio Analyzer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# π Industrial Equipment Sound Analyzer
### Powered by Audio Spectrogram Transformer (AST)
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="Upload Equipment Audio (.wav)",
type="filepath"
)
threshold = gr.Slider(
minimum=0.5,
maximum=0.95,
step=0.05,
value=DEFAULT_THRESHOLD,
label="Anomaly Detection Threshold"
)
analyze_btn = gr.Button("π Analyze Sound", variant="primary")
with gr.Column():
result_label = gr.Label(label="Detection Result")
confidence = gr.Textbox(label="Confidence Score")
spectrogram = gr.Image(label="Spectrogram Visualization")
raw_probs = gr.Textbox(
label="Model Output Probabilities",
visible=False
)
analyze_btn.click(
fn=analyze_audio,
inputs=[audio_input, threshold],
outputs=[result_label, confidence, spectrogram, raw_probs]
)
gr.Markdown("""
**Instructions:**
- Upload .wav audio recordings (5-10 seconds recommended)
- Adjust threshold to control sensitivity
- Results show Normal/Anomaly classification with confidence
""")
if __name__ == "__main__":
demo.launch()
|