import torch import numpy as np import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel from models import MMadaModelLM def add_gumbel_noise(logits, temperature): ''' The Gumbel max is a method for sampling categorical distributions. According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. Thus, we use float64. ''' if temperature == 0: return logits logits = logits.to(torch.float64) noise = torch.rand_like(logits, dtype=torch.float64) gumbel_noise = (- torch.log(noise)) ** temperature return logits.exp() / gumbel_noise def get_num_transfer_tokens(mask_index, steps): ''' In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), the expected number of tokens transitioned at each step should be consistent. This function is designed to precompute the number of tokens that need to be transitioned at each step. ''' mask_num = mask_index.sum(dim=1, keepdim=True) base = mask_num // steps remainder = mask_num % steps num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base for i in range(mask_num.size(0)): num_transfer_tokens[i, :remainder[i]] += 1 return num_transfer_tokens @ torch.no_grad() def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., cfg_scale=0., remasking='low_confidence', mask_id=126336, attention_mask=None): ''' Args: model: Mask predictor. prompt: A tensor of shape (B, L), where B is batch size. steps: Sampling steps, less than or equal to gen_length. gen_length: Generated answer length. block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. temperature: Categorical distribution sampling temperature. cfg_scale: Unsupervised classifier-free guidance scale. remasking: Remasking strategy. 'low_confidence' or 'random'. mask_id: The toke id of [MASK] is 126336. ''' if attention_mask is not None and 0.0 in attention_mask: attention_bias = (attention_mask[:, :, None] & attention_mask[:, None, :]).bool().unsqueeze(1) print(f"attention_bias: {attention_bias}") else: attention_bias = None batch_size = prompt.shape[0] x = torch.full((batch_size, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) x[:, :prompt.shape[1]] = prompt.clone() prompt_index = (x != mask_id) assert gen_length % block_length == 0 num_blocks = gen_length // block_length assert steps % num_blocks == 0 steps = steps // num_blocks for num_block in range(num_blocks): block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) for i in range(steps): mask_index = (x == mask_id) if cfg_scale > 0.: un_x = x.clone() un_x[prompt_index] = mask_id x_ = torch.cat([x, un_x], dim=0) logits = model(x_).logits logits, un_logits = torch.chunk(logits, 2, dim=0) logits = un_logits + (cfg_scale + 1) * (logits - un_logits) else: logits = model(x, attention_bias=attention_bias).logits logits_with_noise = add_gumbel_noise(logits, temperature=temperature) x0 = torch.argmax(logits_with_noise, dim=-1) # b, l if remasking == 'low_confidence': p = F.softmax(logits.to(torch.float64), dim=-1) x0_p = torch.squeeze( torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l elif remasking == 'random': x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) else: raise NotImplementedError(remasking) x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf x0 = torch.where(mask_index, x0, x) confidence = torch.where(mask_index, x0_p, -np.inf) # print(confidence.shape) transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) for j in range(confidence.shape[0]): _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) transfer_index[j, select_index] = True x[transfer_index] = x0[transfer_index] return x def main(): device = 'cuda' model = MMadaModelLM.from_pretrained("/data_storage/ty/MMaDA/mmada-training-stage4-llada-instruct/checkpoint-170000/unwrapped_model", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() tokenizer = AutoTokenizer.from_pretrained("/data_storage/ty/MMaDA/mmada-training-stage4-llada-instruct/checkpoint-170000/unwrapped_model", trust_remote_code=True) tokenizer.chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}" prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?" m = [{"role": "user", "content": prompt}, ] prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) input_ids = tokenizer(text=prompt, return_tensors="pt", padding=True, padding_side="left")['input_ids'] input_ids = torch.tensor(input_ids).to(device) out = generate(model, input_ids, steps=128, gen_length=128, block_length=128, temperature=1, cfg_scale=0., remasking='low_confidence') print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)) if __name__ == '__main__': main()