Skip to content

Commit

Permalink
Merge pull request #745 from TransformerLensOrg/dev
Browse files Browse the repository at this point in the history
release 2.7.1
  • Loading branch information
bryce13950 authored Oct 4, 2024
2 parents cc3103f + fa2989d commit 1d8b1d8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
19 changes: 16 additions & 3 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,18 @@
"""
import logging
import os
from typing import Dict, List, NamedTuple, Optional, Tuple, Union, cast, overload
from typing import (
Dict,
List,
NamedTuple,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)

import einops
import numpy as np
Expand Down Expand Up @@ -67,6 +78,8 @@
"bf16": torch.bfloat16,
}

T = TypeVar("T", bound="HookedTransformer")


class Output(NamedTuple):
"""Output Named Tuple.
Expand Down Expand Up @@ -1053,7 +1066,7 @@ def move_model_modules_to_device(self):

@classmethod
def from_pretrained(
cls,
cls: Type[T],
model_name: str,
fold_ln: bool = True,
center_writing_weights: bool = True,
Expand All @@ -1072,7 +1085,7 @@ def from_pretrained(
dtype="float32",
first_n_layers: Optional[int] = None,
**from_pretrained_kwargs,
) -> "HookedTransformer":
) -> T:
"""Load in a Pretrained Model.
Load in pretrained model weights to the HookedTransformer format and optionally to do some
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def download_file_from_hf(
)

if file_path.endswith(".pth") or force_is_torch:
return torch.load(file_path, map_location="cpu")
return torch.load(file_path, map_location="cpu", weights_only=False)
elif file_path.endswith(".json"):
return json.load(open(file_path, "r"))
else:
Expand Down

0 comments on commit 1d8b1d8

Please sign in to comment.