Skip to content

Commit

Permalink
Commit update to hook_mlp_in (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurConmy authored Jul 5, 2023
1 parent 1723788 commit feefcf1
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 36 deletions.
15 changes: 0 additions & 15 deletions tests/acceptance/test_activation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,3 @@ def test_stack_neuron_results_with_apply_ln():
assert torch.isclose(
ref_scaled_residual_stack, scaled_residual_stack, atol=1e-7
).all()


def test_hook_mlp_in_memory():
"""Test that the new hook_mlp_in is not taking any extra memory"""

model = load_model("solu-2l")

# Run the model and cache all activations
tokens, _ = get_ioi_tokens_and_answer_tokens(model)
_, cache = model.run_with_cache(tokens)

assert (
cache["blocks.0.hook_resid_mid"].data_ptr()
== cache["blocks.0.hook_mlp_in"].data_ptr()
)
1 change: 0 additions & 1 deletion tests/unit/test_cache_hook_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"blocks.0.attn.hook_z",
"blocks.0.hook_attn_out",
"blocks.0.hook_resid_mid",
"blocks.0.hook_mlp_in",
"blocks.0.ln2.hook_scale",
"blocks.0.ln2.hook_normalized",
"blocks.0.mlp.hook_pre",
Expand Down
29 changes: 12 additions & 17 deletions tests/unit/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,28 +132,23 @@ def test_conditional_hooks():
def identity_hook(z, hook):
return z

model.reset_hooks()
model.set_use_attn_result(False)
with pytest.raises(AssertionError):
model.add_hook("blocks.0.attn.hook_result", identity_hook)

model.reset_hooks()
model.set_use_split_qkv_input(False)
with pytest.raises(AssertionError):
model.add_hook("blocks.0.hook_q_input", identity_hook)

# now when we set these conditions to true, should be no errors!

model.reset_hooks()
model.set_use_attn_result(True)
model.add_hook("blocks.0.attn.hook_result", identity_hook)
for hook_name, set_use_hook_function in [
("blocks.0.attn.hook_result", model.set_use_attn_result),
("blocks.0.hook_q_input", model.set_use_split_qkv_input),
("blocks.0.hook_mlp_in", model.set_use_hook_mlp_in),
]:
model.reset_hooks()
set_use_hook_function(False)
with pytest.raises(AssertionError):
model.add_hook(hook_name, identity_hook)
set_use_hook_function(True)
model.add_hook(hook_name, identity_hook)

# check that things are the right shape in the split_q case
model.reset_hooks()
model.set_use_split_qkv_input(True)
model.add_hook("blocks.0.hook_q_input", identity_hook)

# check that things are the right shape

cache = model.run_with_cache(
prompt,
names_filter=lambda x: x == "blocks.0.hook_q_input",
Expand Down
10 changes: 10 additions & 0 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def check_hooks_to_add(
assert (
self.cfg.use_split_qkv_input
), f"Cannot add hook {hook_point_name} if use_split_qkv_input is False"
if hook_point_name.endswith("mlp_in"):
assert (
self.cfg.use_hook_mlp_in
), f"Cannot add hook {hook_point_name} if use_hook_mlp_in is False"

@overload
def forward(
Expand Down Expand Up @@ -1313,6 +1317,12 @@ def set_use_split_qkv_input(self, use_split_qkv_input: bool):
"""
self.cfg.use_split_qkv_input = use_split_qkv_input

def set_use_hook_mlp_in(self, use_hook_mlp_in: bool):
"""
Toggles whether to allow storing and editing inputs to each MLP layer.
"""
self.cfg.use_hook_mlp_in = use_hook_mlp_in

def process_weights_(
self,
fold_ln: bool = True,
Expand Down
3 changes: 3 additions & 0 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class HookedTransformerConfig:
for large models, so defaults to False
use_split_qkv_input (bool): whether to explicitly calculate the input of
each head separately, with a hook. Defaults to false to save memory.
use_hook_mlp_in (bool): whether to use a hook to get the input to the
MLP layer. Defaults to false to save memory.
use_attn_scale (bool): whether to scale the attention weights by
1/sqrt(d_head)
model_name (str): the name of the model, used to load
Expand Down Expand Up @@ -140,6 +142,7 @@ class HookedTransformerConfig:
use_attn_result: bool = False
use_attn_scale: bool = True
use_split_qkv_input: bool = False
use_hook_mlp_in: bool = False
use_local_attn: bool = False
original_architecture: Optional[str] = None
from_checkpoint: bool = False
Expand Down
16 changes: 13 additions & 3 deletions transformer_lens/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,7 +929,12 @@ def add_head_dimension(tensor):
resid_mid = self.hook_resid_mid(
resid_pre + attn_out
) # [batch, pos, d_model]
normalized_resid_mid = self.ln2(self.hook_mlp_in(resid_mid))
mlp_in = (
resid_mid
if not self.cfg.use_hook_mlp_in
else self.hook_mlp_in(resid_mid.clone())
)
normalized_resid_mid = self.ln2(mlp_in)
mlp_out = self.hook_mlp_out(
self.mlp(normalized_resid_mid)
) # [batch, pos, d_model]
Expand Down Expand Up @@ -1012,9 +1017,14 @@ def add_head_dimension(tensor):
)
)
resid_mid = self.hook_resid_mid(resid_pre + attn_out)
normalized_resid_mid = self.ln1(resid_mid)

mlp_out = self.hook_mlp_out(self.mlp(self.hook_mlp_in(normalized_resid_mid)))
mlp_in = (
resid_mid
if not self.cfg.use_hook_mlp_in
else self.hook_mlp_in(resid_mid.clone())
)
normalized_resid_mid = self.ln1(mlp_in)
mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid))
resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out)
normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post))

Expand Down

0 comments on commit feefcf1

Please sign in to comment.