zainulabedin949 commited on
Commit
65c3f40
Β·
verified Β·
1 Parent(s): e9b0e37

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -81
app.py CHANGED
@@ -18,140 +18,195 @@ DEFAULT_THRESHOLD = 0.7
18
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
19
  model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
20
 
21
- def handle_audio_file(audio_file):
22
- """Handle uploaded audio file and convert to numpy array"""
23
- try:
24
- # Save to temp file and load with soundfile
25
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
26
- tmp.write(audio_file.read())
27
- tmp_path = tmp.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- audio, sr = sf.read(tmp_path)
30
- os.unlink(tmp_path) # Clean up temp file
31
 
32
- # Convert to mono if needed
33
- if len(audio.shape) > 1:
34
- audio = np.mean(audio, axis=1)
35
-
36
- return audio, sr
37
- except Exception as e:
38
- raise ValueError(f"Error processing audio file: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def analyze_audio(audio_input, threshold=DEFAULT_THRESHOLD):
41
  """Process audio and detect anomalies"""
42
  try:
43
- # Handle different input types
44
- if isinstance(audio_input, str): # File path
45
- audio, sr = handle_audio_file(open(audio_input, 'rb'))
46
- elif hasattr(audio_input, 'name'): # Gradio file object
47
- audio, sr = handle_audio_file(audio_input)
48
- elif isinstance(audio_input, tuple): # Direct numpy array
49
- sr, audio = audio_input
50
- else:
51
- raise ValueError("Unsupported audio input format")
52
 
53
- # Resample if needed
 
 
54
  if sr != SAMPLING_RATE:
55
  audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLING_RATE)
56
 
57
- # Extract features
58
- inputs = feature_extractor(
59
- audio,
60
- sampling_rate=SAMPLING_RATE,
61
- return_tensors="pt",
62
- padding=True,
63
- return_attention_mask=True
64
- )
65
-
66
- # Run inference
67
  with torch.no_grad():
68
  outputs = model(**inputs)
69
- logits = outputs.logits
70
- probs = torch.softmax(logits, dim=-1)
71
-
72
  # Get results
73
  predicted_class = "Normal" if probs[0][0] > threshold else "Anomaly"
74
  confidence = probs[0][0].item() if predicted_class == "Normal" else 1 - probs[0][0].item()
75
 
76
- # Create spectrogram
77
- spectrogram = librosa.feature.melspectrogram(
78
- y=audio,
79
- sr=SAMPLING_RATE,
80
- n_mels=64,
81
- fmax=8000
82
- )
83
  db_spec = librosa.power_to_db(spectrogram, ref=np.max)
84
 
85
  fig, ax = plt.subplots(figsize=(10, 4))
86
- img = librosa.display.specshow(
87
- db_spec,
88
- x_axis='time',
89
- y_axis='mel',
90
- sr=SAMPLING_RATE,
91
- fmax=8000,
92
- ax=ax
93
- )
94
- fig.colorbar(img, ax=ax, format='%+2.0f dB')
95
- ax.set(title='Mel Spectrogram')
96
- plt.tight_layout()
97
 
98
- # Save to temp file
99
  spec_path = os.path.join(tempfile.gettempdir(), 'spec.png')
100
  plt.savefig(spec_path, bbox_inches='tight')
101
  plt.close()
102
 
 
 
 
103
  return (
104
  predicted_class,
105
  f"{confidence:.1%}",
106
  spec_path,
107
- str(probs.tolist()[0])
108
  )
109
 
110
  except Exception as e:
111
  return f"Error: {str(e)}", "", None, ""
112
 
113
- # Gradio interface
114
- with gr.Blocks(title="Industrial Audio Analyzer", theme=gr.themes.Soft()) as demo:
115
  gr.Markdown("""
116
- # 🏭 Industrial Equipment Sound Analyzer
117
- ### Powered by Audio Spectrogram Transformer (AST)
118
  """)
119
 
120
  with gr.Row():
121
  with gr.Column():
122
  audio_input = gr.Audio(
123
- label="Upload Equipment Audio (.wav)",
124
- type="filepath"
 
125
  )
126
  threshold = gr.Slider(
127
- minimum=0.5,
128
- maximum=0.95,
129
- step=0.05,
130
- value=DEFAULT_THRESHOLD,
131
- label="Anomaly Detection Threshold"
132
  )
133
- analyze_btn = gr.Button("πŸ” Analyze Sound", variant="primary")
134
 
135
  with gr.Column():
136
- result_label = gr.Label(label="Detection Result")
137
  confidence = gr.Textbox(label="Confidence Score")
138
- spectrogram = gr.Image(label="Spectrogram Visualization")
139
- raw_probs = gr.Textbox(
140
- label="Model Output Probabilities",
141
- visible=False
 
142
  )
143
 
