tools / run_xl_pivotal.py
patrickvonplaten's picture
all
6c27fdd
raw
history blame
1.26 kB
#!/usr/bin/env python3
import hf_image_uploader as hiu
from safetensors.torch import load_file
from huggingface_hub import snapshot_download
from diffusers import DiffusionPipeline
import torch
import json
import os
model_id = "multimodalart/sdxl-emoji"
folder = snapshot_download(model_id)
with open(os.path.join(folder, "special_params.json"), 'r') as json_file:
data = json.load(json_file)
token, token_2 = data["TOK"][:4], data["TOK"][4:]
state_dict = load_file(os.path.join(folder, "embeddings.pti"))
text_encoder_sd = state_dict["text_encoders_0"]
text_encoder_2_sd = state_dict["text_encoders_1"]
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
pipe.load_textual_inversion(text_encoder_sd, token=token, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
pipe.load_textual_inversion(text_encoder_2_sd, token=token_2, text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
pipe.load_lora_weights(os.path.join(folder, "lora.safetensors"))
pipe.to(torch_dtype=torch.float16)
pipe.to("cuda")
prompt="A <s0><s1> emoji of a man"
image = pipe(prompt, cross_attention_kwargs={"scale": 0.8}, num_inference_steps=20).images[0]
hiu.upload(image, "patrickvonplaten/images")