Skip to content

Commit

Permalink
internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 281531627
  • Loading branch information
DisentanglementLib Team authored and obachem committed Nov 20, 2019
1 parent a070670 commit a64b8b9
Show file tree
Hide file tree
Showing 10 changed files with 739 additions and 2 deletions.
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
It supports a variety of different models, metrics and data sets:

* *Models*: BetaVAE, FactorVAE, BetaTCVAE, DIP-VAE
* *Metrics*: BetaVAE score, FactorVAE score, Mutual Information Gap, SAP score, DCI, MCE, IRS
* *Metrics*: BetaVAE score, FactorVAE score, Mutual Information Gap, SAP score, DCI, MCE, IRS, UDR
* *Data sets*: dSprites, Color/Noisy/Scream-dSprites, SmallNORB, Cars3D, and Shapes3D
* It also includes 10'800 pretrained disentanglement models (see below for details).

Expand Down Expand Up @@ -262,6 +262,25 @@ If you only want to reevaluate an already trained model using the evaluation pro
dlib_reproduce --model_dir=<model_output_directory> --output_directory=<output> --study=fairness_study_v1
```

## UDR experiments

The library also includes the code for the Unsupervised Disentanglement Ranking (UDR) method proposed in the following paper in `disentanglement_lib/bin/dlib_udr`:
> [**Unsupervised Model Selection for Variational Disentangled Representation Learning**](https://arxiv.org/abs/1905.12614)
> *Sunny Duan, Loic Matthey, Andre Saraiva, Nicholas Watters, Christopher P. Burgess, Alexander Lerchner, Irina Higgins*.
UDR can be applied to newly trained models (e.g. obtained by running
`dlib_reproduce`) or to the existing pretrained models. After the models have
been trained, their UDR scores can be computed by running:

```
dlib_udr --model_dirs=<model_output_directory1>,<model_output_directory2> \
--output_directory=<output>
```

The scores will be exported to `<output>/results/aggregate/evaluation.json`
under the model_scores attribute. The scores will be presented in the order of
the input model directories.

## Feedback
Please send any feedback to bachem@google.com and francesco.locatello@tuebingen.mpg.de.

Expand Down
2 changes: 2 additions & 0 deletions bin/dlib_tests
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ python -m disentanglement_lib.evaluation.metrics.dci_test
python -m disentanglement_lib.evaluation.metrics.sap_score_test
python -m disentanglement_lib.evaluation.metrics.utils_test
python -m disentanglement_lib.evaluation.metrics.unsupervised_metrics_test
python -m disentanglement_lib.evaluation.udr.evaluate_test
python -m disentanglement_lib.evaluation.udr.metrics.udr_test
python -m disentanglement_lib.visualize.visualize_util_test
python -m disentanglement_lib.visualize.visualize_dataset_test
python -m disentanglement_lib.visualize.visualize_model_test
59 changes: 59 additions & 0 deletions bin/dlib_udr
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluation module for disentangled representations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import app
from absl import flags
from disentanglement_lib.evaluation.udr import evaluate
from tensorflow import gfile
import gin.tf

FLAGS = flags.FLAGS

flags.DEFINE_list("model_dirs", [], "List of models to run UDR over.")

flags.DEFINE_string("output_dir", None, "Directory to save representation to.")

flags.DEFINE_multi_string("gin_config", [],
"List of paths to the config files.")

flags.DEFINE_multi_string("gin_bindings", [],
"Newline separated list of Gin parameter bindings.")

flags.DEFINE_string("gin_evaluation_config_glob", None,
"Path to glob pattern to evaluation configs.")


def main(unused_argv):
if FLAGS.gin_evaluation_config_glob is not None:
for gin_eval_config in sorted(gfile.Glob(FLAGS.gin_evaluation_config_glob)):
metric_name = os.path.basename(gin_eval_config).replace(".gin", "")
metric_dir = os.path.join(FLAGS.output_dir, metric_name)
gin.parse_config_files_and_bindings(
[gin_eval_config], ["evaluation.name = '{}'".format(metric_name)])
evaluate.evaluate(FLAGS.model_dirs, metric_dir)
gin.clear_config()
else:
gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
evaluate.evaluate(FLAGS.model_dirs, FLAGS.output_dir)


if __name__ == "__main__":
app.run(main)
24 changes: 24 additions & 0 deletions disentanglement_lib/config/tests/methods/udr/udr.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

evaluation.evaluation_fn = @udr_sklearn
dataset.name="auto"
evaluation.random_seed = 0
udr_sklearn.batch_size=10
udr_sklearn.num_data_points=1000
udr_sklearn.correlation_matrix="lasso"
udr_sklearn.filter_low_kl=True
udr_sklearn.include_raw_correlations=True
udr_sklearn.kl_filter_threshold = 0.01
24 changes: 24 additions & 0 deletions disentanglement_lib/config/tests/methods/udr/udr_spearman.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

evaluation.evaluation_fn = @udr_sklearn
dataset.name="auto"
evaluation.random_seed = 0
udr_sklearn.batch_size=10
udr_sklearn.num_data_points=1000
udr_sklearn.correlation_matrix="spearman"
udr_sklearn.filter_low_kl=True
udr_sklearn.include_raw_correlations=True
udr_sklearn.kl_filter_threshold = 0.01
105 changes: 105 additions & 0 deletions disentanglement_lib/evaluation/udr/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluation module for computing UDR.
Binary for computing the UDR and UDR-A2A scores specified in "Unsupervised
Model Selection for Variational Disentangled Representation Learning"
(https://arxiv.org/abs/1905.12614)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import contextlib
import os
import time

from absl import flags
from disentanglement_lib.data.ground_truth import named_data
from disentanglement_lib.evaluation.udr.metrics import udr # pylint: disable=unused-import
from disentanglement_lib.utils import results
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import gin.tf

FLAGS = flags.FLAGS


@gin.configurable("evaluation", blacklist=["model_dirs", "output_dir"])
def evaluate(model_dirs,
output_dir,
evaluation_fn=gin.REQUIRED,
random_seed=gin.REQUIRED,
name=""):
"""Loads a trained estimator and evaluates it according to beta-VAE metric."""
# The name will be part of the gin config and can be used to tag results.
del name

# Set up time to keep track of elapsed time in results.
experiment_timer = time.time()

# Automatically set the proper dataset if necessary. We replace the active
# gin config as this will lead to a valid gin config file where the dataset
# is present.
if gin.query_parameter("dataset.name") == "auto":
# Obtain the dataset name from the gin config of the previous step.
gin_config_file = os.path.join(model_dirs[0], "results", "gin", "train.gin")
gin_dict = results.gin_dict(gin_config_file)
with gin.unlock_config():
print(gin_dict["dataset.name"])
gin.bind_parameter("dataset.name",
gin_dict["dataset.name"].replace("'", ""))

output_dir = os.path.join(output_dir)
if tf.io.gfile.isdir(output_dir):
tf.io.gfile.rmtree(output_dir)

dataset = named_data.get_named_ground_truth_data()

with contextlib.ExitStack() as stack:
representation_functions = []
eval_functions = [
stack.enter_context(
hub.eval_function_for_module(os.path.join(model_dir, "tfhub")))
for model_dir in model_dirs
]
for f in eval_functions:

def _representation_function(x, f=f):

def compute_gaussian_kl(z_mean, z_logvar):
return np.mean(
0.5 * (np.square(z_mean) + np.exp(z_logvar) - z_logvar - 1),
axis=0)

encoding = f(dict(images=x), signature="gaussian_encoder", as_dict=True)

return np.array(encoding["mean"]), compute_gaussian_kl(
np.array(encoding["mean"]), np.array(encoding["logvar"]))

representation_functions.append(_representation_function)

results_dict = evaluation_fn(
dataset,
representation_functions,
random_state=np.random.RandomState(random_seed))

original_results_dir = os.path.join(model_dirs[0], "results")
results_dir = os.path.join(output_dir, "results")
results_dict["elapsed_time"] = time.time() - experiment_timer
results.update_result_directory(results_dir, "evaluation", results_dict,
original_results_dir)
56 changes: 56 additions & 0 deletions disentanglement_lib/evaluation/udr/evaluate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for evaluate.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import absltest
from absl.testing import parameterized
from disentanglement_lib.evaluation.udr import evaluate
from disentanglement_lib.methods.unsupervised import train
from disentanglement_lib.utils import resources
import gin.tf


class EvaluateTest(parameterized.TestCase):

def setUp(self):
super(EvaluateTest, self).setUp()
self.model1_dir = self.create_tempdir(
"model1/model", cleanup=absltest.TempFileCleanup.OFF).full_path
self.model2_dir = self.create_tempdir(
"model2/model", cleanup=absltest.TempFileCleanup.OFF).full_path
model_config = resources.get_file(
"config/tests/methods/unsupervised/train_test.gin")
gin.clear_config()
train.train_with_gin(self.model1_dir, True, [model_config])
train.train_with_gin(self.model2_dir, True, [model_config])

self.output_dir = self.create_tempdir(
"output", cleanup=absltest.TempFileCleanup.OFF).full_path

@parameterized.parameters(
list(resources.get_files_in_folder("config/tests/methods/udr")))
def test_evaluate(self, gin_config):
# We clear the gin config before running. Otherwise, if a prior test fails,
# the gin config is locked and the current test fails.
gin.clear_config()
gin.parse_config_files_and_bindings([gin_config], None)
evaluate.evaluate([self.model1_dir, self.model2_dir], self.output_dir)


if __name__ == "__main__":
absltest.main()
Loading

0 comments on commit a64b8b9

Please sign in to comment.