Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,55 +7,46 @@ from diffusers import DiffusionPipeline
|
|
7 |
from pyannote.audio import Pipeline as PyannotePipeline
|
8 |
from dia.model import Dia
|
9 |
from dac.utils import load_model as load_dac_model
|
10 |
-
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
11 |
|
12 |
# Environment token from HF Secrets
|
13 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
14 |
-
device_map = "auto"
|
15 |
|
16 |
print("Loading models...")
|
17 |
|
18 |
-
# 1. RVQ Codec (Descript Audio Codec)
|
19 |
print("Loading RVQ Codec...")
|
20 |
rvq = load_dac_model(tag="latest", model_type="44khz")
|
21 |
rvq.eval()
|
22 |
if torch.cuda.is_available():
|
23 |
rvq = rvq.to("cuda")
|
24 |
|
25 |
-
# 2. Voice Activity Detection
|
26 |
print("Loading VAD...")
|
27 |
vad_pipe = PyannotePipeline.from_pretrained(
|
28 |
"pyannote/voice-activity-detection",
|
29 |
use_auth_token=HF_TOKEN
|
30 |
)
|
31 |
|
32 |
-
# 3. Ultravox ASR+LLM
|
33 |
print("Loading Ultravox...")
|
34 |
ultravox_pipe = pipeline(
|
35 |
model="fixie-ai/ultravox-v0_4",
|
36 |
trust_remote_code=True,
|
37 |
-
device_map=
|
38 |
torch_dtype=torch.float16
|
39 |
)
|
40 |
|
41 |
-
# 4. Audio Diffusion Model
|
42 |
print("Loading Audio Diffusion...")
|
43 |
diff_pipe = DiffusionPipeline.from_pretrained(
|
44 |
"teticio/audio-diffusion-instrumental-hiphop-256",
|
45 |
torch_dtype=torch.float16
|
46 |
).to("cuda")
|
47 |
|
48 |
-
# 5. Dia TTS Model
|
49 |
print("Loading Dia TTS...")
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
dia = load_checkpoint_and_dispatch(
|
54 |
-
dia,
|
55 |
-
"nari-labs/Dia-1.6B",
|
56 |
-
device_map=device_map,
|
57 |
-
dtype=torch.float16
|
58 |
-
)
|
59 |
|
60 |
print("All models loaded successfully!")
|
61 |
|
|
|
7 |
from pyannote.audio import Pipeline as PyannotePipeline
|
8 |
from dia.model import Dia
|
9 |
from dac.utils import load_model as load_dac_model
|
|
|
10 |
|
11 |
# Environment token from HF Secrets
|
12 |
HF_TOKEN = os.environ["HF_TOKEN"]
|
|
|
13 |
|
14 |
print("Loading models...")
|
15 |
|
16 |
+
# 1. Load RVQ Codec (Descript Audio Codec)
|
17 |
print("Loading RVQ Codec...")
|
18 |
rvq = load_dac_model(tag="latest", model_type="44khz")
|
19 |
rvq.eval()
|
20 |
if torch.cuda.is_available():
|
21 |
rvq = rvq.to("cuda")
|
22 |
|
23 |
+
# 2. Load Voice Activity Detection
|
24 |
print("Loading VAD...")
|
25 |
vad_pipe = PyannotePipeline.from_pretrained(
|
26 |
"pyannote/voice-activity-detection",
|
27 |
use_auth_token=HF_TOKEN
|
28 |
)
|
29 |
|
30 |
+
# 3. Load Ultravox ASR+LLM
|
31 |
print("Loading Ultravox...")
|
32 |
ultravox_pipe = pipeline(
|
33 |
model="fixie-ai/ultravox-v0_4",
|
34 |
trust_remote_code=True,
|
35 |
+
device_map="auto",
|
36 |
torch_dtype=torch.float16
|
37 |
)
|
38 |
|
39 |
+
# 4. Load Audio Diffusion Model
|
40 |
print("Loading Audio Diffusion...")
|
41 |
diff_pipe = DiffusionPipeline.from_pretrained(
|
42 |
"teticio/audio-diffusion-instrumental-hiphop-256",
|
43 |
torch_dtype=torch.float16
|
44 |
).to("cuda")
|
45 |
|
46 |
+
# 5. Load Dia TTS Model (WITHOUT meta tensor approach)
|
47 |
print("Loading Dia TTS...")
|
48 |
+
# Direct loading without init_empty_weights to avoid meta tensor issues
|
49 |
+
dia = Dia.from_pretrained("nari-labs/Dia-1.6B")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
print("All models loaded successfully!")
|
52 |
|