Spaces:
Runtime error
Runtime error
zero gpu
Browse files- apg_guidance.py +6 -5
- pipeline_ace_step.py +1 -0
apg_guidance.py
CHANGED
|
@@ -17,14 +17,15 @@ def project(
|
|
| 17 |
dims=[-1, -2],
|
| 18 |
):
|
| 19 |
dtype = v0.dtype
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 24 |
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
| 25 |
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
| 26 |
v0_orthogonal = v0 - v0_parallel
|
| 27 |
-
return v0_parallel.to(dtype), v0_orthogonal.to(dtype)
|
| 28 |
|
| 29 |
|
| 30 |
def apg_forward(
|
|
|
|
| 17 |
dims=[-1, -2],
|
| 18 |
):
|
| 19 |
dtype = v0.dtype
|
| 20 |
+
device_type = v0.device.type
|
| 21 |
+
if device_type == "mps":
|
| 22 |
+
v0, v1 = v0.cpu(), v1.cpu()
|
| 23 |
+
|
| 24 |
+
v0, v1 = v0.double(), v1.double()
|
| 25 |
v1 = torch.nn.functional.normalize(v1, dim=dims)
|
| 26 |
v0_parallel = (v0 * v1).sum(dim=dims, keepdim=True) * v1
|
| 27 |
v0_orthogonal = v0 - v0_parallel
|
| 28 |
+
return v0_parallel.to(dtype).to(device_type), v0_orthogonal.to(dtype).to(device_type)
|
| 29 |
|
| 30 |
|
| 31 |
def apg_forward(
|
pipeline_ace_step.py
CHANGED
|
@@ -955,6 +955,7 @@ class ACEStepPipeline:
|
|
| 955 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
| 956 |
return latents
|
| 957 |
|
|
|
|
| 958 |
def __call__(
|
| 959 |
self,
|
| 960 |
audio_duration: float = 60.0,
|
|
|
|
| 955 |
latents, _ = self.music_dcae.encode(input_audio, sr=sr)
|
| 956 |
return latents
|
| 957 |
|
| 958 |
+
@spaces.GPU
|
| 959 |
def __call__(
|
| 960 |
self,
|
| 961 |
audio_duration: float = 60.0,
|