144
  analyze_btn.click(
145
  fn=analyze_audio,
146
  inputs=[audio_input, threshold],
147
- outputs=[result_label, confidence, spectrogram, raw_probs]
148
  )
149
 
150
  gr.Markdown("""
151
- **Instructions:**
152
- - Upload .wav audio recordings (5-10 seconds recommended)
153
- - Adjust threshold to control sensitivity
154
- - Results show Normal/Anomaly classification with confidence
 
 
 
155
  """)
156
 
157
  if __name__ == "__main__":
 
18
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
19
  model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
20
 
21
+ # Equipment knowledge base
22
+ EQUIPMENT_RECOMMENDATIONS = {
23
+ "bearing": {
24
+ "high_frequency": "Recommend bearing replacement. High-frequency noise indicates wear or lubrication issues.",
25
+ "low_frequency": "Check for improper installation or contamination in bearings.",
26
+ "irregular": "Possible bearing cage damage. Schedule vibration analysis."
27
+ },
28
+ "pump": {
29
+ "cavitation": "Pump cavitation detected. Check suction conditions and NPSH.",
30
+ "impeller": "Impeller damage likely. Inspect and balance if needed.",
31
+ "misalignment": "Misalignment detected. Perform laser shaft alignment."
32
+ },
33
+ "motor": {
34
+ "electrical": "Electrical fault suspected. Check windings and connections.",
35
+ "mechanical": "Mechanical imbalance detected. Perform dynamic balancing.",
36
+ "bearing": "Motor bearing wear detected. Schedule replacement."
37
+ },
38
+ "compressor": {
39
+ "valve": "Compressor valve leakage suspected. Perform valve test.",
40
+ "pulsation": "Pulsation issues detected. Check dampeners and piping.",
41
+ "surge": "Compressor surge condition. Review control settings."
42
+ }
43
+ }
44
+
45
+ def analyze_frequency_patterns(audio, sr):
46
+ """Analyze frequency patterns to identify potential issues"""
47
+ patterns = []
48
+
49
+ # Spectral analysis
50
+ spectral_centroid = librosa.feature.spectral_centroid(y=audio, sr=sr)[0]
51
+ spectral_rolloff = librosa.feature.spectral_rolloff(y=audio, sr=sr)[0]
52
+
53
+ mean_centroid = np.mean(spectral_centroid)
54
+ mean_rolloff = np.mean(spectral_rolloff)
55
+
56
+ if mean_centroid > 3000: # High frequency components
57
+ patterns.append("high_frequency")
58
+ elif mean_centroid < 1000: # Low frequency components
59
+ patterns.append("low_frequency")
60
 
61
+ if mean_rolloff > 8000: # Rich in harmonics
62
+ patterns.append("harmonic_rich")
63
 
64
+ return patterns
65
+
66
+ def generate_recommendation(prediction, confidence, audio, sr):
67
+ """Generate maintenance recommendations based on analysis"""
68
+ if prediction == "Normal":
69
+ return "No immediate action required. Equipment operating within normal parameters."
70
+
71
+ patterns = analyze_frequency_patterns(audio, sr)
72
+
73
+ # Simple equipment type classifier based on frequency profile
74
+ spectral_flatness = librosa.feature.spectral_flatness(y=audio)[0]
75
+ mean_flatness = np.mean(spectral_flatness)
76
+
77
+ if mean_flatness < 0.2:
78
+ equipment_type = "bearing"
79
+ elif 0.2 <= mean_flatness < 0.6:
80
+ equipment_type = "pump"
81
+ else:
82
+ equipment_type = "motor" if np.mean(audio) < 0.1 else "compressor"
83
+
84
+ # Generate specific recommendations
85
+ recommendations = ["πŸ”§ Maintenance Recommendations:"]
86
+ recommendations.append(f"Detected issues in {equipment_type} with {confidence:.1%} confidence")
87
+
88
+ for pattern in patterns:
89
+ if pattern in EQUIPMENT_RECOMMENDATIONS.get(equipment_type, {}):
90
+ recommendations.append(f"β†’ {EQUIPMENT_RECOMMENDATIONS[equipment_type][pattern]}")
91
+
92
+ # General recommendations
93
+ if prediction == "Anomaly":
94
+ recommendations.append("\nπŸ› οΈ Suggested Actions:")
95
+ recommendations.append("1. Isolate equipment if possible")
96
+ recommendations.append("2. Perform visual inspection")
97
+ recommendations.append("3. Schedule detailed diagnostics")
98
+ recommendations.append(f"4. Review last maintenance records ({equipment_type})")
99
+
100
+ if confidence > 0.8:
101
+ recommendations.append("\n🚨 Urgent: High confidence abnormality detected. Recommend immediate inspection!")
102
+
103
+ return "\n".join(recommendations)
104
 
105
  def analyze_audio(audio_input, threshold=DEFAULT_THRESHOLD):
