Can you provide an index for all your safe tensor model files?

#1
by Le222 - opened

Otherwise AutoModel cannot be used for your xl version.

CATIE org

Hi!

Thanks for the feedback. I'll look into it.
My current priority is the release of our FAT5 for this month (technical report to conclude + integration into transformers with the HF teams + various communications).
What you report must be part of it, but it could take a few more days nonetheless.

Any update on this? This looks like an incredible model, and it would be great to be able to use it.

Couldn't wait — tried fixing it myself. I created a version here with an attempt to restore the index:
https://huggingface.co/Thalesian/FAT5-xl-flan-en/tree/main

I also found that there is some kind of flaw in Fast_RMS_LayernormBackward with Triton that caused immediate gradient self-implosion.
If you can settle for partial Triton implementation, you can use the following patch to disable Triton in that particular layer:

class SafeRMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6, weight=None):
        super().__init__()
        self.dim = dim
        self.eps = eps
        if weight is None:
            self.weight = nn.Parameter(torch.ones(dim))
        else:
            self.weight = nn.Parameter(weight.detach().clone())

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        var = x.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x * torch.rsqrt(var + self.eps)
        return x_norm * self.weight
def patch_rmsnorm_in_module(module: nn.Module):
    for name, child in list(module.named_children()):
        cls_name = child.__class__.__name__.lower()
        if ("rms" in cls_name) and ("norm" in cls_name):
            weight = getattr(child, "weight", None)
            if weight is None:
                weight = getattr(child, "scale", None)
            if weight is not None:
                dim = weight.shape[-1]
            else:
                dim = getattr(child, "normalized_shape", None)
                if isinstance(dim, (tuple, list)):
                    dim = dim[-1]
                if dim is None:
                    raise ValueError(f"Could not infer dim for RMSNorm module: {child}")
            eps = getattr(child, "eps", getattr(child, "epsilon", 1e-6))
            print(f"Replacing {child.__class__.__name__} at '{name}' with SafeRMSNorm(dim={dim}, eps={eps})")
            safe = SafeRMSNorm(dim=dim, eps=eps, weight=weight)
            setattr(module, name, safe)
        else:
            patch_rmsnorm_in_module(child)
import torch.nn as nn

def disable_fat5_triton_layernorm(model: nn.Module):
    count = 0
    for module in model.modules():
        if module.__class__.__name__ == "FlashT5LayerNorm":
            module.use_triton_layernorm = False
            count += 1
    print(f"Disabled Triton layernorm in {count} FlashT5LayerNorm modules.")

def inspect_fat5_flags(model, max_print=20):
    for name, module in model.named_modules():
        if module.__class__.__name__ == "FlashT5LayerNorm":
            print(f"[LN] {name}: use_triton_layernorm={getattr(module, 'use_triton_layernorm', None)}")
        if hasattr(module, "use_flash_attention"):
            print(f"[ATTN] {name}: use_flash_attention={module.use_flash_attention}")
        if hasattr(module, "use_triton_crossentropy"):
            print(f"[CE] {name}: use_triton_crossentropy={module.use_triton_crossentropy}")

And run it like this:

model = AutoModelForSeq2SeqLM.from_pretrained(
    "Thalesian/FAT5-xl-flan-en", 
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl", use_fast=True, legacy=False)

for attr in ["use_triton_crossentropy", "use_triton_layernorm", "use_flash_attention"]:
    if hasattr(model.config, attr):
        setattr(model.config, attr, True)

disable_fat5_triton_layernorm(model)

tied_keys = [
    "shared.weight",
    "encoder.embed_tokens.weight",
    "decoder.embed_tokens.weight",
]
setattr(model, "_dynamic_tied_weights_keys", tied_keys)

Note that after loading the model, you must always call disable_fat5_triton_layernorm(model) before training.

Oh, sorry, that's a point that was forgotten to be addressed and neither of the two authors who contributed to this project still works at CATIE.
Therefore, no follow-up will be conducted about FAT5 (here on HF or on the GitHub repo).

Thank you for suggesting a solution to the problem. I will add a link in the model card referring to your message 🤗

Sign up or log in to comment