Skip to content

Commit

Permalink
Fix test by using sets directly with id
Browse files Browse the repository at this point in the history
Behavior of tensorflow ObjectIdentitySet has change in 2.15,
resulting in the set equality comparing wrappers instead of
wrapped values. See tensorflow/tensorflow@bc28335.
  • Loading branch information
khurram-ghani committed Jan 8, 2024
1 parent 13ba0ed commit 3a785fc
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions tests/gpflux/layers/test_dedup_trackable_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pytest
import tensorflow as tf
from tensorflow.python.ops.resource_variable_ops import ResourceVariable
from tensorflow.python.util import object_identity

import gpflow
from gpflow.utilities import parameter_dict
Expand Down Expand Up @@ -141,11 +140,10 @@ def test_weights_equals_deduplicated_parameter_dict(model):
# We filter out the parameters of type ResourceVariable.
# They have been added to the model by the `add_metric` call in the layer.
parameters = [p for p in parameter_dict(model).values() if not isinstance(p, ResourceVariable)]
variables = map(lambda p: p.unconstrained_variable, parameters)
deduplicate_variables = object_identity.ObjectIdentitySet(variables)
variables = {id(p.unconstrained_variable) for p in parameters}

weights = model.trainable_weights
assert len(weights) == len(deduplicate_variables)
assert len(weights) == len(variables)

weights_set = object_identity.ObjectIdentitySet(weights)
assert weights_set == deduplicate_variables
weights_set = {id(w) for w in weights}
assert weights_set == variables

0 comments on commit 3a785fc

Please sign in to comment.