Spaces:
Running
on
Zero
Running
on
Zero
# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
# // | |
# // Licensed under the Apache License, Version 2.0 (the "License"); | |
# // you may not use this file except in compliance with the License. | |
# // You may obtain a copy of the License at | |
# // | |
# // http://www.apache.org/licenses/LICENSE-2.0 | |
# // | |
# // Unless required by applicable law or agreed to in writing, software | |
# // distributed under the License is distributed on an "AS IS" BASIS, | |
# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# // See the License for the specific language governing permissions and | |
# // limitations under the License. | |
from dataclasses import dataclass | |
from typing import Any, Callable, Dict, List, Tuple | |
import torch | |
from torch import nn | |
class MMArg: | |
vid: Any | |
txt: Any | |
def get_args(key: str, args: List[Any]) -> List[Any]: | |
return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] | |
def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: | |
return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} | |
class MMModule(nn.Module): | |
def __init__( | |
self, | |
module: Callable[..., nn.Module], | |
*args, | |
shared_weights: bool = False, | |
vid_only: bool = False, | |
**kwargs, | |
): | |
super().__init__() | |
self.shared_weights = shared_weights | |
self.vid_only = vid_only | |
if self.shared_weights: | |
assert get_args("vid", args) == get_args("txt", args) | |
assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) | |
self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) | |
else: | |
self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) | |
self.txt = ( | |
module(*get_args("txt", args), **get_kwargs("txt", kwargs)) | |
if not vid_only | |
else None | |
) | |
def forward( | |
self, | |
vid: torch.FloatTensor, | |
txt: torch.FloatTensor, | |
*args, | |
**kwargs, | |
) -> Tuple[ | |
torch.FloatTensor, | |
torch.FloatTensor, | |
]: | |
vid_module = self.vid if not self.shared_weights else self.all | |
vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) | |
if not self.vid_only: | |
txt_module = self.txt if not self.shared_weights else self.all | |
txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) | |
return vid, txt | |