Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for reduced precision (#104) #317

Merged
merged 5 commits into from
Jul 21, 2023

Conversation

glerzing
Copy link
Contributor

@glerzing glerzing commented Jun 9, 2023

Description

HookedTransformerKeyValueCacheEntry wasn't compatible with other dtypes, so the argument past_kv_cache_entry wasn't working.
There is also an optimization. Instead of initializing the model in torch.float32 and then converting it to the desired dtype, we now directly initialize the model layers in the desired dtype.
The attribute dtype was added to the configuration with the default value torch.float32. I hope this is ok. It's practical to have access to it from anywhere.
Also added a test for 8 bits loading, which is skipped if there is no GPU.

Fixes # 104

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@jbloomAus
Copy link
Collaborator

I'd rather not accept this until we have resolved the issue described here: #104 (comment)

Moreover, we should have tests for reduced precision model performance if it's the kind of thing that can be impacted by accident.

The rest seems fine. Thanks for the contribution!

To do list:

  • write tests for precision model performance
  • pass test for precision model_performance

@jbloomAus jbloomAus added the seen_by_maintainers Confirms that a maintainer is aware of this card. label Jun 14, 2023
@slavachalnev
Copy link
Contributor

slavachalnev commented Jun 27, 2023

What is test_hooked_transformer.py::test_8bits testing?. The test passes as it doesn't assert anything but if I load the model as is done in the test

model = HookedTransformer.from_pretrained("solu-1l", load_in_8bit=True, device_map="auto")

then the model is entirely in float32.

@glerzing
Copy link
Contributor Author

glerzing commented Jun 27, 2023

Torch doesn't support 8-bit floating point models, so there is no corresponding torch dtype and it relies on different libraries like bitsandbytes and accelerate. It is bug-prone, so I should probably make further checks. But it would perhaps be wiser to set in bfloat16 whatever is not in 8-bits, since there is no float8 dtype. I'll try to figure it out and update the PR in the coming days.
And I will also take into account the comment about the test for the model precision performance, sorry for the delay.

@glerzing
Copy link
Contributor Author

glerzing commented Jun 28, 2023

Actually my previous answer was a bit inaccurate. Quantization doesn't use 8 bit floats, it uses 8 bit integers with a scaling factor. And there is an int8 dtype in torch, but it doesn't seem to be used by the transformers library, as I understand it the hidden states are typically in half precision and only some operations are quantized. I recommend this article if you are interested in the details : https://huggingface.co/blog/hf-bitsandbytes-integration.
It's a special type of computation, and I'm not sure it's worth trying to make it work in TransformerLens, probably not (both because it may not be easy to implement and because the lack of precision might make the interpretability analysis less reliable), I'll see.

@slavachalnev
Copy link
Contributor

That's a great blogpost, thanks!

Yeah, it looks like the HuggingFace bitsandbytes integration only quantises nn.Linear layers and HookedTransformer has no nn.Linear layers.

@glerzing
Copy link
Contributor Author

glerzing commented Jul 3, 2023

The fact that we use einsum instead of nn.Linear seems like the main reason for the remaining differences between the predictions of TransformerLens and Hugging Face in half precision.

@@ -377,7 +409,7 @@ def __init__(
else:
raise ValueError(f"Invalid attention type: {self.attn_type}")

self.register_buffer("IGNORE", torch.tensor(-1e5))
self.register_buffer("IGNORE", torch.tensor(torch.finfo(cfg.dtype).min))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be constroversial but I modified the value of the layer "IGNORE" from -1e-5 to torch.finfo(cfg.dtype).min. This shouldn't change much in practice, I just thought it was cleaner and closer to what Hugging Face did. I can revert that if needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll revert that since it is done in the PR #319.

# If using 16 bits, increase the precision to avoid numerical instabilities
q = q.to(torch.float32)
k = k.to(torch.float32)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose to convert to float32 before the attention layer, instead of dividing sooner with self.attn_scale as proposed. Mostly because it is what is done in the GPT NEO architecture in def _attn and it looked a bit more reliable.

@glerzing
Copy link
Contributor Author

glerzing commented Jul 3, 2023

By the way, it's a problem that we can't self-assign issues, because when we chose an issue, we don't know if someone is already on it. Perhaps we could leave public the authorization to self-assign issues. Or we may have to systematically add a comment to indicate when we start working on an issue.
By the way, is there some existing discussion channel (e.g. Discord), or does everything happen on discord ?

@glerzing
Copy link
Contributor Author

glerzing commented Jul 7, 2023

(rebased on the main branch to solve conflicts)

@glerzing
Copy link
Contributor Author

glerzing commented Jul 15, 2023

@jbloomAus, @slavachalnev, thanks in advance if you can give a review.

@jbloomAus
Copy link
Collaborator

I'm on holiday ATM but can review once @slavachalnev has taken a look depending on what's needed. Thanks for doing this!

@slavachalnev
Copy link
Contributor

Looks good but need to fix a few things:

When I run tests on GPU, I get four failures.

FAILED test_hooked_transformer.py::test_dtypes[dtype0] - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking...
FAILED test_hooked_transformer.py::test_dtypes[dtype1] - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking...
FAILED test_hooked_transformer.py::test_half_precision[dtype0] - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking...
FAILED test_hooked_transformer.py::test_half_precision[dtype1] - RuntimeError: "baddbmm_with_gemm" not implemented for 'Half'

The first two are easily fixed by moving the tokens to device in the test e.g.

def check_performance(tl_model, hf_model, margin=0.01, device='cpu'):  # Added device here
    """
    Check that the TransformerLens model and the HuggingFace have
    approximately the same confidence in the expected answer.
    """
    prompt = " Unable"
    tokens = tl_model.tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)  # Added device here

    expected_token = tl_model.tokenizer.encode(" to")[
        0
    ]  # Assume this is the expected token to predict

    tl_logits = tl_model(tokens, prepend_bos=False)[0, -1].float()
    hf_logits = hf_model(tokens).logits[0, -1].float()
    tl_prob = torch.softmax(tl_logits, dim=-1)[expected_token].item()
    hf_prob = torch.softmax(hf_logits, dim=-1)[expected_token].item()
    assert tl_prob + margin > hf_prob

Then we are left with a bfloat16 precision error being higher than the set margin: assert (0.15567666292190552 + 0.005) > 0.17949192225933075 Probably acceptable to increase margin but not sure.

The final error is caused by a bunch of operations not being defined for Half precision on CPU. In load_from_pretrained we do preprocessing on CPU which causes an error. I think that when loading from pertrained, we need to use float32 and move model to dtype after it’s loaded. What do you think?

I would also modify move_to_and_update_config function so that it updates config when you move the model to a different dtype. The reason this matters is you check the dtype in the config during the Attention forward pass so it should be up to date.

@neelnanda-io
Copy link
Collaborator

neelnanda-io commented Jul 16, 2023 via email

@neelnanda-io
Copy link
Collaborator

Also thanks so much for making the PR! Real mixed precision support has been on my wishlist for ages

@glerzing
Copy link
Contributor Author

glerzing commented Jul 17, 2023

Oops, sorry @slavachalnev, I missed the bugs induced by the rebase.

I added a warning to adivse using from_pretrained_no_processing instead of from_pretrained, because the effect on the result doesn't seem negligible.

Most of the time the results are pretty similar between TransformerLens and Hugging Face. But it's still a bit weird that for EleutherAI/pythia-70m with bfloat16 transformer_lens with from_pretrained_no_processing gives a probability of 0.836 for the token " to" whereas Hugging Face gives only 0.002. Since " to" is considered the correct token it passes ok here.

@slavachalnev
Copy link
Contributor

Now getting an error in test_half_precision[dtype0]


dtype = torch.bfloat16

    @pytest.mark.skipif(
        torch.backends.mps.is_available() or not torch.cuda.is_available(),
        reason="bfloat16 unsupported by MPS: https://github.com/pytorch/pytorch/issues/78168 or no GPU",
    )
    @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
    def test_half_precision(dtype):
        """Check the 16 bits loading and inferences.
        Note that bfloat16 is generally preferred to float16 for ML due to numerical instabilities,
        and some float16 operations require having a GPU.
        bfloat16 can be used without GPU, but surprisingly it doesn't give the same results in this case.
        """
>       check_dtype(dtype, margin=0.05, no_processing=True)

tests/acceptance/test_hooked_transformer.py:245: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/acceptance/test_hooked_transformer.py:222: in check_dtype
    _ = model.generate("Hello, World!")
.venv/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27: in decorate_context
    return func(*args, **kwargs)
transformer_lens/HookedTransformer.py:1516: in generate
    sampled_tokens = utils.sample_logits(
transformer_lens/utils.py:348: in sample_logits
    return torch.distributions.categorical.Categorical(logits=final_logits).sample()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = Categorical(probs: torch.Size([1, 50257]), logits: torch.Size([1, 50257])), sample_shape = torch.Size([])

    def sample(self, sample_shape=torch.Size()):
        if not isinstance(sample_shape, torch.Size):
            sample_shape = torch.Size(sample_shape)
        probs_2d = self.probs.reshape(-1, self._num_events)
>       samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T
E       RuntimeError: "multinomial_kernel_cuda" not implemented for 'BFloat16'

May want to convert probs to float32 before sampling.

Otherwise, changes look good.

@slavachalnev
Copy link
Contributor

@jbloomAus this looks good to me

@jbloomAus jbloomAus merged commit 12ff439 into TransformerLensOrg:main Jul 21, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
seen_by_maintainers Confirms that a maintainer is aware of this card.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants