Spaces:
Build error
Build error
Commit
·
3859913
1
Parent(s):
9fe6f54
Add full application code and deps
Browse files- Dockerfile +15 -0
- app.py +4 -0
- app/__init__.py +0 -0
- app/accent_classifier.py +32 -0
- app/config.py +10 -0
- app/main.py +43 -0
- app/utils.py +70 -0
- requirements.txt +13 -0
- ui/demo.py +54 -0
Dockerfile
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim
|
2 |
+
|
3 |
+
RUN apt-get update && apt-get install -y --no-install-recommends ffmpeg git && rm -rf /var/lib/apt/lists/*
|
4 |
+
|
5 |
+
WORKDIR /app
|
6 |
+
COPY requirements.txt .
|
7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
8 |
+
|
9 |
+
COPY . .
|
10 |
+
|
11 |
+
ENV PYTHONUNBUFFERED=1 \
|
12 |
+
TRANSFORMERS_CACHE=/app/.cache/hf
|
13 |
+
|
14 |
+
# default: FastAPI, but you can override CMD in HF Spaces Runtime
|
15 |
+
CMD ["uvicorn", "app.main:api", "--host", "0.0.0.0", "--port", "8000"]
|
app.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Entry‑point for Hugging Face Spaces (Gradio SDK)."""
|
2 |
+
from ui.demo import demo
|
3 |
+
|
4 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
app/__init__.py
ADDED
File without changes
|
app/accent_classifier.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Lightweight Wav2Vec‑based English accent ID."""
|
2 |
+
from functools import lru_cache
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
from transformers import AutoProcessor, AutoModelForAudioClassification
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from .config import BASE_MODEL_ID, SEGMENT_SECONDS
|
11 |
+
|
12 |
+
class AccentClassifier:
|
13 |
+
def __init__(self, device: str | None = None):
|
14 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
self.processor = AutoProcessor.from_pretrained(BASE_MODEL_ID)
|
16 |
+
self.model = AutoModelForAudioClassification.from_pretrained(BASE_MODEL_ID).to(self.device)
|
17 |
+
self.labels = self.model.config.id2label
|
18 |
+
|
19 |
+
@torch.inference_mode()
|
20 |
+
def classify(self, wav: str) -> Dict[str, str | int]:
|
21 |
+
wav_arr, sr = torchaudio.load(wav)
|
22 |
+
wav_arr = wav_arr.squeeze(0).numpy()[: sr * SEGMENT_SECONDS]
|
23 |
+
inp = self.processor(wav_arr, sampling_rate=sr, return_tensors="pt")
|
24 |
+
inp = {k: v.to(self.device) for k, v in inp.items()}
|
25 |
+
logits = self.model(**inp).logits[0]
|
26 |
+
probs = torch.softmax(logits, dim=-1).cpu().numpy()
|
27 |
+
idx = int(np.argmax(probs))
|
28 |
+
return {"accent": self.labels[idx], "confidence": int(probs[idx] * 100)}
|
29 |
+
|
30 |
+
@lru_cache(maxsize=1)
|
31 |
+
def get_classifier() -> AccentClassifier:
|
32 |
+
return AccentClassifier()
|
app/config.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Centralised paths & constants."""
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
ROOT = Path(__file__).resolve().parent.parent
|
5 |
+
MODELS_DIR = ROOT / "models"
|
6 |
+
AUDIO_CACHE = ROOT / "cache" / "audio"
|
7 |
+
AUDIO_CACHE.mkdir(parents=True, exist_ok=True)
|
8 |
+
|
9 |
+
BASE_MODEL_ID = "dima806/english_accents_classification" # HF model
|
10 |
+
SEGMENT_SECONDS = 30 # audio length fed to the classifier
|
app/main.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException
|
2 |
+
from pydantic import BaseModel
|
3 |
+
from pathlib import Path
|
4 |
+
import tempfile
|
5 |
+
import asyncio
|
6 |
+
|
7 |
+
from .utils import download_video, extract_audio, trim_silence
|
8 |
+
from .accent_classifier import get_classifier
|
9 |
+
|
10 |
+
api = FastAPI()
|
11 |
+
|
12 |
+
class URLBody(BaseModel):
|
13 |
+
url: str
|
14 |
+
|
15 |
+
@api.post("/analyze/url")
|
16 |
+
async def analyze_url(body: URLBody):
|
17 |
+
with tempfile.TemporaryDirectory() as td:
|
18 |
+
tdir = Path(td)
|
19 |
+
video = await download_video(body.url, tdir)
|
20 |
+
wav = tdir / "aud.wav"
|
21 |
+
await extract_audio(video, wav)
|
22 |
+
wav = trim_silence(wav)
|
23 |
+
return get_classifier().classify(str(wav))
|
24 |
+
|
25 |
+
@api.post("/analyze/upload")
|
26 |
+
async def analyze_upload(file: UploadFile = File(...)):
|
27 |
+
if not file.content_type.startswith(("audio", "video")):
|
28 |
+
raise HTTPException(400, "Unsupported file type")
|
29 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=file.filename) as tmp:
|
30 |
+
tmp.write(await file.read())
|
31 |
+
tmp.flush()
|
32 |
+
path = Path(tmp.name)
|
33 |
+
if file.content_type.startswith("video"):
|
34 |
+
wav = path.with_suffix(".wav")
|
35 |
+
await extract_audio(path, wav)
|
36 |
+
else:
|
37 |
+
wav = path
|
38 |
+
wav = trim_silence(wav)
|
39 |
+
return get_classifier().classify(str(wav))
|
40 |
+
|
41 |
+
@api.get("/healthz")
|
42 |
+
async def health():
|
43 |
+
return {"status": "ok"}
|
app/utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import subprocess
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import torchaudio
|
7 |
+
from yt_dlp import YoutubeDL
|
8 |
+
import webrtcvad
|
9 |
+
|
10 |
+
from .config import AUDIO_CACHE
|
11 |
+
|
12 |
+
# ---------------------------------------------------------------------------
|
13 |
+
# ffmpeg helpers
|
14 |
+
# ---------------------------------------------------------------------------
|
15 |
+
|
16 |
+
def _run(cmd: List[str]):
|
17 |
+
proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
18 |
+
if proc.returncode != 0:
|
19 |
+
raise RuntimeError(proc.stderr.decode())
|
20 |
+
|
21 |
+
# ---------------------------------------------------------------------------
|
22 |
+
# Video → Audio
|
23 |
+
# ---------------------------------------------------------------------------
|
24 |
+
async def download_video(url: str, out_dir: Path) -> Path:
|
25 |
+
"""Async wrapper around yt‑dlp to pull remote video assets."""
|
26 |
+
ydl_opts = {
|
27 |
+
"quiet": True,
|
28 |
+
"no_warnings": True,
|
29 |
+
"outtmpl": str(out_dir / "download.%(ext)s"),
|
30 |
+
"format": "bestvideo+bestaudio/best / best",
|
31 |
+
}
|
32 |
+
loop = asyncio.get_running_loop()
|
33 |
+
|
34 |
+
def _job():
|
35 |
+
with YoutubeDL(ydl_opts) as ydl:
|
36 |
+
ydl.download([url])
|
37 |
+
|
38 |
+
await loop.run_in_executor(None, _job)
|
39 |
+
return next(out_dir.glob("download.*"))
|
40 |
+
|
41 |
+
async def extract_audio(video_path: Path, wav_path: Path, sr: int = 16000):
|
42 |
+
cmd = [
|
43 |
+
"ffmpeg", "-y", "-i", str(video_path),
|
44 |
+
"-vn", "-ac", "1", "-ar", str(sr), str(wav_path)
|
45 |
+
]
|
46 |
+
loop = asyncio.get_running_loop()
|
47 |
+
await loop.run_in_executor(None, _run, cmd)
|
48 |
+
|
49 |
+
# ---------------------------------------------------------------------------
|
50 |
+
# VAD trimming (WebRTC)
|
51 |
+
# ---------------------------------------------------------------------------
|
52 |
+
|
53 |
+
def _frame_gen(frame_ms, pcm16, sr):
|
54 |
+
n = int(sr * (frame_ms / 1000.0) * 2)
|
55 |
+
for i in range(0, len(pcm16), n):
|
56 |
+
yield pcm16[i : i + n]
|
57 |
+
|
58 |
+
def trim_silence(wav_path: Path, aggressiveness: int = 3) -> Path:
|
59 |
+
sig, sr = torchaudio.load(str(wav_path))
|
60 |
+
sig = sig.squeeze(0).numpy()
|
61 |
+
vad = webrtcvad.Vad(aggressiveness)
|
62 |
+
frames = list(_frame_gen(30, (sig * 32768).astype("int16").tobytes(), sr))
|
63 |
+
voiced = [vad.is_speech(f, sr) for f in frames]
|
64 |
+
if not any(voiced):
|
65 |
+
return wav_path
|
66 |
+
first, last = voiced.index(True), len(voiced) - 1 - voiced[::-1].index(True)
|
67 |
+
kept = sig[first * 480 : (last + 1) * 480]
|
68 |
+
out = wav_path.with_name(wav_path.stem + "_trim.wav")
|
69 |
+
torchaudio.save(str(out), torchaudio.tensor(kept).unsqueeze(0), sr)
|
70 |
+
return out
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.111.0 # only needed if you use /app/main.py via uvicorn
|
2 |
+
uvicorn[standard]==0.30.1
|
3 |
+
yt-dlp==2024.05.27
|
4 |
+
ffmpeg-python==0.2.0
|
5 |
+
webrtcvad==2.0.10
|
6 |
+
transformers==4.41.1
|
7 |
+
accelerate==0.30.1
|
8 |
+
torch>=2.3.0
|
9 |
+
scikit-learn==1.5.0
|
10 |
+
pydantic==2.7.1
|
11 |
+
torchaudio==2.3.0
|
12 |
+
gradio==4.34.0
|
13 |
+
aiohttp==3.9.5
|
ui/demo.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import gradio as gr
|
3 |
+
import tempfile
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from app.utils import download_video, extract_audio, trim_silence
|
7 |
+
from app.accent_classifier import get_classifier
|
8 |
+
|
9 |
+
clf = get_classifier()
|
10 |
+
|
11 |
+
async def _url_pipeline(url: str):
|
12 |
+
with tempfile.TemporaryDirectory() as td:
|
13 |
+
tdir = Path(td)
|
14 |
+
video = await download_video(url, tdir)
|
15 |
+
wav = tdir / "aud.wav"
|
16 |
+
await extract_audio(video, wav)
|
17 |
+
wav = trim_silence(wav)
|
18 |
+
return clf.classify(str(wav))
|
19 |
+
|
20 |
+
def analyze_url(url: str):
|
21 |
+
return asyncio.run(_url_pipeline(url))
|
22 |
+
|
23 |
+
|
24 |
+
def analyze_file(file):
|
25 |
+
path = Path(file.name)
|
26 |
+
if path.suffix.lower() in {".mp4", ".mov", ".mkv"}:
|
27 |
+
wav = path.with_suffix(".wav")
|
28 |
+
asyncio.run(extract_audio(path, wav))
|
29 |
+
else:
|
30 |
+
wav = path
|
31 |
+
wav = trim_silence(wav)
|
32 |
+
return clf.classify(str(wav))
|
33 |
+
|
34 |
+
|
35 |
+
def fmt(res):
|
36 |
+
if not res:
|
37 |
+
return "Analysis failed."
|
38 |
+
return f"**Accent:** {res['accent']}\n\n**Confidence:** {res['confidence']}%"
|
39 |
+
|
40 |
+
with gr.Blocks(title="English Accent Detector") as demo:
|
41 |
+
gr.Markdown("## REM Waste – Accent Screening Tool")
|
42 |
+
with gr.Tab("From URL"):
|
43 |
+
url_in = gr.Text(label="Public video URL (Loom, MP4, YouTube, …)")
|
44 |
+
btn = gr.Button("Analyze")
|
45 |
+
out = gr.Markdown()
|
46 |
+
btn.click(lambda u: fmt(analyze_url(u)), inputs=url_in, outputs=out)
|
47 |
+
with gr.Tab("Upload File"):
|
48 |
+
file_in = gr.File()
|
49 |
+
btn2 = gr.Button("Analyze")
|
50 |
+
out2 = gr.Markdown()
|
51 |
+
btn2.click(lambda f: fmt(analyze_file(f)), inputs=file_in, outputs=out2)
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
demo.launch()
|