diff --git a/tests/test_utils.py b/tests/test_utils.py index 856e2d67..a3462d78 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -192,8 +192,18 @@ def assert_expected_namedtuple( def init_weights_with_constant(model: nn.Module, constant: float = 1.0) -> None: - for p in model.parameters(): + for n, p in model.named_parameters(): nn.init.constant_(p, constant) + # reduce the change to the tests + for k in { + "text_projection.bias", + "pooled_projection.bias", + "output_projection.bias", + "vision_proj.bias", + }: + if n.endswith(k): + nn.init.constant_(p, 0.0) + break def tensor_hash(x: torch.tensor, scaling=0.05, buckets=1000) -> torch.tensor: