| from transformers.configuration_utils import PretrainedConfig | |
| import sys | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForCausalLM, | |
| LlamaConfig, | |
| LlamaForCausalLM, | |
| PreTrainedModel, | |
| ) | |
| from .attrdict_config import AttrDict | |
| class VisionConfig(PretrainedConfig): | |
| model_type = "vision" | |
| cls: str = "" | |
| params: AttrDict = {} | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.cls = kwargs.get("cls", "") | |
| if not isinstance(self.cls, str): | |
| self.cls = self.cls.__name__ | |
| self.params = AttrDict(kwargs.get("params", {})) | |
| class AlignerConfig(PretrainedConfig): | |
| model_type = "aligner" | |
| cls: str = "" | |
| params: AttrDict = {} | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.cls = kwargs.get("cls", "") | |
| if not isinstance(self.cls, str): | |
| self.cls = self.cls.__name__ | |
| self.params = AttrDict(kwargs.get("params", {})) | |
| class GenVisionConfig(PretrainedConfig): | |
| model_type = "gen_vision" | |
| cls: str = "" | |
| params: AttrDict = {} | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.cls = kwargs.get("cls", "") | |
| if not isinstance(self.cls, str): | |
| self.cls = self.cls.__name__ | |
| self.params = AttrDict(kwargs.get("params", {})) | |
| class GenAlignerConfig(PretrainedConfig): | |
| model_type = "gen_aligner" | |
| cls: str = "" | |
| params: AttrDict = {} | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.cls = kwargs.get("cls", "") | |
| if not isinstance(self.cls, str): | |
| self.cls = self.cls.__name__ | |
| self.params = AttrDict(kwargs.get("params", {})) | |
| class GenHeadConfig(PretrainedConfig): | |
| model_type = "gen_head" | |
| cls: str = "" | |
| params: AttrDict = {} | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.cls = kwargs.get("cls", "") | |
| if not isinstance(self.cls, str): | |
| self.cls = self.cls.__name__ | |
| self.params = AttrDict(kwargs.get("params", {})) | |
| class MultiModalityConfig(PretrainedConfig): | |
| model_type = "multi_modality" | |
| vision_config: VisionConfig | |
| aligner_config: AlignerConfig | |
| gen_vision_config: GenVisionConfig | |
| gen_aligner_config: GenAlignerConfig | |
| gen_head_config: GenHeadConfig | |
| language_config: LlamaConfig | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| vision_config = kwargs.get("vision_config", {}) | |
| self.vision_config = VisionConfig(**vision_config) | |
| aligner_config = kwargs.get("aligner_config", {}) | |
| self.aligner_config = AlignerConfig(**aligner_config) | |
| gen_vision_config = kwargs.get("gen_vision_config", {}) | |
| self.gen_vision_config = GenVisionConfig(**gen_vision_config) | |
| gen_aligner_config = kwargs.get("gen_aligner_config", {}) | |
| self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) | |
| gen_head_config = kwargs.get("gen_head_config", {}) | |
| self.gen_head_config = GenHeadConfig(**gen_head_config) | |
| language_config = kwargs.get("language_config", {}) | |
| if isinstance(language_config, LlamaConfig): | |
| self.language_config = language_config | |
| else: | |
| self.language_config = LlamaConfig(**language_config) | |