diff --git a/README.md b/README.md index 2f0f39f7..d8f5395f 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ It supports a variety of different models, metrics and data sets: * *Models*: BetaVAE, FactorVAE, BetaTCVAE, DIP-VAE -* *Metrics*: BetaVAE score, FactorVAE score, Mutual Information Gap, SAP score, DCI, MCE, IRS +* *Metrics*: BetaVAE score, FactorVAE score, Mutual Information Gap, SAP score, DCI, MCE, IRS, UDR * *Data sets*: dSprites, Color/Noisy/Scream-dSprites, SmallNORB, Cars3D, and Shapes3D * It also includes 10'800 pretrained disentanglement models (see below for details). @@ -262,6 +262,25 @@ If you only want to reevaluate an already trained model using the evaluation pro dlib_reproduce --model_dir= --output_directory= --study=fairness_study_v1 ``` +## UDR experiments + +The library also includes the code for the Unsupervised Disentanglement Ranking (UDR) method proposed in the following paper in `disentanglement_lib/bin/dlib_udr`: +> [**Unsupervised Model Selection for Variational Disentangled Representation Learning**](https://arxiv.org/abs/1905.12614) +> *Sunny Duan, Loic Matthey, Andre Saraiva, Nicholas Watters, Christopher P. Burgess, Alexander Lerchner, Irina Higgins*. + +UDR can be applied to newly trained models (e.g. obtained by running +`dlib_reproduce`) or to the existing pretrained models. After the models have +been trained, their UDR scores can be computed by running: + +``` +dlib_udr --model_dirs=, \ + --output_directory= +``` + +The scores will be exported to `/results/aggregate/evaluation.json` +under the model_scores attribute. The scores will be presented in the order of +the input model directories. + ## Feedback Please send any feedback to bachem@google.com and francesco.locatello@tuebingen.mpg.de. diff --git a/bin/dlib_tests b/bin/dlib_tests index 013fba79..82b90e55 100644 --- a/bin/dlib_tests +++ b/bin/dlib_tests @@ -38,6 +38,8 @@ python -m disentanglement_lib.evaluation.metrics.dci_test python -m disentanglement_lib.evaluation.metrics.sap_score_test python -m disentanglement_lib.evaluation.metrics.utils_test python -m disentanglement_lib.evaluation.metrics.unsupervised_metrics_test +python -m disentanglement_lib.evaluation.udr.evaluate_test +python -m disentanglement_lib.evaluation.udr.metrics.udr_test python -m disentanglement_lib.visualize.visualize_util_test python -m disentanglement_lib.visualize.visualize_dataset_test python -m disentanglement_lib.visualize.visualize_model_test diff --git a/bin/dlib_udr b/bin/dlib_udr new file mode 100644 index 00000000..9c883142 --- /dev/null +++ b/bin/dlib_udr @@ -0,0 +1,59 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2018 The DisentanglementLib Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluation module for disentangled representations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +from absl import app +from absl import flags +from disentanglement_lib.evaluation.udr import evaluate +from tensorflow import gfile +import gin.tf + +FLAGS = flags.FLAGS + +flags.DEFINE_list("model_dirs", [], "List of models to run UDR over.") + +flags.DEFINE_string("output_dir", None, "Directory to save representation to.") + +flags.DEFINE_multi_string("gin_config", [], + "List of paths to the config files.") + +flags.DEFINE_multi_string("gin_bindings", [], + "Newline separated list of Gin parameter bindings.") + +flags.DEFINE_string("gin_evaluation_config_glob", None, + "Path to glob pattern to evaluation configs.") + + +def main(unused_argv): + if FLAGS.gin_evaluation_config_glob is not None: + for gin_eval_config in sorted(gfile.Glob(FLAGS.gin_evaluation_config_glob)): + metric_name = os.path.basename(gin_eval_config).replace(".gin", "") + metric_dir = os.path.join(FLAGS.output_dir, metric_name) + gin.parse_config_files_and_bindings( + [gin_eval_config], ["evaluation.name = '{}'".format(metric_name)]) + evaluate.evaluate(FLAGS.model_dirs, metric_dir) + gin.clear_config() + else: + gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) + evaluate.evaluate(FLAGS.model_dirs, FLAGS.output_dir) + + +if __name__ == "__main__": + app.run(main) diff --git a/disentanglement_lib/config/tests/methods/udr/udr.gin b/disentanglement_lib/config/tests/methods/udr/udr.gin new file mode 100644 index 00000000..8d5d9702 --- /dev/null +++ b/disentanglement_lib/config/tests/methods/udr/udr.gin @@ -0,0 +1,24 @@ +# coding=utf-8 +# Copyright 2018 The DisentanglementLib Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +evaluation.evaluation_fn = @udr_sklearn +dataset.name="auto" +evaluation.random_seed = 0 +udr_sklearn.batch_size=10 +udr_sklearn.num_data_points=1000 +udr_sklearn.correlation_matrix="lasso" +udr_sklearn.filter_low_kl=True +udr_sklearn.include_raw_correlations=True +udr_sklearn.kl_filter_threshold = 0.01 diff --git a/disentanglement_lib/config/tests/methods/udr/udr_spearman.gin b/disentanglement_lib/config/tests/methods/udr/udr_spearman.gin new file mode 100644 index 00000000..3fde15dc --- /dev/null +++ b/disentanglement_lib/config/tests/methods/udr/udr_spearman.gin @@ -0,0 +1,24 @@ +# coding=utf-8 +# Copyright 2018 The DisentanglementLib Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +evaluation.evaluation_fn = @udr_sklearn +dataset.name="auto" +evaluation.random_seed = 0 +udr_sklearn.batch_size=10 +udr_sklearn.num_data_points=1000 +udr_sklearn.correlation_matrix="spearman" +udr_sklearn.filter_low_kl=True +udr_sklearn.include_raw_correlations=True +udr_sklearn.kl_filter_threshold = 0.01 diff --git a/disentanglement_lib/evaluation/udr/evaluate.py b/disentanglement_lib/evaluation/udr/evaluate.py new file mode 100644 index 00000000..e6e8a072 --- /dev/null +++ b/disentanglement_lib/evaluation/udr/evaluate.py @@ -0,0 +1,105 @@ +# coding=utf-8 +# Copyright 2018 The DisentanglementLib Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluation module for computing UDR. + +Binary for computing the UDR and UDR-A2A scores specified in "Unsupervised +Model Selection for Variational Disentangled Representation Learning" +(https://arxiv.org/abs/1905.12614) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import os +import time + +from absl import flags +from disentanglement_lib.data.ground_truth import named_data +from disentanglement_lib.evaluation.udr.metrics import udr # pylint: disable=unused-import +from disentanglement_lib.utils import results +import numpy as np +import tensorflow as tf +import tensorflow_hub as hub +import gin.tf + +FLAGS = flags.FLAGS + + +@gin.configurable("evaluation", blacklist=["model_dirs", "output_dir"]) +def evaluate(model_dirs, + output_dir, + evaluation_fn=gin.REQUIRED, + random_seed=gin.REQUIRED, + name=""): + """Loads a trained estimator and evaluates it according to beta-VAE metric.""" + # The name will be part of the gin config and can be used to tag results. + del name + + # Set up time to keep track of elapsed time in results. + experiment_timer = time.time() + + # Automatically set the proper dataset if necessary. We replace the active + # gin config as this will lead to a valid gin config file where the dataset + # is present. + if gin.query_parameter("dataset.name") == "auto": + # Obtain the dataset name from the gin config of the previous step. + gin_config_file = os.path.join(model_dirs[0], "results", "gin", "train.gin") + gin_dict = results.gin_dict(gin_config_file) + with gin.unlock_config(): + print(gin_dict["dataset.name"]) + gin.bind_parameter("dataset.name", + gin_dict["dataset.name"].replace("'", "")) + + output_dir = os.path.join(output_dir) + if tf.io.gfile.isdir(output_dir): + tf.io.gfile.rmtree(output_dir) + + dataset = named_data.get_named_ground_truth_data() + + with contextlib.ExitStack() as stack: + representation_functions = [] + eval_functions = [ + stack.enter_context( + hub.eval_function_for_module(os.path.join(model_dir, "tfhub"))) + for model_dir in model_dirs + ] + for f in eval_functions: + + def _representation_function(x, f=f): + + def compute_gaussian_kl(z_mean, z_logvar): + return np.mean( + 0.5 * (np.square(z_mean) + np.exp(z_logvar) - z_logvar - 1), + axis=0) + + encoding = f(dict(images=x), signature="gaussian_encoder", as_dict=True) + + return np.array(encoding["mean"]), compute_gaussian_kl( + np.array(encoding["mean"]), np.array(encoding["logvar"])) + + representation_functions.append(_representation_function) + + results_dict = evaluation_fn( + dataset, + representation_functions, + random_state=np.random.RandomState(random_seed)) + + original_results_dir = os.path.join(model_dirs[0], "results") + results_dir = os.path.join(output_dir, "results") + results_dict["elapsed_time"] = time.time() - experiment_timer + results.update_result_directory(results_dir, "evaluation", results_dict, + original_results_dir) diff --git a/disentanglement_lib/evaluation/udr/evaluate_test.py b/disentanglement_lib/evaluation/udr/evaluate_test.py new file mode 100644 index 00000000..4bc411ca --- /dev/null +++ b/disentanglement_lib/evaluation/udr/evaluate_test.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright 2018 The DisentanglementLib Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for evaluate.py.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from absl.testing import absltest +from absl.testing import parameterized +from disentanglement_lib.evaluation.udr import evaluate +from disentanglement_lib.methods.unsupervised import train +from disentanglement_lib.utils import resources +import gin.tf + + +class EvaluateTest(parameterized.TestCase): + + def setUp(self): + super(EvaluateTest, self).setUp() + self.model1_dir = self.create_tempdir( + "model1/model", cleanup=absltest.TempFileCleanup.OFF).full_path + self.model2_dir = self.create_tempdir( + "model2/model", cleanup=absltest.TempFileCleanup.OFF).full_path + model_config = resources.get_file( + "config/tests/methods/unsupervised/train_test.gin") + gin.clear_config() + train.train_with_gin(self.model1_dir, True, [model_config]) + train.train_with_gin(self.model2_dir, True, [model_config]) + + self.output_dir = self.create_tempdir( + "output", cleanup=absltest.TempFileCleanup.OFF).full_path + + @parameterized.parameters( + list(resources.get_files_in_folder("config/tests/methods/udr"))) + def test_evaluate(self, gin_config): + # We clear the gin config before running. Otherwise, if a prior test fails, + # the gin config is locked and the current test fails. + gin.clear_config() + gin.parse_config_files_and_bindings([gin_config], None) + evaluate.evaluate([self.model1_dir, self.model2_dir], self.output_dir) + + +if __name__ == "__main__": + absltest.main() diff --git a/disentanglement_lib/evaluation/udr/metrics/udr.py b/disentanglement_lib/evaluation/udr/metrics/udr.py new file mode 100644 index 00000000..b2f9487a --- /dev/null +++ b/disentanglement_lib/evaluation/udr/metrics/udr.py @@ -0,0 +1,258 @@ +# coding=utf-8 +# Copyright 2018 The DisentanglementLib Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementations of the UDR score. + +Methods for computing the UDR and UDR-A2A scores specified in "Unsupervised +Model Selection for Variational Disentangled Representation Learning" +(https://arxiv.org/abs/1905.12614) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl import logging + +import numpy as np +import scipy +from sklearn import linear_model +from sklearn import preprocessing +import gin.tf + + +def relative_strength_disentanglement(corr_matrix): + """Computes disentanglement using relative strength score.""" + score_x = np.nanmean( + np.nan_to_num( + np.power(np.ndarray.max(corr_matrix, axis=0), 2) / + np.sum(corr_matrix, axis=0), 0)) + score_y = np.nanmean( + np.nan_to_num( + np.power(np.ndarray.max(corr_matrix, axis=1), 2) / + np.sum(corr_matrix, axis=1), 0)) + return (score_x + score_y) / 2 + + +def spearman_correlation_conv(vec1, vec2): + """Computes Spearman correlation matrix of two representations. + + Args: + vec1: 2d array of representations with axis 0 the batch dimension and axis 1 + the representation dimension. + vec2: 2d array of representations with axis 0 the batch dimension and axis 1 + the representation dimension. + + Returns: + A 2d array with the correlations between all pairwise combinations of + elements of both representations are computed. Elements of vec1 correspond + to axis 0 and elements of vec2 correspond to axis 1. + """ + assert vec1.shape == vec2.shape + corr_y = [] + for i in range(vec1.shape[1]): + corr_x = [] + for j in range(vec2.shape[1]): + corr, _ = scipy.stats.spearmanr(vec1[:, i], vec2[:, j], nan_policy="omit") + corr_x.append(corr) + corr_y.append(np.stack(corr_x)) + return np.transpose(np.absolute(np.stack(corr_y, axis=1))) + + +def lasso_correlation_matrix(vec1, vec2, random_state=None): + """Computes correlation matrix of two representations using Lasso Regression. + + Args: + vec1: 2d array of representations with axis 0 the batch dimension and axis 1 + the representation dimension. + vec2: 2d array of representations with axis 0 the batch dimension and axis 1 + the representation dimension. + random_state: int used to seed an RNG used for model training. + + Returns: + A 2d array with the correlations between all pairwise combinations of + elements of both representations are computed. Elements of vec1 correspond + to axis 0 and elements of vec2 correspond to axis 1. + """ + assert vec1.shape == vec2.shape + model = linear_model.Lasso(random_state=random_state, alpha=0.1) + model.fit(vec1, vec2) + return np.transpose(np.absolute(model.coef_)) + + +def _generate_representation_batch(ground_truth_data, representation_functions, + batch_size, random_state): + """Sample a single mini-batch of representations from the ground-truth data. + + Args: + ground_truth_data: GroundTruthData to be sampled from. + representation_functions: functions that takes observations as input and + outputs a dim_representation sized representation for each observation and + a vector of the average kl divergence per latent. + batch_size: size of batches of representations to be collected at one time. + random_state: numpy random state used for randomness. + + Returns: + representations: List[batch_size, dim_representation] List of representation + batches for each of the representation_functions. + """ + # Sample a mini batch of latent variables + observations = ground_truth_data.sample_observations(batch_size, random_state) + # Compute representations based on the observations. + return [fn(observations) for fn in representation_functions] + + +def _generate_representation_dataset(ground_truth_data, + representation_functions, batch_size, + num_data_points, random_state): + """Sample dataset of represetations for all of the different models. + + Args: + ground_truth_data: GroundTruthData to be sampled from. + representation_functions: functions that takes observations as input and + outputs a dim_representation sized representation for each observation and + a vector of the average kl divergence per latent. + batch_size: size of batches of representations to be collected at one time. + num_data_points: total number of points to be sampled for training set. + random_state: numpy random state used for randomness. + + Returns: + representation_points: (num_data_points, dim_representation)-sized numpy + array with training set features. + kl: (dim_representation) - The average KL divergence per latent in the + representation. + """ + if num_data_points % batch_size != 0: + raise ValueError("num_data_points must be a multiple of batch_size") + + representation_points = [] + kl_divergence = [] + + for i in range(int(num_data_points / batch_size)): + representation_batch = _generate_representation_batch( + ground_truth_data, representation_functions, batch_size, random_state) + + for j in range(len(representation_functions)): + # Initialize the outputs if it hasn't been created yet. + if len(representation_points) <= j: + kl_divergence.append( + np.zeros((int(num_data_points / batch_size), + representation_batch[j][1].shape[0]))) + representation_points.append( + np.zeros((num_data_points, representation_batch[j][0].shape[1]))) + kl_divergence[j][i, :] = representation_batch[j][1] + representation_points[j][i * batch_size:(i + 1) * batch_size, :] = ( + representation_batch[j][0]) + return representation_points, [np.mean(kl, axis=0) for kl in kl_divergence] + + +@gin.configurable( + "udr_sklearn", + blacklist=["ground_truth_data", "representation_functions", "random_state"]) +def compute_udr_sklearn(ground_truth_data, + representation_functions, + random_state, + batch_size, + num_data_points, + correlation_matrix="lasso", + filter_low_kl=True, + include_raw_correlations=True, + kl_filter_threshold=0.01): + """Computes the UDR score using scikit-learn. + + Args: + ground_truth_data: GroundTruthData to be sampled from. + representation_functions: functions that takes observations as input and + outputs a dim_representation sized representation for each observation. + random_state: numpy random state used for randomness. + batch_size: Number of datapoints to compute in a single batch. Useful for + reducing memory overhead for larger models. + num_data_points: total number of representation datapoints to generate for + computing the correlation matrix. + correlation_matrix: Type of correlation matrix to generate. Can be either + "lasso" or "spearman". + filter_low_kl: If True, filter out elements of the representation vector + which have low computed KL divergence. + include_raw_correlations: Whether or not to include the raw correlation + matrices in the results. + kl_filter_threshold: Threshold which latents with average KL divergence + lower than the threshold will be ignored when computing disentanglement. + + Returns: + scores_dict: a dictionary of the scores computed for UDR with the following + keys: + raw_correlations: (num_models, num_models, latent_dim, latent_dim) - The + raw computed correlation matrices for all models. The pair of models is + indexed by axis 0 and 1 and the matrix represents the computed + correlation matrix between latents in axis 2 and 3. + pairwise_disentanglement_scores: (num_models, num_models, 1) - The + computed disentanglement scores representing the similarity of + representation between pairs of models. + model_scores: (num_models) - List of aggregated model scores corresponding + to the median of the pairwise disentanglement scores for each model. + """ + logging.info("Generating training set.") + inferred_model_reps, kl = _generate_representation_dataset( + ground_truth_data, representation_functions, batch_size, num_data_points, + random_state) + + num_models = len(inferred_model_reps) + logging.info("Number of Models: %s", num_models) + + logging.info("Training sklearn models.") + latent_dim = inferred_model_reps[0].shape[1] + corr_matrix_all = np.zeros((num_models, num_models, latent_dim, latent_dim)) + + # Normalize and calculate mask based off of kl divergence to remove + # uninformative latents. + kl_mask = [] + for i in range(len(inferred_model_reps)): + scaler = preprocessing.StandardScaler() + scaler.fit(inferred_model_reps[i]) + inferred_model_reps[i] = scaler.transform(inferred_model_reps[i]) + inferred_model_reps[i] = inferred_model_reps[i] * np.greater(kl[i], 0.01) + kl_mask.append(kl[i] > kl_filter_threshold) + + disentanglement = np.zeros((num_models, num_models, 1)) + for i in range(num_models): + for j in range(num_models): + if i == j: + continue + + if correlation_matrix == "lasso": + corr_matrix = lasso_correlation_matrix(inferred_model_reps[i], + inferred_model_reps[j], + random_state) + else: + corr_matrix = spearman_correlation_conv(inferred_model_reps[i], + inferred_model_reps[j]) + + corr_matrix_all[i, j, :, :] = corr_matrix + if filter_low_kl: + corr_matrix = corr_matrix[kl_mask[i], ...][..., kl_mask[j]] + disentanglement[i, j] = relative_strength_disentanglement(corr_matrix) + + scores_dict = {} + if include_raw_correlations: + scores_dict["raw_correlations"] = corr_matrix_all.tolist() + scores_dict["pairwise_disentanglement_scores"] = disentanglement.tolist() + + model_scores = [] + for i in range(num_models): + model_scores.append(np.median(np.delete(disentanglement[:, i], i))) + + scores_dict["model_scores"] = model_scores + + return scores_dict diff --git a/disentanglement_lib/evaluation/udr/metrics/udr_test.py b/disentanglement_lib/evaluation/udr/metrics/udr_test.py new file mode 100644 index 00000000..0ea78f79 --- /dev/null +++ b/disentanglement_lib/evaluation/udr/metrics/udr_test.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2018 The DisentanglementLib Authors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for udr.py.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from absl.testing import absltest +from disentanglement_lib.data.ground_truth import dummy_data +from disentanglement_lib.evaluation.udr.metrics import udr +import numpy as np + + +class UdrTest(absltest.TestCase): + + def test_metric_spearman(self): + ground_truth_data = dummy_data.DummyData() + random_state = np.random.RandomState(0) + num_factors = ground_truth_data.num_factors + batch_size = 10 + num_data_points = 1000 + + permutation = np.random.permutation(num_factors) + sign_inverse = np.random.choice(num_factors, int(num_factors / 2)) + + def rep_fn1(data): + return (np.reshape(data, (batch_size, -1))[:, :num_factors], + np.ones(num_factors)) + + # Should be invariant to permutation and sign inverse. + def rep_fn2(data): + raw_representation = np.reshape(data, (batch_size, -1))[:, :num_factors] + perm_rep = raw_representation[:, permutation] + perm_rep[:, sign_inverse] = -1.0 * perm_rep[:, sign_inverse] + return perm_rep, np.ones(num_factors) + + scores = udr.compute_udr_sklearn( + ground_truth_data, [rep_fn1, rep_fn2], + random_state, + batch_size, + num_data_points, + correlation_matrix="spearman") + self.assertBetween(scores["model_scores"][0], 0.8, 1.0) + self.assertBetween(scores["model_scores"][1], 0.8, 1.0) + + def test_metric_lasso(self): + ground_truth_data = dummy_data.DummyData() + random_state = np.random.RandomState(0) + num_factors = ground_truth_data.num_factors + batch_size = 10 + num_data_points = 1000 + + permutation = np.random.permutation(num_factors) + sign_inverse = np.random.choice(num_factors, int(num_factors / 2)) + + def rep_fn1(data): + return (np.reshape(data, (batch_size, -1))[:, :num_factors], + np.ones(num_factors)) + + # Should be invariant to permutation and sign inverse. + def rep_fn2(data): + raw_representation = np.reshape(data, (batch_size, -1))[:, :num_factors] + perm_rep = raw_representation[:, permutation] + perm_rep[:, sign_inverse] = -1.0 * perm_rep[:, sign_inverse] + return perm_rep, np.ones(num_factors) + + scores = udr.compute_udr_sklearn( + ground_truth_data, [rep_fn1, rep_fn2], + random_state, + batch_size, + num_data_points, + correlation_matrix="lasso") + self.assertBetween(scores["model_scores"][0], 0.8, 1.0) + self.assertBetween(scores["model_scores"][1], 0.8, 1.0) + + def test_metric_kl(self): + ground_truth_data = dummy_data.DummyData() + random_state = np.random.RandomState(0) + num_factors = ground_truth_data.num_factors + batch_size = 10 + num_data_points = 1000 + + # Representation without KL Mask where only first latent is valid. + def rep_fn(data): + rep = np.concatenate([ + np.reshape(data, (batch_size, -1))[:, :1], + np.random.random_sample((batch_size, num_factors - 1)) + ], + axis=1) + kl_mask = np.zeros(num_factors) + kl_mask[0] = 1.0 + return rep, kl_mask + + scores = udr.compute_udr_sklearn( + ground_truth_data, [rep_fn, rep_fn], + random_state, + batch_size, + num_data_points, + filter_low_kl=False) + self.assertBetween(scores["model_scores"][0], 0.0, 0.2) + self.assertBetween(scores["model_scores"][1], 0.0, 0.2) + + scores = udr.compute_udr_sklearn( + ground_truth_data, [rep_fn, rep_fn], + random_state, + batch_size, + num_data_points, + filter_low_kl=True) + self.assertBetween(scores["model_scores"][0], 0.8, 1.0) + self.assertBetween(scores["model_scores"][1], 0.8, 1.0) + + def test_relative_strength_disentanglement(self): + corr_matrix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + self.assertEqual(udr.relative_strength_disentanglement(corr_matrix), 1.0) + + corr_matrix = np.array([[1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + self.assertBetween( + udr.relative_strength_disentanglement(corr_matrix), 0.6, 0.7) + + corr_matrix = np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]) + self.assertBetween( + udr.relative_strength_disentanglement(corr_matrix), 0.3, 0.4) + + corr_matrix = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + self.assertEqual(udr.relative_strength_disentanglement(corr_matrix), 0.0) + + def test_spearman_correlation(self): + random_state = np.random.RandomState(0) + vec1 = random_state.random_sample((1000, 3)) + vec2 = np.copy(vec1) + expected_matrix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0]]) + self.assertTrue( + np.allclose( + udr.spearman_correlation_conv(vec1, vec2), + expected_matrix, + atol=0.1)) + + vec1 = random_state.random_sample((1000, 3)) + vec2 = np.copy(vec1) + vec2[:, 1] = vec2[:, 0] + + expected_matrix = np.array([[1.0, 1.0, 0.0], [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0]]) + self.assertTrue( + np.allclose( + udr.spearman_correlation_conv(vec1, vec2), + expected_matrix, + atol=0.1)) + + def test_lasso_correlation(self): + random_state = np.random.RandomState(0) + vec1 = random_state.random_sample((1000, 3)) * 10.0 + vec2 = np.copy(vec1) + expected_matrix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0]]) + self.assertTrue( + np.allclose( + udr.lasso_correlation_matrix(vec1, vec2, random_state=random_state), + expected_matrix, + atol=0.2)) + + vec1 = random_state.random_sample((1000, 3)) * 10.0 + vec2 = np.copy(vec1) + vec2[:, 1] = vec2[:, 0] + + expected_matrix = np.array([[1.0, 1.0, 0.0], [0.0, 0.0, 0.0], + [0.0, 0.0, 1.0]]) + self.assertTrue( + np.allclose( + udr.lasso_correlation_matrix(vec1, vec2, random_state=random_state), + expected_matrix, + atol=0.2)) + + +if __name__ == "__main__": + absltest.main() diff --git a/setup.py b/setup.py index b5cd2423..6574ba79 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ setup( name='disentanglement_lib', - version='1.3', + version='1.4', description=('Library for research on disentangled representations.'), author='DisentanglementLib Authors', author_email='no-reply@google.com', @@ -34,6 +34,7 @@ 'bin/dlib_reason', 'bin/dlib_visualize_dataset', 'bin/dlib_evaluate', + 'bin/dlib_udr', 'bin/dlib_postprocess', 'bin/dlib_train', 'bin/dlib_visualize_dataset',