File size: 4,413 Bytes
8b54513
 
3c55139
22ff2b2
5bbfa70
3f9caff
3c55139
8b54513
3c55139
 
 
288480f
bb89818
3c55139
8b54513
3c55139
 
 
 
8b54513
3c55139
 
 
 
 
 
 
 
 
 
 
22ff2b2
3f9caff
 
 
 
3c55139
 
 
 
3f9caff
 
3c55139
8b54513
3c55139
 
8b54513
3c55139
 
8b54513
3c55139
 
8b54513
3c55139
 
 
 
8b54513
3c55139
8b54513
5bb0eb0
5813aac
3c55139
 
 
 
 
 
 
 
 
3f59c4c
3f9caff
 
3c55139
 
 
 
 
 
 
 
 
 
 
3f9caff
3c55139
 
 
 
 
 
 
 
 
8b54513
3c55139
 
a66a8c1
3c55139
 
 
 
fc68709
3c55139
 
 
 
8b54513
 
3c55139
 
 
8b54513
d25ae12
3c55139
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import gradio as gr
from torchvision.transforms.functional import to_tensor
from huggingface_hub import hf_hub_download, snapshot_download, login

from tok.ar_dtok.ar_model import ARModel
from t2i_inference import T2IConfig, TextToImageInference

def generate_text(self, image: str, prompt: str) -> str:
    image = image.convert('RGB')
    image = to_tensor(image).unsqueeze(0).to(self.device)
    
    image_code = self.visual_tokenizer.encoder(image.to(self.config.dtype))['bottleneck_rep']
    image_text = "".join([f"<I{x}>" for x in image_code[0].cpu().tolist()])
    
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": f"{image_text}\n{prompt}"}
    ]
    
    input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = self.tokenizer(input_text, return_tensors="pt")
    
    gen_ids = self.model.generate(
        inputs.input_ids.to(self.device),
        max_new_tokens=512,
        do_sample=True)
    return self.tokenizer.batch_decode(gen_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)[0]

login(token=os.getenv('HF_TOKEN'))
config = T2IConfig()
config.model = snapshot_download("csuhan/Tar-7B-v0.1")
config.ar_path = {
    "1024px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_1024px.pth"),
    "512px": hf_hub_download("csuhan/TA-Tok", "ar_dtok_lp_512px.pth"),
}
config.encoder_path = hf_hub_download("csuhan/TA-Tok", "ta_tok.pth")
config.decoder_path = hf_hub_download("peizesun/llamagen_t2i", "vq_ds16_t2i.pt")
inference = TextToImageInference(config)

def generate_image(prompt, resolution, top_p, top_k, cfg_scale):
    image = inference.generate_image(prompt, resolution, top_p, top_k, cfg_scale)
    return image

def clear_inputs_t2i():
    return "", None

def understand_image(image, prompt):
    return generate_text(inference, image, prompt)

def clear_inputs_i2t():
    return None, ""

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        <div align="center">

        ### Tar: Unifying Visual Understanding and Generation via Text-Aligned Representations  

        [πŸ•ΈοΈ Project Page](http://tar.csuhan.com) β€’ [πŸ“„ Paper](http://arxiv.org/abs/2506.18898) β€’ [πŸ’» Code](https://github.com/csuhan/Tar) β€’ [πŸ“¦ Model](https://huggingface.co/collections/csuhan/tar-68538273b5537d0bee712648)

        </div>
        """,
        elem_id="title",
    )
    with gr.Tab("Image Generation"):
      with gr.Row():
          with gr.Column(scale=1):
              prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt")
              with gr.Accordion("Advanced Settings", open=False):
                resolution = gr.Radio(
                    ["512px", "1024px"], value="1024px", label="Resolution"
                )
                top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p")
                top_k = gr.Slider(1, 2000, value=1200, step=10, label="Top-k")
                cfg_scale = gr.Slider(1.0, 20.0, value=4.0, step=0.5, label="CFG Scale")
              with gr.Row():
                  generate_btn = gr.Button("Generate")
                  clear_btn = gr.Button("Clear")
          with gr.Column(scale=2):
              output_image = gr.Image(label="Generated Image")

      generate_btn.click(
          generate_image, 
          inputs=[prompt, resolution, top_p, top_k, cfg_scale], 
          outputs=output_image
      )
      clear_btn.click(
          clear_inputs_t2i, 
          outputs=[prompt, output_image]
      )

    with gr.Tab("Image Understanding"):
        with gr.Row():
            with gr.Column(scale=1):
                image_input = gr.Image(label="Upload Image", type="pil")
                question_input = gr.Textbox(label="Instruction", value="Describe the image shortly.")
                with gr.Row():
                    qa_btn = gr.Button("Generate")
                    clear_btn_i2t = gr.Button("Clear")
            with gr.Column(scale=1):
                answer_output = gr.Textbox(label="Response", lines=4)

        qa_btn.click(
            understand_image,
            inputs=[image_input, question_input],
            outputs=answer_output
        )

        clear_btn_i2t.click(
            clear_inputs_i2t,
            outputs=[image_input, question_input, answer_output]
        )

demo.launch(share=True)