106
  """Process audio and detect anomalies"""
107
  try:
108
+ # Handle file upload
109
+ if isinstance(audio_input, str):
110
+ audio, sr = sf.read(audio_input)
111
+ else: # Gradio file object
112
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
113
+ tmp.write(audio_input.read())
114
+ tmp_path = tmp.name
115
+ audio, sr = sf.read(tmp_path)
116
+ os.unlink(tmp_path)
117
 
118
+ # Convert to mono and resample if needed
119
+ if len(audio.shape) > 1:
120
+ audio = np.mean(audio, axis=1)
121
  if sr != SAMPLING_RATE:
122
  audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLING_RATE)
123
 
124
+ # Feature extraction and prediction
125
+ inputs = feature_extractor(audio, sampling_rate=SAMPLING_RATE, return_tensors="pt")
 
 
 
 
 
 
 
 
126
  with torch.no_grad():
127
  outputs = model(**inputs)
128
+ probs = torch.softmax(outputs.logits, dim=-1)
129
+
 
130
  # Get results
131
  predicted_class = "Normal" if probs[0][0] > threshold else "Anomaly"
132
  confidence = probs[0][0].item() if predicted_class == "Normal" else 1 - probs[0][0].item()
133
 
134
+ # Generate spectrogram
135
+ spectrogram = librosa.feature.melspectrogram(y=audio, sr=SAMPLING_RATE, n_mels=64, fmax=8000)
 
 
 
 
 
136
  db_spec = librosa.power_to_db(spectrogram, ref=np.max)
137
 
138
  fig, ax = plt.subplots(figsize=(10, 4))
139
+ librosa.display.specshow(db_spec, x_axis='time', y_axis='mel', sr=SAMPLING_RATE, fmax=8000, ax=ax)
140
+ plt.colorbar(format='%+2.0f dB')
141
+ plt.title('Mel Spectrogram with Anomaly Detection')
142
+
143
+ # Mark anomalies on plot
144
+ if predicted_class == "Anomaly":
145
+ plt.text(0.5, 0.9, 'ANOMALY DETECTED', color='red',
146
+ ha='center', va='center', transform=ax.transAxes,
147
+ fontsize=14, bbox=dict(facecolor='white', alpha=0.8))
 
 
148
 
 
149
  spec_path = os.path.join(tempfile.gettempdir(), 'spec.png')
150
  plt.savefig(spec_path, bbox_inches='tight')
151
  plt.close()
152
 
153
+ # Generate detailed recommendations
154
+ recommendations = generate_recommendation(predicted_class, confidence, audio, SAMPLING_RATE)
155
+
156
  return (
157
  predicted_class,
158
  f"{confidence:.1%}",
159
  spec_path,
160
+ recommendations
161
  )
162
 
163
  except Exception as e:
164
  return f"Error: {str(e)}", "", None, ""
165
 
166
+ # Gradio Interface
167
+ with gr.Blocks(title="Industrial Diagnostic Assistant πŸ‘¨β€πŸ”§", theme=gr.themes.Soft()) as demo:
168
  gr.Markdown("""
169
+ # 🏭 Industrial Equipment Diagnostic Assistant
170
+ ## Acoustic Anomaly Detection & Maintenance Recommendation System
171
  """)
172
 
173
  with gr.Row():
174
  with gr.Column():
175
  audio_input = gr.Audio(
176
+ label="Upload Equipment Recording (.wav)",
177
+ type="filepath",
178
+ source="upload"
179
  )
180
  threshold = gr.Slider(
181
+ minimum=0.5, maximum=0.95, step=0.05, value=DEFAULT_THRESHOLD,
182
+ label="Detection Sensitivity", interactive=True
 
 
 
183
  )
184
+ analyze_btn = gr.Button("πŸ” Analyze & Diagnose", variant="primary")
185
 
186
  with gr.Column():
187
+ result_label = gr.Label(label="Diagnosis Result")
188
  confidence = gr.Textbox(label="Confidence Score")
189
+ spectrogram = gr.Image(label="Acoustic Analysis")
190
+ recommendations = gr.Textbox(
191
+ label="Maintenance Recommendations",
192
+ lines=10,
193
+ interactive=False
194
  )
195
 
196
  analyze_btn.click(
197
  fn=analyze_audio,
198
  inputs=[audio_input, threshold],
199
+ outputs=[result_label, confidence, spectrogram, recommendations]
200
  )
201
 
202
  gr.Markdown("""
203
+ ### System Capabilities:
204
+ - Automatic anomaly detection in industrial equipment sounds
205
+ - Frequency pattern analysis to identify failure modes
206
+ - Equipment-specific maintenance recommendations
207
+ - Confidence-based urgency classification
208
+
209
+ **Tip:** For best results, use 5-10 second recordings of steady operation
210
  """)
211
 
212
  if __name__ == "__main__":