gramt-binaural-time / Patcher.py
GokseninYuksel's picture
Upload model
f0e612b verified
from abc import ABC
from .patching_utils import combine_patches, generate_patches, get_shape
class PatchStrategy(ABC):
def __init__(self, tstride, tshape, fstride, fshape, input_fdim, input_tdim):
self.tstride = tstride
self.tshape = tshape
self.fstride = fstride
self.fshape = fshape
self.input_fdim = input_fdim
self.input_tdim = input_tdim
def _patch(self, x):
patches = generate_patches(
input=x,
fstride=self.fstride,
tstride=self.tstride,
fshape=self.fshape,
tshape=self.tshape,
)
return patches
def patch(self, x):
return self._patch(x)
def embed(self, x, patch_embedder):
return patch_embedder(x)
def patch_and_embed(self, x, patch_embedder):
"""
Generate patches from the input spectrogram and embed them.
This method creates patches based on the frequency and temporal stride/shape
parameters, and then applies the given patch embedding function.
Parameters
----------
x : torch.Tensor
The input spectrogram tensor to be patched and embedded.
patch_embedder : Callable
A function that applies embedding to the patches.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
The generated patches and their embeddings.
"""
# Generate patches for knowing the input.
patches = generate_patches(
input=x,
fstride=self.fstride,
tstride=self.tstride,
fshape=self.fshape,
tshape=self.tshape,
)
x = patch_embedder(x)
return patches, x
def get_patch_size(self):
p_f_dim, p_t_dim = get_shape(
fstride=self.fstride,
tstride=self.tstride,
input_fdim=self.input_fdim,
input_tdim=self.input_tdim,
fshape=self.fshape,
tshape=self.tshape,
)
return p_f_dim, p_t_dim
def combine_patches(self, patches, original_size):
return combine_patches(
patches, original_size, self.fstride, self.tstride, self.fshape, self.tshape
)
class TimePatching(PatchStrategy):
def __init__(
self, input_tdim, tstride=2, tshape=2, fstride=128, fshape=128, input_fdim=128
):
super().__init__(
tstride=tstride,
tshape=tshape,
fstride=fstride,
fshape=fshape,
input_fdim=input_fdim,
input_tdim=input_tdim,
)
class FramePatching(PatchStrategy):
def __init__(
self, input_tdim, tstride=16, tshape=16, fstride=16, fshape=16, input_fdim=128
):
super().__init__(
tstride=tstride,
tshape=tshape,
fstride=fstride,
fshape=fshape,
input_fdim=input_fdim,
input_tdim=input_tdim,
)