class Edit: def __init__(self, ablator, vanilla_pre_forward_dict: Callable[[str, int], dict], vanilla_forward_dict: Callable[[str, int], dict], ablated_pre_forward_dict: Callable[[str, int], dict], ablated_forward_dict: Callable[[str, int], dict],): self.ablator=ablator self.vanilla_seed = 42 self.vanilla_pre_forward_dict = vanilla_pre_forward_dict self.vanilla_forward_dict = vanilla_forward_dict self.ablated_seed = 42 self.ablated_pre_forward_dict = ablated_pre_forward_dict self.ablated_forward_dict = ablated_forward_dict def get_edit(name: str): if name == "edit_streams": ablator = TransformerActivationCache() stream: str = kwargs["stream"] layers = kwargs["layers"] edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"] interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19} interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19}) return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {}, ablated_forward_dict=lambda block_type, layer_num: interventions, ) """ def get_ablation(name: str, **kwargs): if name == "intermediate_text_stream_to_input": ablator = TransformerActivationCache() return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream="text")}, ablated_forward_dict=lambda block_type, layer_num: {}) elif name == "input_to_intermediate_text_stream": ablator = TransformerActivationCache() return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.cache_attention_activation(*args, full_output=True)}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.replace_stream_input(*args, stream="text")}, ablated_forward_dict=lambda block_type, layer_num: {}) elif name == "set_input_text": tensor: torch.Tensor = kwargs["tensor"] ablator = TransformerActivationCache() return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream="text")}, ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)}) elif name == "replace_text_stream_activation": ablator = AttentionAblationCacheHook() weight = kwargs["weight"] if "weight" in kwargs else 1.0 return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward}, ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.set_ablated_attention(*args, weight=weight)}) elif name == "replace_text_stream": ablator = TransformerActivationCache() weight = kwargs["weight"] if "weight" in kwargs else 1.0 return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward}, ablated_forward_dict=lambda block_type, layer_num: {}) elif name == "input=output": return Ablation(None, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {}, ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablate_block(*args)}) elif name == "reweight_text_stream": ablator = TransformerActivationCache() residual_w=kwargs["residual_w"] activation_w=kwargs["activation_w"] return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {}, ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_text_stream(*args, residual_w=residual_w, activation_w=activation_w)}) elif name == "add_input_text": tensor: torch.Tensor = kwargs["tensor"] ablator = TransformerActivationCache() return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.add_text_stream_input(*args, use_tensor=tensor)}, ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)}) elif name == "nothing": ablator = TransformerActivationCache() return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {}, ablated_forward_dict=lambda block_type, layer_num: {}) elif name == "reweight_image_stream": ablator = TransformerActivationCache() residual_w=kwargs["residual_w"] activation_w=kwargs["activation_w"] return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {}, ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_image_stream(*args, residual_w=residual_w, activation_w=activation_w)}) if name == "intermediate_image_stream_to_input": ablator = TransformerActivationCache() return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream='image')}, ablated_forward_dict=lambda block_type, layer_num: {}) elif name == "replace_text_stream_one_layer": ablator = AttentionAblationCacheHook() weight = kwargs["weight"] if "weight" in kwargs else 1.0 return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward}, ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.restore_text_stream}) elif name == "replace_intermediate_representation": ablator = TransformerActivationCache() tensor: torch.Tensor = kwargs["tensor"] return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream='text_image')}, ablated_forward_dict=lambda block_type, layer_num: {}) elif name == "destroy_registers": ablator = TransformerActivationCache() layers: List[int] = kwargs['layers'] k: float = kwargs["k"] stream: str = kwargs['stream'] random: bool = kwargs["random"] if "random" in kwargs else False lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers}, ablated_forward_dict=lambda block_type, layer_num: {}) elif name == "patch_registers": ablator = TransformerActivationCache() layers: List[int] = kwargs['layers'] k: float = kwargs["k"] stream: str = kwargs['stream'] random: bool = kwargs["random"] if "random" in kwargs else False lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.set_cached_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers}, ablated_forward_dict=lambda block_type, layer_num: {}) elif name == "add_registers": ablator = TransformerActivationCache() num_registers: int = kwargs["num_registers"] return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: insert_extra_registers(*args, num_registers=num_registers)}, ablated_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: discard_extra_registers(*args, num_registers=num_registers)},) elif name == "edit_streams": ablator = TransformerActivationCache() stream: str = kwargs["stream"] layers = kwargs["layers"] edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"] interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19} interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19}) return Ablation(ablator, vanilla_pre_forward_dict=lambda block_type, layer_num: {}, vanilla_forward_dict=lambda block_type, layer_num: {}, ablated_pre_forward_dict=lambda block_type, layer_num: {}, ablated_forward_dict=lambda block_type, layer_num: interventions, ) """