Skip to content

Commit

Permalink
Support for reduced precision (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
glerzing committed Jul 15, 2023
1 parent 218ebd6 commit bfceeaf
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 75 deletions.
23 changes: 13 additions & 10 deletions tests/acceptance/test_hooked_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,6 @@ def test_run_with_cache(our_bert, huggingface_bert, hello_world_tokens):
assert "mlm_head.ln.hook_normalized" in cache


@pytest.mark.skipif(
torch.backends.mps.is_available(),
reason="bfloat16 unsupported by MPS: https://github.com/pytorch/pytorch/issues/78168",
)
def test_from_pretrained_dtype():
"""Check that the parameter `torch_dtype` works"""
model = HookedEncoder.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
assert model.W_K.dtype == torch.bfloat16


def test_from_pretrained_revision():
"""
Check that the from_pretrained parameter `revision` (= git version) works
Expand All @@ -161,6 +151,19 @@ def test_from_pretrained_revision():
raise AssertionError("Should have raised an error")


@pytest.mark.skipif(
torch.backends.mps.is_available() or not torch.cuda.is_available(),
reason="bfloat16 unsupported by MPS: https://github.com/pytorch/pytorch/issues/78168 or no GPU",
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_half_precision(dtype):
"""Check the 16 bits loading and inferences."""
model = HookedEncoder.from_pretrained(MODEL_NAME, torch_dtype=dtype)
assert model.W_K.dtype == dtype

_ = model(model.tokenizer("Hello, world", return_tensors="pt")["input_ids"])


def test_predictions(our_bert, huggingface_bert, tokenizer):
input_ids = tokenizer("The [MASK] sat on the mat", return_tensors="pt")["input_ids"]

Expand Down
88 changes: 77 additions & 11 deletions tests/acceptance/test_hooked_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest
import torch
from transformers import AutoConfig
from transformers import AutoConfig, AutoModelForCausalLM

from transformer_lens import HookedTransformer
from transformer_lens.components import LayerNormPre
Expand Down Expand Up @@ -136,16 +136,6 @@ def test_from_pretrained_no_processing(name, expected_loss):
assert (reff_loss.item() - expected_loss) < 4e-5


@pytest.mark.skipif(
torch.backends.mps.is_available(),
reason="bfloat16 unsupported by MPS: https://github.com/pytorch/pytorch/issues/78168",
)
def test_from_pretrained_dtype():
"""Check that the parameter `torch_dtype` works"""
model = HookedTransformer.from_pretrained("solu-1l", torch_dtype=torch.bfloat16)
assert model.W_K.dtype == torch.bfloat16


def test_process_weights_inplace():
"""Check that process_weights_ works"""
model = HookedTransformer.from_pretrained_no_processing("gpt2-small")
Expand All @@ -170,6 +160,82 @@ def test_from_pretrained_revision():
raise AssertionError("Should have raised an error")


def check_similarity_with_hf_model(tl_model, hf_model, prompt="Hello, world!"):
"""
Check that the TransformerLens model and the HuggingFace model
give approximately the same results.
The logits typically differ by a constant value, but check only the results
after the softmax because this is what matters most.
"""
tokens = tl_model.tokenizer.encode(prompt, return_tensors="pt")
logits = tl_model(tokens, prepend_bos=False)
hf_logits = hf_model(tokens).logits
assert torch.allclose(
torch.softmax(logits, dim=-1), torch.softmax(hf_logits, dim=-1), atol=1e-5
)


def check_performance(tl_model, hf_model, margin=0.01):
"""
Check that the TransformerLens model and the HuggingFace have
approximately the same confidence in the expected answer.
"""
prompt = " Unable"
tokens = tl_model.tokenizer(prompt, return_tensors="pt")["input_ids"]

expected_token = tl_model.tokenizer.encode(" to")[
0
] # Assume this is the expected token to predict

tl_logits = tl_model(tokens, prepend_bos=False)[0, -1].float()
hf_logits = hf_model(tokens).logits[0, -1].float()
tl_prob = torch.softmax(tl_logits, dim=-1)[expected_token].item()
hf_prob = torch.softmax(hf_logits, dim=-1)[expected_token].item()
assert tl_prob + margin > hf_prob


def check_dtype(dtype, margin=0.01):
"""Check the loading and inferences for different dtypes."""
for model_path in ["gpt2", "roneneldan/TinyStories-33M", "EleutherAI/pythia-70m"]:
model = HookedTransformer.from_pretrained(model_path, torch_dtype=dtype)
hf_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
).to("cuda" if torch.cuda.is_available() else "cpu")

for layer_name, layer in model.state_dict().items():
assert layer.dtype in [dtype, torch.bool] or "IGNORE" in layer_name

check_performance(model, hf_model, margin)

# Check that generate doesn't throw an error
_ = model.generate("Hello, World!")

del model
del hf_model
gc.collect()


@pytest.mark.parametrize("dtype", [torch.float64, torch.float32])
def test_dtypes(dtype):
check_dtype(dtype, margin=5e-5)


@pytest.mark.skipif(
torch.backends.mps.is_available() or not torch.cuda.is_available(),
reason="bfloat16 unsupported by MPS: https://github.com/pytorch/pytorch/issues/78168 or no GPU",
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_half_precision(dtype):
"""Check the 16 bits loading and inferences.
Note that bfloat16 is generally preferred to float16 for ML due to numerical instabilities,
and some float16 operations require having a GPU.
bfloat16 can be used without GPU, but surprisingly it doesn't give the same results in this case.
"""
check_dtype(dtype, margin=0.005)


@torch.no_grad()
def test_pos_embed_hook():
"""
Expand Down
9 changes: 5 additions & 4 deletions transformer_lens/HookedEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ def from_pretrained(
"that the last LayerNorm in a block cannot be folded."
)

assert not (
from_pretrained_kwargs.get("load_in_8bit", False)
or from_pretrained_kwargs.get("load_in_4bit", False)
), "Quantization not supported"

official_model_name = loading.get_official_model_name(model_name)

cfg = loading.get_pretrained_model_config(
Expand All @@ -235,10 +240,6 @@ def from_pretrained(

model = cls(cfg, tokenizer, move_to_device=False)

dtype = from_pretrained_kwargs.get("torch_dtype", None)
if dtype is not None:
model = model.to(dtype)

model.load_state_dict(state_dict, strict=False)

if move_to_device:
Expand Down
11 changes: 5 additions & 6 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,10 @@ def from_pretrained(
functions when compatible. For some models or arguments it doesn't work, especially for
models that are not internally loaded with HuggingFace's from_pretrained (e.g. SoLU models).
"""
assert not (
from_pretrained_kwargs.get("load_in_8bit", False)
or from_pretrained_kwargs.get("load_in_4bit", False)
), "Quantization not supported"

# Get the model name used in HuggingFace, rather than the alias.
official_model_name = loading.get_official_model_name(model_name)
Expand Down Expand Up @@ -915,10 +919,6 @@ def from_pretrained(
# Create the HookedTransformer object
model = cls(cfg, tokenizer, move_to_device=False)

dtype = from_pretrained_kwargs.get("torch_dtype", None)
if dtype is not None:
model = model.to(dtype)

model.load_and_process_state_dict(
state_dict,
fold_ln=fold_ln,
Expand Down Expand Up @@ -1019,7 +1019,6 @@ def load_and_process_state_dict(
model_name (str, optional): checks the model name for special cases of state dict loading. Only used for
Redwood 2L model currently
"""

state_dict = self.fill_missing_keys(state_dict)
if fold_ln:
if self.cfg.normalization_type not in ["LN", "LNPre"]:
Expand Down Expand Up @@ -1407,7 +1406,7 @@ def generate(
prepend_bos: Optional[bool] = None,
return_type: Optional[str] = "input",
verbose: bool = True,
) -> Float[torch.Tensor, "batch pos_plus_new_tokens"]:
) -> Union[Int[torch.Tensor, "batch pos_plus_new_tokens"], str]:
"""
Sample tokens from the model until the model outputs eos_token or max_new_tokens is reached.
Expand Down
2 changes: 2 additions & 0 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class HookedTransformerConfig:
trained with this, heads often use the first position as a resting position and accordingly lose information from
the first token, so this empirically seems to give better results. Call set_default_prepend_bos(False) to change
this default value to False.
dtype (torch.dtype, *optional*): The model's dtype. Defaults to torch.float32.
"""

n_layers: int
Expand Down Expand Up @@ -178,6 +179,7 @@ class HookedTransformerConfig:
use_hook_tokens: bool = False
gated_mlp: bool = False
default_prepend_bos: bool = True
dtype: torch.dtype = torch.float32

def __post_init__(self):
if self.n_heads == -1:
Expand Down
Loading

0 comments on commit bfceeaf

Please sign in to comment.