surokpro2's picture
Upload 47 files
215c4b7 verified
class Edit:
def __init__(self, ablator, vanilla_pre_forward_dict: Callable[[str, int], dict],
vanilla_forward_dict: Callable[[str, int], dict],
ablated_pre_forward_dict: Callable[[str, int], dict],
ablated_forward_dict: Callable[[str, int], dict],):
self.ablator=ablator
self.vanilla_seed = 42
self.vanilla_pre_forward_dict = vanilla_pre_forward_dict
self.vanilla_forward_dict = vanilla_forward_dict
self.ablated_seed = 42
self.ablated_pre_forward_dict = ablated_pre_forward_dict
self.ablated_forward_dict = ablated_forward_dict
def get_edit(name: str):
if name == "edit_streams":
ablator = TransformerActivationCache()
stream: str = kwargs["stream"]
layers = kwargs["layers"]
edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"]
interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19}
interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19})
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {},
ablated_forward_dict=lambda block_type, layer_num: interventions,
)
"""
def get_ablation(name: str, **kwargs):
if name == "intermediate_text_stream_to_input":
ablator = TransformerActivationCache()
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream="text")},
ablated_forward_dict=lambda block_type, layer_num: {})
elif name == "input_to_intermediate_text_stream":
ablator = TransformerActivationCache()
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.cache_attention_activation(*args, full_output=True)},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.replace_stream_input(*args, stream="text")},
ablated_forward_dict=lambda block_type, layer_num: {})
elif name == "set_input_text":
tensor: torch.Tensor = kwargs["tensor"]
ablator = TransformerActivationCache()
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream="text")},
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)})
elif name == "replace_text_stream_activation":
ablator = AttentionAblationCacheHook()
weight = kwargs["weight"] if "weight" in kwargs else 1.0
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward},
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.set_ablated_attention(*args, weight=weight)})
elif name == "replace_text_stream":
ablator = TransformerActivationCache()
weight = kwargs["weight"] if "weight" in kwargs else 1.0
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward},
ablated_forward_dict=lambda block_type, layer_num: {})
elif name == "input=output":
return Ablation(None,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {},
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablate_block(*args)})
elif name == "reweight_text_stream":
ablator = TransformerActivationCache()
residual_w=kwargs["residual_w"]
activation_w=kwargs["activation_w"]
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {},
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_text_stream(*args, residual_w=residual_w, activation_w=activation_w)})
elif name == "add_input_text":
tensor: torch.Tensor = kwargs["tensor"]
ablator = TransformerActivationCache()
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.add_text_stream_input(*args, use_tensor=tensor)},
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)})
elif name == "nothing":
ablator = TransformerActivationCache()
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {},
ablated_forward_dict=lambda block_type, layer_num: {})
elif name == "reweight_image_stream":
ablator = TransformerActivationCache()
residual_w=kwargs["residual_w"]
activation_w=kwargs["activation_w"]
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {},
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_image_stream(*args, residual_w=residual_w, activation_w=activation_w)})
if name == "intermediate_image_stream_to_input":
ablator = TransformerActivationCache()
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream='image')},
ablated_forward_dict=lambda block_type, layer_num: {})
elif name == "replace_text_stream_one_layer":
ablator = AttentionAblationCacheHook()
weight = kwargs["weight"] if "weight" in kwargs else 1.0
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward},
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.restore_text_stream})
elif name == "replace_intermediate_representation":
ablator = TransformerActivationCache()
tensor: torch.Tensor = kwargs["tensor"]
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream='text_image')},
ablated_forward_dict=lambda block_type, layer_num: {})
elif name == "destroy_registers":
ablator = TransformerActivationCache()
layers: List[int] = kwargs['layers']
k: float = kwargs["k"]
stream: str = kwargs['stream']
random: bool = kwargs["random"] if "random" in kwargs else False
lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers},
ablated_forward_dict=lambda block_type, layer_num: {})
elif name == "patch_registers":
ablator = TransformerActivationCache()
layers: List[int] = kwargs['layers']
k: float = kwargs["k"]
stream: str = kwargs['stream']
random: bool = kwargs["random"] if "random" in kwargs else False
lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.set_cached_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers},
ablated_forward_dict=lambda block_type, layer_num: {})
elif name == "add_registers":
ablator = TransformerActivationCache()
num_registers: int = kwargs["num_registers"]
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: insert_extra_registers(*args, num_registers=num_registers)},
ablated_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: discard_extra_registers(*args, num_registers=num_registers)},)
elif name == "edit_streams":
ablator = TransformerActivationCache()
stream: str = kwargs["stream"]
layers = kwargs["layers"]
edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"]
interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19}
interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19})
return Ablation(ablator,
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
vanilla_forward_dict=lambda block_type, layer_num: {},
ablated_pre_forward_dict=lambda block_type, layer_num: {},
ablated_forward_dict=lambda block_type, layer_num: interventions,
)
"""