Spaces:
Running
on
A100
Running
on
A100
| import json | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, T5ForConditionalGeneration | |
| def load_glyph_byT5_v2(args, device): | |
| """ | |
| Loads ByT5 tokenizer and encoder model for glyph encoding. | |
| Args: | |
| args (dict): Configuration dictionary containing paths and settings. | |
| device (str or torch.device): Device to load the model onto. | |
| Returns: | |
| dict: Dictionary with keys 'byt5_tokenizer', 'byt5_model', 'byt5_max_length'. | |
| """ | |
| byt5_tokenizer, byt5_model, byt5_max_length = create_byt5(args, device) | |
| byt5_model = byt5_model.to(device=device) | |
| return { | |
| "byt5_tokenizer": byt5_tokenizer, | |
| "byt5_model": byt5_model, | |
| "byt5_max_length": byt5_max_length, | |
| } | |
| def create_byt5(args, device): | |
| """ | |
| Create ByT5 tokenizer and encoder, load weights if provided. | |
| Args: | |
| args (dict): Configuration dictionary. | |
| device (str or torch.device): Device to load the model onto. | |
| Returns: | |
| tuple: (byt5_tokenizer, byt5_model, byt5_max_length) | |
| """ | |
| byt5_max_length = args['byt5_max_length'] | |
| byt5_config = dict( | |
| byt5_name=args['byT5_google_path'], | |
| special_token=True, | |
| color_special_token=True, | |
| font_special_token=True, | |
| color_ann_path=args['multilingual_prompt_format_color_path'], | |
| font_ann_path=args['multilingual_prompt_format_font_path'], | |
| multilingual=True, | |
| ) | |
| huggingface_cache_dir = None | |
| byt5_model, byt5_tokenizer = load_byt5_and_byt5_tokenizer( | |
| **byt5_config, | |
| huggingface_cache_dir=huggingface_cache_dir, | |
| device=device, | |
| ) | |
| # Load custom checkpoint if provided | |
| if args['byT5_ckpt_path'] is not None: | |
| if "cuda" not in str(device): | |
| byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=f"cuda:{device}") | |
| else: | |
| byt5_state_dict = torch.load(args['byT5_ckpt_path'], map_location=device) | |
| if 'state_dict' in byt5_state_dict: | |
| sd = byt5_state_dict["state_dict"] | |
| newsd = {} | |
| for k, v in sd.items(): | |
| if k.startswith('module.text_tower.encoder.'): | |
| newsd[k[len('module.text_tower.encoder.'):]] = v | |
| byt5_state_dict = newsd | |
| byt5_model.load_state_dict(byt5_state_dict) | |
| byt5_model.requires_grad_(False) | |
| return byt5_tokenizer, byt5_model, byt5_max_length | |
| def add_special_token( | |
| tokenizer, | |
| text_encoder, | |
| add_color, | |
| add_font, | |
| color_ann_path, | |
| font_ann_path, | |
| multilingual=False, | |
| ): | |
| """ | |
| Add special tokens for color and font to tokenizer and text encoder. | |
| Args: | |
| tokenizer: Huggingface tokenizer. | |
| text_encoder: Huggingface T5 encoder. | |
| add_color (bool): Whether to add color tokens. | |
| add_font (bool): Whether to add font tokens. | |
| color_ann_path (str): Path to color annotation JSON. | |
| font_ann_path (str): Path to font annotation JSON. | |
| multilingual (bool): Whether to use multilingual font tokens. | |
| """ | |
| with open(font_ann_path, 'r') as f: | |
| idx_font_dict = json.load(f) | |
| with open(color_ann_path, 'r') as f: | |
| idx_color_dict = json.load(f) | |
| if multilingual: | |
| font_token = [f'<{font_code[:2]}-font-{idx_font_dict[font_code]}>' for font_code in idx_font_dict] | |
| else: | |
| font_token = [f'<font-{i}>' for i in range(len(idx_font_dict))] | |
| color_token = [f'<color-{i}>' for i in range(len(idx_color_dict))] | |
| additional_special_tokens = [] | |
| if add_color: | |
| additional_special_tokens += color_token | |
| if add_font: | |
| additional_special_tokens += font_token | |
| tokenizer.add_tokens(additional_special_tokens, special_tokens=True) | |
| # Set mean_resizing=False to avoid PyTorch LAPACK dependency | |
| text_encoder.resize_token_embeddings(len(tokenizer), mean_resizing=False) | |
| def load_byt5_and_byt5_tokenizer( | |
| byt5_name='google/byt5-small', | |
| special_token=False, | |
| color_special_token=False, | |
| font_special_token=False, | |
| color_ann_path='assets/color_idx.json', | |
| font_ann_path='assets/font_idx_512.json', | |
| huggingface_cache_dir=None, | |
| multilingual=False, | |
| device=None, | |
| ): | |
| """ | |
| Load ByT5 encoder and tokenizer from Huggingface, and add special tokens if needed. | |
| Args: | |
| byt5_name (str): Model name or path. | |
| special_token (bool): Whether to add special tokens. | |
| color_special_token (bool): Whether to add color tokens. | |
| font_special_token (bool): Whether to add font tokens. | |
| color_ann_path (str): Path to color annotation JSON. | |
| font_ann_path (str): Path to font annotation JSON. | |
| huggingface_cache_dir (str): Huggingface cache directory. | |
| multilingual (bool): Whether to use multilingual font tokens. | |
| device (str or torch.device): Device to load the model onto. | |
| Returns: | |
| tuple: (byt5_text_encoder, byt5_tokenizer) | |
| """ | |
| byt5_tokenizer = AutoTokenizer.from_pretrained( | |
| byt5_name, | |
| cache_dir=huggingface_cache_dir, | |
| ) | |
| byt5_text_encoder = T5ForConditionalGeneration.from_pretrained( | |
| byt5_name, | |
| cache_dir=huggingface_cache_dir, | |
| ).get_encoder() | |
| if "cuda" not in str(device): | |
| device = torch.device(f"cuda:{device}") | |
| else: | |
| device = torch.device(device) | |
| byt5_text_encoder = byt5_text_encoder.to(device) | |
| if special_token: | |
| add_special_token( | |
| byt5_tokenizer, | |
| byt5_text_encoder, | |
| add_color=color_special_token, | |
| add_font=font_special_token, | |
| color_ann_path=color_ann_path, | |
| font_ann_path=font_ann_path, | |
| multilingual=multilingual, | |
| ) | |
| return byt5_text_encoder, byt5_tokenizer | |
| class ByT5Mapper(nn.Module): | |
| """ | |
| ByT5Mapper: Maps ByT5 encoder outputs to a new space, with optional residual connection. | |
| Args: | |
| in_dim (int): Input dimension (must equal out_dim if use_residual). | |
| out_dim (int): Output dimension after second linear layer. | |
| hidden_dim (int): Hidden dimension for intermediate layer. | |
| out_dim1 (int): Final output dimension. | |
| use_residual (bool): Whether to use residual connection (default: True). | |
| """ | |
| def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_residual=True): | |
| super().__init__() | |
| if use_residual: | |
| assert in_dim == out_dim | |
| self.layernorm = nn.LayerNorm(in_dim) | |
| self.fc1 = nn.Linear(in_dim, hidden_dim) | |
| self.fc2 = nn.Linear(hidden_dim, out_dim) | |
| self.fc3 = nn.Linear(out_dim, out_dim1) | |
| self.use_residual = use_residual | |
| self.act_fn = nn.GELU() | |
| def forward(self, x): | |
| """ | |
| Forward pass for ByT5Mapper. | |
| Args: | |
| x (Tensor): Input tensor of shape (..., in_dim). | |
| Returns: | |
| Tensor: Output tensor of shape (..., out_dim1). | |
| """ | |
| residual = x | |
| x = self.layernorm(x) | |
| x = self.fc1(x) | |
| x = self.act_fn(x) | |
| x = self.fc2(x) | |
| x2 = self.act_fn(x) | |
| x2 = self.fc3(x2) | |
| if self.use_residual: | |
| x2 = x2 + residual | |
| return x2 | |