File size: 5,491 Bytes
215c4b7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from typing import Callable, List, Optional
import torch
class TimedHook:
def __init__(self, hook_fn, total_steps, apply_at_steps=None):
self.hook_fn = hook_fn
self.total_steps = total_steps
self.apply_at_steps = apply_at_steps
self.current_step = 0
def identity(self, module, input, output):
return output
def __call__(self, module, input, output):
if self.apply_at_steps is not None:
if self.current_step in self.apply_at_steps:
self.__increment()
return self.hook_fn(module, input, output)
else:
self.__increment()
return self.identity(module, input, output)
return self.identity(module, input, output)
def __increment(self):
if self.current_step < self.total_steps:
self.current_step += 1
else:
self.current_step = 0
@torch.no_grad()
def add_feature(sae, feature_idx, value, module, input, output):
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
activated = sae.encode(diff)
mask = torch.zeros_like(activated, device=diff.device)
mask[..., feature_idx] = value
to_add = mask @ sae.decoder.weight.T
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
@torch.no_grad()
def add_feature_on_area_base(sae, feature_idx, activation_map, module, input, output):
return add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output)
@torch.no_grad()
def add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output):
# add the feature to cond and subtract from uncond
# this assumes diff.shape[0] == 2
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
activated = sae.encode(diff)
mask = torch.zeros_like(activated, device=diff.device)
if len(activation_map) == 2:
activation_map = activation_map.unsqueeze(0)
mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
to_add = mask @ sae.decoder.weight.T
to_add = to_add.chunk(2)
output[0][0] -= to_add[0].permute(0, 3, 1, 2).to(output[0].device)[0]
output[0][1] += to_add[1].permute(0, 3, 1, 2).to(output[0].device)[0]
return output
@torch.no_grad()
def add_feature_on_area_base_cond(sae, feature_idx, activation_map, module, input, output):
# add the feature to cond
# this assumes diff.shape[0] == 2
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
diff_uncond, diff_cond = diff.chunk(2)
activated = sae.encode(diff_cond)
mask = torch.zeros_like(activated, device=diff_cond.device)
if len(activation_map) == 2:
activation_map = activation_map.unsqueeze(0)
mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
to_add = mask @ sae.decoder.weight.T
output[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0]
return output
@torch.no_grad()
def replace_with_feature_base(sae, feature_idx, value, module, input, output):
# this assumes diff.shape[0] == 2
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
diff_uncond, diff_cond = diff.chunk(2)
activated = sae.encode(diff_cond)
mask = torch.zeros_like(activated, device=diff_cond.device)
mask[..., feature_idx] = value
to_add = mask @ sae.decoder.weight.T
input[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0]
return input
@torch.no_grad()
def add_feature_on_area_turbo(sae, feature_idx, activation_map, module, input, output):
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
activated = sae.encode(diff)
mask = torch.zeros_like(activated, device=diff.device)
if len(activation_map) == 2:
activation_map = activation_map.unsqueeze(0)
mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
to_add = mask @ sae.decoder.weight.T
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
@torch.no_grad
def add_feature_on_area_flux(
sae,
feature_idx,
activation_map,
module,
input: torch.Tensor,
output: torch.Tensor,
):
diff = (output - input).to(sae.device)
activated = sae.encode(diff)
# TODO: check
if len(activation_map) == 2:
activation_map = activation_map.unsqueeze(0)
mask = torch.zeros_like(activated, device=diff.device)
activation_map = activation_map.flatten()
mask[..., feature_idx] = activation_map.to(mask.device)
to_add = mask @ sae.decoder.weight.T
return output + to_add.to(output.device, output.dtype)
@torch.no_grad()
def replace_with_feature_turbo(sae, feature_idx, value, module, input, output):
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
activated = sae.encode(diff)
mask = torch.zeros_like(activated, device=diff.device)
mask[..., feature_idx] = value
to_add = mask @ sae.decoder.weight.T
return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
@torch.no_grad()
def reconstruct_sae_hook(sae, module, input, output):
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
activated = sae.encode(diff)
reconstructed = sae.decoder(activated) + sae.pre_bias
return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
@torch.no_grad()
def ablate_block(module, input, output):
return input
|