sae_flux / app.py
surokpro2's picture
Update app.py
512649a verified
from functools import partial
import json
import gradio as gr
import os
# environment
os.environ['HF_HOME'] = '/dlabscratch1/anmari'
os.environ['TRANSFORMERS_CACHE'] = '/dlabscratch1/anmari'
os.environ['HF_DATASETS_CACHE'] = '/dlabscratch1/anmari'
# os.environ["HF_TOKEN"] = ""
import torch
from PIL import Image
from SDLens import HookedStableDiffusionXLPipeline, CachedPipeline as CachedFLuxPipeline
from SDLens.cache_and_edit.flux_pipeline import EditedFluxPipeline
from SAE import SparseAutoencoder
from utils import TimedHook, add_feature_on_area_base, replace_with_feature_base, add_feature_on_area_turbo, replace_with_feature_turbo, add_feature_on_area_flux
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import threading
from einops import rearrange
import spaces
# from retrieval import FeatureRetriever
code_to_block_sd = {
"down.2.1": "unet.down_blocks.2.attentions.1",
"mid.0": "unet.mid_block.attentions.0",
"up.0.1": "unet.up_blocks.0.attentions.1",
"up.0.0": "unet.up_blocks.0.attentions.0"
}
code_to_block_flux = {"18": "transformer.transformer_blocks.18"}
FLUX_NAMES = ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev"]
MODELS_CONFIG = {
"stabilityai/stable-diffusion-xl-base-1.0": {
"steps": 25,
"guidance_scale": 8.0,
"choices": ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
"value": "down.2.1 (composition)",
"code_to_block": code_to_block_sd,
"max_steps": 50,
"is_flux": False,
"downsample_factor": 16,
"add_feature_on_area": add_feature_on_area_base,
"num_features": 5120,
},
"stabilityai/sdxl-turbo": {
"steps": 1,
"guidance_scale": 0.0,
"choices": ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
"value": "down.2.1 (composition)",
"code_to_block": code_to_block_sd,
"max_steps": 4,
"is_flux": False,
"downsample_factor": 32,
"add_feature_on_area": add_feature_on_area_turbo,
"num_features": 5120,
},
"black-forest-labs/FLUX.1-schnell": {
"steps": 1,
"guidance_scale": 0.0,
"choices": ["18"],
"value": "18",
"code_to_block": code_to_block_flux,
"max_steps": 4,
"is_flux": True,
"exclude_list": [2462, 2974, 1577, 786, 3188, 9986, 4693, 8472, 8248, 325, 9596, 2813, 10803, 11773, 11410, 1067, 2965, 10488, 4537, 2102],
"downsample_factor": 8,
"add_feature_on_area": add_feature_on_area_flux,
"num_features": 12288
},
"black-forest-labs/FLUX.1-dev": {
"steps": 25,
"guidance_scale": 0.0,
"choices": ["18"],
"value": "18",
"code_to_block": code_to_block_flux,
"max_steps": 50,
"is_flux": True,
"exclude_list": [2462, 2974, 1577, 786, 3188, 9986, 4693, 8472, 8248, 325, 9596, 2813, 10803, 11773, 11410, 1067, 2965, 10488, 4537, 2102],
"downsample_factor": 8,
"add_feature_on_area": add_feature_on_area_flux,
"num_features": 12288
}
}
lock = threading.Lock()
def process_cache(cache, saes_dict, model_config, timestep=None):
top_features_dict = {}
sparse_maps_dict = {}
for code in model_config['code_to_block'].keys():
block = model_config["code_to_block"][code]
sae = saes_dict[code]
if model_config["is_flux"]:
with torch.no_grad():
features = sae.encode(torch.stack(cache.image_activation)) # shape: [timestep, batch, seq_len, num_features]
features[..., model_config["exclude_list"]] = 0
if timestep is not None and timestep < features.shape[0]:
features = features[timestep:timestep+1]
# I want to get [batch, timestep, 64, 64, num_features]
sparse_maps = rearrange(features, "t b (w h) n -> b t w h n", w=64, h=64).squeeze(0).squeeze(0)
else:
diff = cache["output"][block] - cache["input"][block]
if diff.shape[0] == 2: # guidance is on and we need to select the second output
diff = diff[1].unsqueeze(0)
# If a specific timestep is provided, select that timestep from the cached activations
if timestep is not None and timestep < diff.shape[1]:
diff = diff[:, timestep:timestep+1]
diff = diff.permute(0, 1, 3, 4, 2).squeeze(0).squeeze(0)
with torch.no_grad():
sparse_maps = sae.encode(diff)
averages = torch.mean(sparse_maps, dim=(0, 1))
top_features = torch.topk(averages, 40).indices
top_features_dict[code] = top_features.cpu().tolist()
sparse_maps_dict[code] = sparse_maps.cpu().numpy()
return top_features_dict, sparse_maps_dict
def plot_image_heatmap(cache, block_select, radio, model_config):
code = block_select.split()[0]
feature = int(radio)
heatmap = cache["heatmaps"][code][:, :, feature]
scaling_factor = 16 if model_config["is_flux"] else 32
heatmap = np.kron(heatmap, np.ones((scaling_factor, scaling_factor)))
image = cache["image"].convert("RGBA")
jet = plt.cm.jet
cmap = jet(np.arange(jet.N))
cmap[:1, -1] = 0
cmap[1:, -1] = 0.6
cmap = ListedColormap(cmap)
heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
heatmap_rgba = cmap(heatmap)
heatmap_image = Image.fromarray((heatmap_rgba * 255).astype(np.uint8))
heatmap_with_transparency = Image.alpha_composite(image, heatmap_image)
return heatmap_with_transparency
def create_prompt_part(pipe, saes_dict, demo):
model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
@spaces.GPU
def image_gen(prompt, timestep=None, num_steps=None, guidance_scale=None):
lock.acquire()
try:
# Default values
default_n_steps = model_config["steps"]
default_guidance = model_config["guidance_scale"]
# Use provided values if available, otherwise use defaults
n_steps = default_n_steps if num_steps is None else int(num_steps)
guidance = default_guidance if guidance_scale is None else float(guidance_scale)
# Convert timestep to integer if it's not None
timestep_int = None if timestep is None else int(timestep)
if "FLUX" in pipe.pipe.name_or_path:
images = pipe.run(
prompt,
num_inference_steps=n_steps,
width=1024,
height=1024,
cache_activations=True,
guidance_scale=guidance,
positions_to_cache = list(model_config["code_to_block"].values()),
inverse=False,
)
cache = pipe.activation_cache
else:
images, cache = pipe.run_with_cache(
prompt,
positions_to_cache=list(model_config["code_to_block"].values()),
num_inference_steps=n_steps,
generator=torch.Generator(device="cpu").manual_seed(42),
guidance_scale=guidance,
save_input=True,
save_output=True
)
finally:
lock.release()
top_features_dict, top_sparse_maps_dict = process_cache(cache, saes_dict, model_config, timestep_int)
return images.images[0], {
"image": images.images[0],
"heatmaps": top_sparse_maps_dict,
"features": top_features_dict
}
def update_radio(cache, block_select):
code = block_select.split()[0]
return gr.update(choices=cache["features"][code])
def update_img(cache, block_select, radio):
new_img = plot_image_heatmap(cache, block_select, radio, model_config)
return new_img
with gr.Tab("Explore", elem_classes="tabs") as explore_tab:
cache = gr.State(value={
"image": None,
"heatmaps": None,
"features": []
})
with gr.Row():
with gr.Column(scale=7):
with gr.Row(equal_height=True):
prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party and eathing a dish with peas.")
button = gr.Button("Generate", elem_classes="generate_button1")
with gr.Row():
image = gr.Image(width=512, height=512, image_mode="RGB", label="Generated image")
with gr.Column(scale=4):
block_select = gr.Dropdown(
choices=model_config["choices"], # replace this for flux
value=model_config["value"],
label="Select block",
elem_id="block_select",
interactive=True
)
with gr.Group() as sdxl_base_controls:
steps_slider = gr.Slider(
minimum=1,
maximum=model_config["max_steps"],
value= model_config["steps"],
step=1,
label="Number of steps",
elem_id="steps_slider",
interactive=True,
visible=True
)
# Add timestep selector
# TODO: check this
timestep_selector = gr.Slider(
minimum=0,
maximum=model_config["max_steps"]-1,
value=None,
step=1,
label="Timestep (leave empty for average across all steps)",
elem_id="timestep_selector",
interactive=True,
visible=True,
)
recompute_button = gr.Button("Recompute", elem_id="recompute_button")
# Update max timestep when steps change
steps_slider.change(lambda s: gr.update(maximum=s-1), [steps_slider], [timestep_selector])
radio = gr.Radio(choices=[], label="Select a feature", interactive=True)
button.click(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache])
cache.change(update_radio, [cache, block_select], outputs=[radio])
block_select.select(update_radio, [cache, block_select], outputs=[radio])
radio.select(update_img, [cache, block_select, radio], outputs=[image])
recompute_button.click(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache])
demo.load(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache])
return explore_tab
def downsample_mask(image, factor):
downsampled = image.reshape(
(image.shape[0] // factor, factor,
image.shape[1] // factor, factor)
)
downsampled = downsampled.mean(axis=(1, 3))
return downsampled
def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, means_dict, demo):
model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
@spaces.GPU
def image_gen(prompt, num_steps, guidance_scale=None):
lock.acquire()
guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale)
try:
if "FLUX" in pipe.pipe.name_or_path:
images = pipe.run(
prompt,
num_inference_steps=int(num_steps),
width=1024,
height=1024,
cache_activations=False,
guidance_scale=guidance,
inverse=False,
)
else:
images = pipe.run_with_hooks(
prompt,
position_hook_dict={},
num_inference_steps=int(num_steps),
generator=torch.Generator(device="cpu").manual_seed(42),
guidance_scale=guidance,
)
finally:
lock.release()
if images.images[0].size == (1024, 1024):
return images.images[0].resize((512, 512))
else:
return images.images[0]
@spaces.GPU
def image_mod(prompt, block_str, brush_index, strength, num_steps, input_image, guidance_scale=None, start_index=None, end_index=None):
block = block_str.split(" ")[0]
mask = (input_image["layers"][0] > 0)[:, :, -1].astype(float)
mask = downsample_mask(mask, model_config["downsample_factor"])
mask = torch.tensor(mask, dtype=torch.float32, device="cuda")
if mask.sum() == 0:
gr.Info("No mask selected, please draw on the input image")
# Set default values for start_index and end_index if not provided
if start_index is None:
start_index = 0
if end_index is None:
end_index = int(num_steps)
# Ensure start_index and end_index are within valid ranges
start_index = max(0, min(int(start_index), int(num_steps)))
end_index = max(0, min(int(end_index), int(num_steps)))
# Ensure start_index is less than end_index
if start_index >= end_index:
start_index = max(0, end_index - 1)
def myhook(module, input, output):
return model_config["add_feature_on_area"](
saes_dict[block],
brush_index,
mask * means_dict[block][brush_index] * strength,
module,
input,
output)
hook = TimedHook(myhook, int(num_steps), np.arange(start_index, end_index))
lock.acquire()
guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale)
try:
if model_config["is_flux"]:
image = pipe.run_with_edit(
prompt,
seed=42,
num_inference_steps=int(num_steps),
edit_fn= lambda input, output: hook(None, input, output),
layers_for_edit_fn=[i for i in range(18, 57)],
stream="image").images[0]
else:
image = pipe.run_with_hooks(
prompt,
position_hook_dict={model_config["code_to_block"][block]: hook},
num_inference_steps=int(num_steps),
generator=torch.Generator(device="cpu").manual_seed(42),
guidance_scale=guidance
).images[0]
finally:
lock.release()
return image
def feature_icon(block_str, brush_index, guidance_scale=None):
block = block_str.split(" ")[0]
if block in ["mid.0", "up.0.0"]:
gr.Info("Note that Feature Icon works best with down.2.1 and up.0.1 blocks but feel free to explore", duration=3)
def hook(module, input, output):
if is_base_model:
return replace_with_feature_base(
saes_dict[block],
brush_index,
means_dict[block][brush_index] * saes_dict[block].k,
module,
input,
output
)
else:
return replace_with_feature_turbo(
saes_dict[block],
brush_index,
means_dict[block][brush_index] * saes_dict[block].k,
module,
input,
output)
lock.acquire()
guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale)
try:
image = pipe.run_with_hooks(
"",
position_hook_dict={model_config["code_to_block"][block]: hook},
num_inference_steps=model_config["steps"],
generator=torch.Generator(device="cpu").manual_seed(42),
guidance_scale=guidance,
).images[0]
finally:
lock.release()
return image
with gr.Tab("Paint!", elem_classes="tabs") as intervene_tab:
image_state = gr.State(value=None)
with gr.Row():
with gr.Column(scale=3):
# Generation column
with gr.Row():
# prompt and num_steps
prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A dog plays with a ball, cartoon", elem_id="prompt_input")
with gr.Row():
num_steps = gr.Number(value=model_config["steps"], label="Number of steps", minimum=1, maximum=model_config["max_steps"], elem_id="num_steps", precision=0)
with gr.Row():
# Generate button
button_generate = gr.Button("Generate", elem_id="generate_button")
with gr.Column(scale=3):
# Intervention column
with gr.Row():
# dropdowns and number inputs
with gr.Column(scale=7):
with gr.Row():
block_select = gr.Dropdown(
choices=model_config["choices"],
value=model_config["value"],
label="Select block",
elem_id="block_select"
)
brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=model_config["num_features"]-1, elem_id="brush_index", precision=0)
# with gr.Row():
# button_icon = gr.Button('Feature Icon', elem_id="feature_icon_button")
with gr.Row():
gr.Markdown("**TimedHook Range** (which steps to apply the feature)", visible=True)
with gr.Row():
start_index = gr.Number(value=0, label="Start index", minimum=0, maximum=model_config["max_steps"], elem_id="start_index", precision=0, visible=True)
end_index = gr.Number(value=model_config["steps"], label="End index", minimum=0, maximum=model_config["max_steps"], elem_id="end_index", precision=0, visible=True)
with gr.Column(scale=3):
with gr.Row():
strength = gr.Number(value=10, label="Strength", minimum=-40, maximum=40, elem_id="strength", precision=2)
with gr.Row():
button = gr.Button('Apply', elem_id="apply_button")
with gr.Row():
with gr.Column():
# Input image
i_image = gr.Sketchpad(
height=610,
layers=False, transforms=[], placeholder="Generate and paint!",
brush=gr.Brush(default_size=64, color_mode="fixed", colors=['black']),
container=False,
canvas_size=(512, 512),
label="Input Image")
clear_button = gr.Button("Clear")
clear_button.click(lambda x: x, [image_state], [i_image])
# Output image
o_image = gr.Image(width=512, height=512, label="Output Image")
# Set up the click events
button_generate.click(image_gen, inputs=[prompt_field, num_steps], outputs=[image_state])
image_state.change(lambda x: x, [image_state], [i_image])
# Update max values for start_index and end_index when num_steps changes
def update_index_maxes(steps):
return gr.update(maximum=steps), gr.update(maximum=steps)
num_steps.change(update_index_maxes, [num_steps], [start_index, end_index])
button.click(image_mod,
inputs=[prompt_field, block_select, brush_index, strength, num_steps, i_image, start_index, end_index],
outputs=o_image)
# button_icon.click(feature_icon, inputs=[block_select, brush_index], outputs=o_image)
demo.load(image_gen, [prompt_field, num_steps], outputs=[image_state])
return intervene_tab
def create_top_images_part(demo, pipe):
model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
if isinstance(pipe, HookedStableDiffusionXLPipeline):
is_flux = False
elif isinstance(pipe, CachedFLuxPipeline):
is_flux = True
else:
raise AssertionError(f"Unknown pipe class: {type(pipe)}")
def update_top_images(block_select, brush_index):
block = block_select.split(" ")[0]
# Define path for fetching image
if is_flux:
part = 1 if brush_index <= 7000 else 2
url = f"https://huggingface.co/datasets/antoniomari/flux_sae_images/resolve/main/{block}/part{part}/{brush_index}.jpg"
else:
url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{brush_index}.jpg"
return url
with gr.Tab("Top Images", elem_classes="tabs") as top_images_tab:
with gr.Row():
block_select = gr.Dropdown(
choices=["flux_18"] if is_flux else ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
value="flux_18" if is_flux else "down.2.1 (composition)",
label="Select block"
)
brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=model_config["num_features"]-1, precision=0)
with gr.Row():
image = gr.Image(width=600, height=600, label="Top Images")
block_select.select(update_top_images, [block_select, brush_index], outputs=[image])
brush_index.change(update_top_images, [block_select, brush_index], outputs=[image])
demo.load(update_top_images, [block_select, brush_index], outputs=[image])
return top_images_tab
def create_top_images_plus_search_part(retriever, demo, pipe):
model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
if isinstance(pipe, HookedStableDiffusionXLPipeline):
is_flux = False
elif isinstance(pipe, CachedFLuxPipeline):
is_flux = True
else:
raise AssertionError(f"Unknown pipe class: {type(pipe)}")
def update_cache(block_select, search_by_text, search_by_index):
if search_by_text == "":
top_indices = []
index = search_by_index
block = block_select.split(" ")[0]
# Define path for fetching image
if is_flux:
part = 1 if index <= 7000 else 2
url = f"https://huggingface.co/antoniomari/flux_sae_images/resolve/main/{block}/part{part}/{index}.jpg"
else:
url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg"
return url, {"image": url, "feature_idx": index, "features": top_indices}
else:
# TODO
if retriever is None:
raise ValueError("Feature retrieval is not enabled")
lock.acquire()
try:
top_indices = list(retriever.query_text(search_by_text, block_select.split(" ")[0]).keys())
finally:
lock.release()
block = block_select.split(" ")[0]
top_indices = list(map(int, top_indices))
index = top_indices[0]
url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg"
return url, {"image": url, "feature_idx": index, "features": top_indices[:20]}
def update_radio(cache):
return gr.update(choices=cache["features"], value=cache["feature_idx"])
def update_img(cache, block_select, index):
block = block_select.split(" ")[0]
url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg"
return url
with gr.Tab("Top Images", elem_classes="tabs") as explore_tab:
cache = gr.State(value={
"image": None,
"feature_idx": None,
"features": []
})
with gr.Row():
with gr.Column(scale=7):
with gr.Row():
# top images
image = gr.Image(width=600, height=600, image_mode="RGB", label="Top images")
with gr.Column(scale=4):
block_select = gr.Dropdown(
choices=["flux_18"] if is_flux else ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
value="flux_18" if is_flux else "down.2.1 (composition)",
label="Select block",
elem_id="block_select",
interactive=True
)
search_by_index = gr.Number(value=0, label="Search by index", minimum=0, maximum=model_config["num_features"]-1, precision=0)
search_by_text = gr.Textbox(lines=1, label="Search by text", value="", visible=False)
radio = gr.Radio(choices=[], label="Select a feature", interactive=True, visible=False)
search_by_text.change(update_cache,
[block_select, search_by_text, search_by_index],
outputs=[image, cache])
block_select.select(update_cache,
[block_select, search_by_text, search_by_index],
outputs=[image, cache])
cache.change(update_radio, [cache], outputs=[radio])
radio.select(update_img, [cache, block_select, radio], outputs=[image])
search_by_index.change(update_img, [cache, block_select, search_by_index], outputs=[image])
demo.load(update_img,
[cache, block_select, search_by_index],
outputs=[image])
return explore_tab
def create_intro_part():
with gr.Tab("Instructions", elem_classes="tabs") as intro_tab:
gr.Markdown(
'''# Unpacking SDXL Turbo with Sparse Autoencoders
## Demo Overview
This demo showcases the use of Sparse Autoencoders (SAEs) to understand the features learned by the Stable Diffusion XL Turbo model.
## How to Use
### Explore
* Enter a prompt in the text box and click on the "Generate" button to generate an image.
* You can observe the active features in different blocks plot on top of the generated image.
### Top Images
* For each feature, you can view the top images that activate the feature the most.
### Paint!
* Generate an image using the prompt.
* Paint on the generated image to apply interventions.
* Use the "Feature Icon" button to understand how the selected brush functions.
### Remarks
* Not all brushes mix well with all images. Experiment with different brushes and strengths.
* Feature Icon works best with `down.2.1 (composition)` and `up.0.1 (style)` blocks.
* This demo is provided for research purposes only. We do not take responsibility for the content generated by the demo.
### Interesting features to try
To get started, try the following features:
- down.2.1 (composition): 2301 (evil) 3747 (image frame) 4998 (cartoon)
- up.0.1 (style): 4977 (tiger stripes) 90 (fur) 2615 (twilight blur)
'''
)
return intro_tab
def create_demo(pipe, saes_dict, means_dict, use_retrieval=True):
custom_css = """
.tabs button {
font-size: 20px !important; /* Adjust font size for tab text */
padding: 10px !important; /* Adjust padding to make the tabs bigger */
font-weight: bold !important; /* Adjust font weight to make the text bold */
}
.generate_button1 {
max-width: 160px !important;
margin-top: 20px !important;
margin-bottom: 20px !important;
}
"""
if use_retrieval:
retriever = None # FeatureRetriever()
else:
retriever = None
with gr.Blocks(css=custom_css) as demo:
# with create_intro_part():
# pass
with create_prompt_part(pipe, saes_dict, demo):
pass
with create_top_images_part(demo, pipe):
pass
with create_intervene_part(pipe, saes_dict, means_dict, demo):
pass
return demo
if __name__ == "__main__":
import os
import gradio as gr
import torch
from SDLens import HookedStableDiffusionXLPipeline
from SAE import SparseAutoencoder
from huggingface_hub import hf_hub_download
from huggingface_hub import login
login(token=os.environ["HF_TOKEN"])
dtype = torch.float16
pipe = EditedFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
device_map="balanced",
torch_dtype=dtype
)
pipe.set_progress_bar_config(disable=True)
pipe = CachedFLuxPipeline(pipe)
# Parameters
DEVICE = "cuda"
# Hugging Face repo setup
HF_REPO_ID = "antoniomari/SAE_flux_18"
HF_BRANCH = "main"
# Command-line arguments
block_code = "18"
block_name = code_to_block_flux[block_code]
saes_dict = {}
means_dict = {}
# Download files from the root of the repo
state_dict_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename="state_dict.pth",
revision=HF_BRANCH
)
config_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename="config.json",
revision=HF_BRANCH
)
mean_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename="mean.pt",
revision=HF_BRANCH
)
# Load config and model
with open(config_path, "r") as f:
config = json.load(f)
sae = SparseAutoencoder(**config)
checkpoint = torch.load(state_dict_path, map_location=DEVICE)
state_dict = checkpoint["state_dict"]
sae.load_state_dict(state_dict)
sae = sae.to(DEVICE, dtype=torch.float16).eval()
means = torch.load(mean_path, map_location=DEVICE).to(dtype)
saes_dict[block_code] = sae
means_dict[block_code] = means
demo = create_demo(pipe, saes_dict, means_dict)
demo.launch()