|
|
|
class Edit: |
|
|
|
def __init__(self, ablator, vanilla_pre_forward_dict: Callable[[str, int], dict], |
|
vanilla_forward_dict: Callable[[str, int], dict], |
|
ablated_pre_forward_dict: Callable[[str, int], dict], |
|
ablated_forward_dict: Callable[[str, int], dict],): |
|
self.ablator=ablator |
|
self.vanilla_seed = 42 |
|
self.vanilla_pre_forward_dict = vanilla_pre_forward_dict |
|
self.vanilla_forward_dict = vanilla_forward_dict |
|
|
|
self.ablated_seed = 42 |
|
self.ablated_pre_forward_dict = ablated_pre_forward_dict |
|
self.ablated_forward_dict = ablated_forward_dict |
|
|
|
|
|
def get_edit(name: str): |
|
|
|
if name == "edit_streams": |
|
ablator = TransformerActivationCache() |
|
stream: str = kwargs["stream"] |
|
layers = kwargs["layers"] |
|
edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"] |
|
|
|
interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19} |
|
interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19}) |
|
|
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_forward_dict=lambda block_type, layer_num: interventions, |
|
) |
|
|
|
|
|
""" |
|
def get_ablation(name: str, **kwargs): |
|
|
|
if name == "intermediate_text_stream_to_input": |
|
|
|
ablator = TransformerActivationCache() |
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream="text")}, |
|
ablated_forward_dict=lambda block_type, layer_num: {}) |
|
elif name == "input_to_intermediate_text_stream": |
|
ablator = TransformerActivationCache() |
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.cache_attention_activation(*args, full_output=True)}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.replace_stream_input(*args, stream="text")}, |
|
ablated_forward_dict=lambda block_type, layer_num: {}) |
|
|
|
elif name == "set_input_text": |
|
|
|
tensor: torch.Tensor = kwargs["tensor"] |
|
|
|
ablator = TransformerActivationCache() |
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream="text")}, |
|
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)}) |
|
|
|
elif name == "replace_text_stream_activation": |
|
ablator = AttentionAblationCacheHook() |
|
weight = kwargs["weight"] if "weight" in kwargs else 1.0 |
|
|
|
|
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward}, |
|
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.set_ablated_attention(*args, weight=weight)}) |
|
|
|
elif name == "replace_text_stream": |
|
ablator = TransformerActivationCache() |
|
weight = kwargs["weight"] if "weight" in kwargs else 1.0 |
|
|
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward}, |
|
ablated_forward_dict=lambda block_type, layer_num: {}) |
|
|
|
|
|
elif name == "input=output": |
|
return Ablation(None, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablate_block(*args)}) |
|
|
|
elif name == "reweight_text_stream": |
|
ablator = TransformerActivationCache() |
|
|
|
residual_w=kwargs["residual_w"] |
|
activation_w=kwargs["activation_w"] |
|
|
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_text_stream(*args, residual_w=residual_w, activation_w=activation_w)}) |
|
|
|
elif name == "add_input_text": |
|
|
|
tensor: torch.Tensor = kwargs["tensor"] |
|
|
|
ablator = TransformerActivationCache() |
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.add_text_stream_input(*args, use_tensor=tensor)}, |
|
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)}) |
|
|
|
elif name == "nothing": |
|
ablator = TransformerActivationCache() |
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_forward_dict=lambda block_type, layer_num: {}) |
|
|
|
elif name == "reweight_image_stream": |
|
ablator = TransformerActivationCache() |
|
residual_w=kwargs["residual_w"] |
|
activation_w=kwargs["activation_w"] |
|
|
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_image_stream(*args, residual_w=residual_w, activation_w=activation_w)}) |
|
|
|
if name == "intermediate_image_stream_to_input": |
|
|
|
ablator = TransformerActivationCache() |
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream='image')}, |
|
ablated_forward_dict=lambda block_type, layer_num: {}) |
|
|
|
|
|
elif name == "replace_text_stream_one_layer": |
|
ablator = AttentionAblationCacheHook() |
|
weight = kwargs["weight"] if "weight" in kwargs else 1.0 |
|
|
|
|
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward}, |
|
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.restore_text_stream}) |
|
|
|
elif name == "replace_intermediate_representation": |
|
ablator = TransformerActivationCache() |
|
tensor: torch.Tensor = kwargs["tensor"] |
|
|
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream='text_image')}, |
|
ablated_forward_dict=lambda block_type, layer_num: {}) |
|
|
|
elif name == "destroy_registers": |
|
ablator = TransformerActivationCache() |
|
layers: List[int] = kwargs['layers'] |
|
k: float = kwargs["k"] |
|
stream: str = kwargs['stream'] |
|
random: bool = kwargs["random"] if "random" in kwargs else False |
|
lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False |
|
|
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers}, |
|
ablated_forward_dict=lambda block_type, layer_num: {}) |
|
|
|
elif name == "patch_registers": |
|
ablator = TransformerActivationCache() |
|
layers: List[int] = kwargs['layers'] |
|
k: float = kwargs["k"] |
|
stream: str = kwargs['stream'] |
|
random: bool = kwargs["random"] if "random" in kwargs else False |
|
lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False |
|
|
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.set_cached_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers}, |
|
ablated_forward_dict=lambda block_type, layer_num: {}) |
|
|
|
elif name == "add_registers": |
|
ablator = TransformerActivationCache() |
|
num_registers: int = kwargs["num_registers"] |
|
|
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: insert_extra_registers(*args, num_registers=num_registers)}, |
|
ablated_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: discard_extra_registers(*args, num_registers=num_registers)},) |
|
|
|
|
|
elif name == "edit_streams": |
|
ablator = TransformerActivationCache() |
|
stream: str = kwargs["stream"] |
|
layers = kwargs["layers"] |
|
edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"] |
|
|
|
interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19} |
|
interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19}) |
|
|
|
return Ablation(ablator, |
|
vanilla_pre_forward_dict=lambda block_type, layer_num: {}, |
|
vanilla_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_pre_forward_dict=lambda block_type, layer_num: {}, |
|
ablated_forward_dict=lambda block_type, layer_num: interventions, |
|
) |
|
""" |