File size: 5,169 Bytes
d0cb32e
 
 
 
e9b0e37
d0cb32e
 
 
e9b0e37
 
d0cb32e
 
 
 
 
 
 
 
 
 
e9b0e37
 
d0cb32e
e9b0e37
 
 
 
 
 
 
 
 
d0cb32e
e9b0e37
d0cb32e
e9b0e37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0cb32e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9b0e37
d0cb32e
 
 
e9b0e37
d0cb32e
 
 
e9b0e37
d0cb32e
 
 
 
07c6db0
 
 
 
 
 
 
 
 
 
 
d0cb32e
e9b0e37
 
 
 
d0cb32e
 
 
 
 
e9b0e37
d0cb32e
 
 
 
e9b0e37
d0cb32e
 
 
 
 
 
 
 
 
 
 
e9b0e37
 
d0cb32e
 
 
 
 
 
e9b0e37
d0cb32e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9b0e37
 
 
 
d0cb32e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
import numpy as np
import torch
import librosa
import soundfile as sf
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import tempfile
import os

# Constants
SAMPLING_RATE = 16000
MODEL_NAME = "MIT/ast-finetuned-audioset-10-10-0.4593"
DEFAULT_THRESHOLD = 0.7

# Load model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)

def handle_audio_file(audio_file):
    """Handle uploaded audio file and convert to numpy array"""
    try:
        # Save to temp file and load with soundfile
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
            tmp.write(audio_file.read())
            tmp_path = tmp.name
        
        audio, sr = sf.read(tmp_path)
        os.unlink(tmp_path)  # Clean up temp file
        
        # Convert to mono if needed
        if len(audio.shape) > 1:
            audio = np.mean(audio, axis=1)
            
        return audio, sr
    except Exception as e:
        raise ValueError(f"Error processing audio file: {str(e)}")

def analyze_audio(audio_input, threshold=DEFAULT_THRESHOLD):
    """Process audio and detect anomalies"""
    try:
        # Handle different input types
        if isinstance(audio_input, str):  # File path
            audio, sr = handle_audio_file(open(audio_input, 'rb'))
        elif hasattr(audio_input, 'name'):  # Gradio file object
            audio, sr = handle_audio_file(audio_input)
        elif isinstance(audio_input, tuple):  # Direct numpy array
            sr, audio = audio_input
        else:
            raise ValueError("Unsupported audio input format")
        
        # Resample if needed
        if sr != SAMPLING_RATE:
            audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLING_RATE)
        
        # Extract features
        inputs = feature_extractor(
            audio, 
            sampling_rate=SAMPLING_RATE, 
            return_tensors="pt",
            padding=True,
            return_attention_mask=True
        )
        
        # Run inference
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            probs = torch.softmax(logits, dim=-1)
            
        # Get results
        predicted_class = "Normal" if probs[0][0] > threshold else "Anomaly"
        confidence = probs[0][0].item() if predicted_class == "Normal" else 1 - probs[0][0].item()
        
        # Create spectrogram
        spectrogram = librosa.feature.melspectrogram(
            y=audio, 
            sr=SAMPLING_RATE,
            n_mels=64,
            fmax=8000
        )
        db_spec = librosa.power_to_db(spectrogram, ref=np.max)
        
        fig, ax = plt.subplots(figsize=(10, 4))
        img = librosa.display.specshow(
            db_spec, 
            x_axis='time',
            y_axis='mel',
            sr=SAMPLING_RATE,
            fmax=8000,
            ax=ax
        )
        fig.colorbar(img, ax=ax, format='%+2.0f dB')
        ax.set(title='Mel Spectrogram')
        plt.tight_layout()
        
        # Save to temp file
        spec_path = os.path.join(tempfile.gettempdir(), 'spec.png')
        plt.savefig(spec_path, bbox_inches='tight')
        plt.close()
        
        return (
            predicted_class,
            f"{confidence:.1%}",
            spec_path,
            str(probs.tolist()[0])
        )
        
    except Exception as e:
        return f"Error: {str(e)}", "", None, ""

# Gradio interface
with gr.Blocks(title="Industrial Audio Analyzer", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # 🏭 Industrial Equipment Sound Analyzer
    ### Powered by Audio Spectrogram Transformer (AST)
    """)
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(
                label="Upload Equipment Audio (.wav)",
                type="filepath"
            )
            threshold = gr.Slider(
                minimum=0.5,
                maximum=0.95,
                step=0.05,
                value=DEFAULT_THRESHOLD,
                label="Anomaly Detection Threshold"
            )
            analyze_btn = gr.Button("πŸ” Analyze Sound", variant="primary")
            
        with gr.Column():
            result_label = gr.Label(label="Detection Result")
            confidence = gr.Textbox(label="Confidence Score")
            spectrogram = gr.Image(label="Spectrogram Visualization")
            raw_probs = gr.Textbox(
                label="Model Output Probabilities",
                visible=False
            )
    
    analyze_btn.click(
        fn=analyze_audio,
        inputs=[audio_input, threshold],
        outputs=[result_label, confidence, spectrogram, raw_probs]
    )
    
    gr.Markdown("""
    **Instructions:**
    - Upload .wav audio recordings (5-10 seconds recommended)
    - Adjust threshold to control sensitivity
    - Results show Normal/Anomaly classification with confidence
    """)

if __name__ == "__main__":
    demo.launch()