-
Notifications
You must be signed in to change notification settings - Fork 203
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 281531627
- Loading branch information
Showing
10 changed files
with
739 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2018 The DisentanglementLib Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Evaluation module for disentangled representations.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
import os | ||
from absl import app | ||
from absl import flags | ||
from disentanglement_lib.evaluation.udr import evaluate | ||
from tensorflow import gfile | ||
import gin.tf | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_list("model_dirs", [], "List of models to run UDR over.") | ||
|
||
flags.DEFINE_string("output_dir", None, "Directory to save representation to.") | ||
|
||
flags.DEFINE_multi_string("gin_config", [], | ||
"List of paths to the config files.") | ||
|
||
flags.DEFINE_multi_string("gin_bindings", [], | ||
"Newline separated list of Gin parameter bindings.") | ||
|
||
flags.DEFINE_string("gin_evaluation_config_glob", None, | ||
"Path to glob pattern to evaluation configs.") | ||
|
||
|
||
def main(unused_argv): | ||
if FLAGS.gin_evaluation_config_glob is not None: | ||
for gin_eval_config in sorted(gfile.Glob(FLAGS.gin_evaluation_config_glob)): | ||
metric_name = os.path.basename(gin_eval_config).replace(".gin", "") | ||
metric_dir = os.path.join(FLAGS.output_dir, metric_name) | ||
gin.parse_config_files_and_bindings( | ||
[gin_eval_config], ["evaluation.name = '{}'".format(metric_name)]) | ||
evaluate.evaluate(FLAGS.model_dirs, metric_dir) | ||
gin.clear_config() | ||
else: | ||
gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) | ||
evaluate.evaluate(FLAGS.model_dirs, FLAGS.output_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The DisentanglementLib Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
evaluation.evaluation_fn = @udr_sklearn | ||
dataset.name="auto" | ||
evaluation.random_seed = 0 | ||
udr_sklearn.batch_size=10 | ||
udr_sklearn.num_data_points=1000 | ||
udr_sklearn.correlation_matrix="lasso" | ||
udr_sklearn.filter_low_kl=True | ||
udr_sklearn.include_raw_correlations=True | ||
udr_sklearn.kl_filter_threshold = 0.01 |
24 changes: 24 additions & 0 deletions
24
disentanglement_lib/config/tests/methods/udr/udr_spearman.gin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The DisentanglementLib Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
evaluation.evaluation_fn = @udr_sklearn | ||
dataset.name="auto" | ||
evaluation.random_seed = 0 | ||
udr_sklearn.batch_size=10 | ||
udr_sklearn.num_data_points=1000 | ||
udr_sklearn.correlation_matrix="spearman" | ||
udr_sklearn.filter_low_kl=True | ||
udr_sklearn.include_raw_correlations=True | ||
udr_sklearn.kl_filter_threshold = 0.01 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The DisentanglementLib Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Evaluation module for computing UDR. | ||
Binary for computing the UDR and UDR-A2A scores specified in "Unsupervised | ||
Model Selection for Variational Disentangled Representation Learning" | ||
(https://arxiv.org/abs/1905.12614) | ||
""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import contextlib | ||
import os | ||
import time | ||
|
||
from absl import flags | ||
from disentanglement_lib.data.ground_truth import named_data | ||
from disentanglement_lib.evaluation.udr.metrics import udr # pylint: disable=unused-import | ||
from disentanglement_lib.utils import results | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_hub as hub | ||
import gin.tf | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
@gin.configurable("evaluation", blacklist=["model_dirs", "output_dir"]) | ||
def evaluate(model_dirs, | ||
output_dir, | ||
evaluation_fn=gin.REQUIRED, | ||
random_seed=gin.REQUIRED, | ||
name=""): | ||
"""Loads a trained estimator and evaluates it according to beta-VAE metric.""" | ||
# The name will be part of the gin config and can be used to tag results. | ||
del name | ||
|
||
# Set up time to keep track of elapsed time in results. | ||
experiment_timer = time.time() | ||
|
||
# Automatically set the proper dataset if necessary. We replace the active | ||
# gin config as this will lead to a valid gin config file where the dataset | ||
# is present. | ||
if gin.query_parameter("dataset.name") == "auto": | ||
# Obtain the dataset name from the gin config of the previous step. | ||
gin_config_file = os.path.join(model_dirs[0], "results", "gin", "train.gin") | ||
gin_dict = results.gin_dict(gin_config_file) | ||
with gin.unlock_config(): | ||
print(gin_dict["dataset.name"]) | ||
gin.bind_parameter("dataset.name", | ||
gin_dict["dataset.name"].replace("'", "")) | ||
|
||
output_dir = os.path.join(output_dir) | ||
if tf.io.gfile.isdir(output_dir): | ||
tf.io.gfile.rmtree(output_dir) | ||
|
||
dataset = named_data.get_named_ground_truth_data() | ||
|
||
with contextlib.ExitStack() as stack: | ||
representation_functions = [] | ||
eval_functions = [ | ||
stack.enter_context( | ||
hub.eval_function_for_module(os.path.join(model_dir, "tfhub"))) | ||
for model_dir in model_dirs | ||
] | ||
for f in eval_functions: | ||
|
||
def _representation_function(x, f=f): | ||
|
||
def compute_gaussian_kl(z_mean, z_logvar): | ||
return np.mean( | ||
0.5 * (np.square(z_mean) + np.exp(z_logvar) - z_logvar - 1), | ||
axis=0) | ||
|
||
encoding = f(dict(images=x), signature="gaussian_encoder", as_dict=True) | ||
|
||
return np.array(encoding["mean"]), compute_gaussian_kl( | ||
np.array(encoding["mean"]), np.array(encoding["logvar"])) | ||
|
||
representation_functions.append(_representation_function) | ||
|
||
results_dict = evaluation_fn( | ||
dataset, | ||
representation_functions, | ||
random_state=np.random.RandomState(random_seed)) | ||
|
||
original_results_dir = os.path.join(model_dirs[0], "results") | ||
results_dir = os.path.join(output_dir, "results") | ||
results_dict["elapsed_time"] = time.time() - experiment_timer | ||
results.update_result_directory(results_dir, "evaluation", results_dict, | ||
original_results_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 The DisentanglementLib Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for evaluate.py.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from absl.testing import absltest | ||
from absl.testing import parameterized | ||
from disentanglement_lib.evaluation.udr import evaluate | ||
from disentanglement_lib.methods.unsupervised import train | ||
from disentanglement_lib.utils import resources | ||
import gin.tf | ||
|
||
|
||
class EvaluateTest(parameterized.TestCase): | ||
|
||
def setUp(self): | ||
super(EvaluateTest, self).setUp() | ||
self.model1_dir = self.create_tempdir( | ||
"model1/model", cleanup=absltest.TempFileCleanup.OFF).full_path | ||
self.model2_dir = self.create_tempdir( | ||
"model2/model", cleanup=absltest.TempFileCleanup.OFF).full_path | ||
model_config = resources.get_file( | ||
"config/tests/methods/unsupervised/train_test.gin") | ||
gin.clear_config() | ||
train.train_with_gin(self.model1_dir, True, [model_config]) | ||
train.train_with_gin(self.model2_dir, True, [model_config]) | ||
|
||
self.output_dir = self.create_tempdir( | ||
"output", cleanup=absltest.TempFileCleanup.OFF).full_path | ||
|
||
@parameterized.parameters( | ||
list(resources.get_files_in_folder("config/tests/methods/udr"))) | ||
def test_evaluate(self, gin_config): | ||
# We clear the gin config before running. Otherwise, if a prior test fails, | ||
# the gin config is locked and the current test fails. | ||
gin.clear_config() | ||
gin.parse_config_files_and_bindings([gin_config], None) | ||
evaluate.evaluate([self.model1_dir, self.model2_dir], self.output_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
Oops, something went wrong.