Tar / t2i_inference.py
hanjiaming.0208
add 512px AR
3f9caff
import re
from dataclasses import dataclass
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import AutoTokenizer, Qwen2ForCausalLM
from tok.mm_autoencoder import MMAutoEncoder
@dataclass
class T2IConfig:
model_path: str = "csuhan/Tar-1.5B"
# visual tokenizer config
ar_path = None
encoder_path: str = 'ta_tok.pth'
decoder_path: str = 'vq_ds16_t2i.pt'
device: str = "cuda:0"
dtype: torch.dtype = torch.bfloat16
# generation parameters
scale: int = 0 # choose from [0, 1, 2]
seq_len: int = 729 # choose from [729, 169, 81]
temperature: float = 1.0
top_p: float = 0.95
top_k: int = 1200
cfg_scale: float = 4.0
class TextToImageInference:
def __init__(self, config: T2IConfig):
self.config = config
self.device = torch.device(config.device)
self._load_models()
def _load_models(self):
self.model = Qwen2ForCausalLM.from_pretrained(self.config.model_path, torch_dtype=self.config.dtype).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_path)
# Initialize visual tokenizer
config = dict(
ar_path_dict=self.config.ar_path,
encoder_path=self.config.encoder_path,
decoder_path=self.config.decoder_path,
encoder_args={'input_type': 'rec'},
decoder_args={},
)
self.visual_tokenizer = MMAutoEncoder(**config).eval().to(dtype=self.config.dtype, device=self.device)
for ar_model in self.visual_tokenizer.ar_model.values():
ar_model.cls_token_num = self.config.seq_len
self.visual_tokenizer.encoder.pool_scale = self.config.scale + 1
def generate_image(self, prompt, resolution, top_p, top_k, cfg_scale) -> Image.Image:
# Prepare prompt
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
input_text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True)
input_text += f"<im_start><S{self.config.scale}>"
# Generate tokens
inputs = self.tokenizer(input_text, return_tensors="pt")
gen_ids = self.model.generate(
inputs.input_ids.to(self.device),
max_new_tokens=self.config.seq_len,
do_sample=True,
temperature=self.config.temperature,
top_p=top_p,
top_k=top_k)
# Process generated tokens
gen_text = self.tokenizer.batch_decode(gen_ids)[0]
gen_code = [int(x) for x in re.findall(r'<I(\d+)>', gen_text)]
gen_code = gen_code[:self.config.seq_len] + [0] * max(0, self.config.seq_len - len(gen_code))
gen_code = torch.tensor(gen_code).unsqueeze(0).to(self.device)
gen_tensor = self.visual_tokenizer.decode_from_encoder_indices(
gen_code,
{'cfg_scale': cfg_scale, 'resolution': resolution},
)
gen_image = Image.fromarray(gen_tensor[0].numpy())
return gen_image