Skip to content

Commit

Permalink
internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 328541556
  • Loading branch information
DisentanglementLib Team authored and obachem committed Aug 27, 2020
1 parent a64b8b9 commit 86a644d
Show file tree
Hide file tree
Showing 77 changed files with 5,179 additions and 81 deletions.
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,40 @@ The scores will be exported to `<output>/results/aggregate/evaluation.json`
under the model_scores attribute. The scores will be presented in the order of
the input model directories.

## Weakly-Supervised experiments
The library also includes the code for the weakly-supervised disentanglement methods proposed in the following paper in `disentanglement_lib/bin/dlib_reproduce_weakly_supervised`:
> [**Weakly-Supervised Disentanglement Without Compromises**](https://arxiv.org/abs/2002.02886)
> *Francesco Locatello, Ben Poole, Gunnar Rätsch, Bernhard Schölkopf, Olivier Bachem, Michael Tschannen*.
```
dlib_reproduce_weakly_supervised --output_directory=<output> \
--gin_model_config_dir=<dir> \
--gin_model_config_name=<name> \
--gin_postprocess_config_glob=<postprocess_configs> \
--gin_evaluation_config_glob=<eval_configs> \
--pipeline_seed=<seed>
```

## Semi-Supervised experiments
The library also includes the code for the semi-supervised disentanglement methods proposed in the following paper in `disentanglement_lib/bin/dlib_reproduce_semi_supervised`:
> [**Disentangling Factors of Variation Using Few Labels**](https://arxiv.org/abs/1905.01258)
> *Francesco Locatello, Michael Tschannen, Stefan Bauer, Gunnar Rätsch, Bernhard Schölkopf, Olivier Bachem*.
```
dlib_reproduce_weakly_supervised --output_directory=<output> \
--gin_model_config_dir=<dir> \
--gin_model_config_name=<name> \
--gin_postprocess_config_glob=<postprocess_configs> \
--gin_evaluation_config_glob=<eval_configs> \
--gin_validation_config_glob=<val_configs> \
--pipeline_seed=<seed> \
--eval_seed=<seed> \
--supervised_seed=<seed> \
--num_labelled_samples=<num> \
--train_percentage=0.9 \
--labeller_fn="@perfect_labeller"
```

## Feedback
Please send any feedback to bachem@google.com and francesco.locatello@tuebingen.mpg.de.

Expand Down
2 changes: 1 addition & 1 deletion bin/dlib_evaluate
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import os
from absl import app
from absl import flags
from disentanglement_lib.evaluation import evaluate
from tensorflow import gfile
from tensorflow.compat.v1 import gfile

FLAGS = flags.FLAGS

Expand Down
175 changes: 175 additions & 0 deletions bin/dlib_reproduce_jmlr
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2018 The DisentanglementLib Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pipeline that trains a model and then computes multiple scores."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

# Dependency imports

from absl import app
from absl import flags
from absl import logging
from disentanglement_lib.evaluation import evaluate
from disentanglement_lib.methods.unsupervised import train
from disentanglement_lib.postprocessing import postprocess
from disentanglement_lib.utils import aggregate_results
from disentanglement_lib.visualize import visualize_model
import numpy as np
from six.moves import range
from tensorflow.compat.v1 import gfile



FLAGS = flags.FLAGS

flags.DEFINE_string("output_directory", None,
"Output directory of experiments.")

# Model flags. If the model_dir flag is set, then that directory is used and
# training is skipped.
flags.DEFINE_string("model_dir", None, "Directory to take trained model from.")
# Otherwise, the model is trained using the gin bindings and the gin model
# config file in the gin model config folder.
flags.DEFINE_multi_string("gin_bindings", [],
"Newline separated list of Gin parameter bindings.")
flags.DEFINE_string("gin_model_config_dir", None,
"Path to directory with model configs.")
flags.DEFINE_string("gin_model_config_name", None,
"Filename of the model config.")

# Postprocessing and evaluation is done using glob patterns of gin configs.
flags.DEFINE_string("gin_postprocess_config_glob", None,
"Path to glob pattern to evaluation configs.")
flags.DEFINE_string("gin_evaluation_config_glob", None,
"Path to glob pattern to evaluation configs.")
flags.DEFINE_integer("num_eval", 1,
"Number of times the evaluation measures are computed.")
flags.DEFINE_multi_string("scores_list", [],
"List of scores to be evaluate multiple times.")

flags.DEFINE_multi_string("evaluation_num_samples_train", [],
"List of sample sizes for the evaluation.")

# Other flags.
flags.DEFINE_integer("random_seed", None,
"Integer with random seed for whole pipeline.")

flags.DEFINE_boolean("overwrite", False,
"Whether to overwrite output directory.")


def main(unused_argv):
if FLAGS.scores_list and not FLAGS.evaluation_num_samples_train:
raise ValueError("How many samples should be used to compute the scores in"
" scores_list must be specified.")
# In this pipeline, we manually manage the random seeds of the different steps
# as otherwise different training runs of the model (with different random
# seeds) would be evaluated on exactly the same data (i.e., there would be no
# randomness in evaluation). We use the random seed in the flag to seed a
# random number generator from which random seeds for the different parts of
# the pipeline are drawn.
random_state = np.random.RandomState(FLAGS.random_seed)

# Model training (if the model_dir is not provided.).

# It is important that we sample the model random seed regardless whether
# we actually train the model so later seeds are the same.
model_random_seed = random_state.randint(2**32)
if FLAGS.model_dir is None:
logging.info("Training model...")
model_dir = os.path.join(FLAGS.output_directory, "model")
model_config_file = os.path.join(FLAGS.gin_model_config_dir,
FLAGS.gin_model_config_name)
model_bindings = [
"model.random_seed = {}".format(model_random_seed),
"model.name = '{}'".format(FLAGS.gin_model_config_name).replace(
".gin", "")
] + FLAGS.gin_bindings
train.train_with_gin(model_dir, FLAGS.overwrite, [model_config_file],
model_bindings)
else:
logging.info("Skipped training...")
model_dir = os.path.join(FLAGS.model_dir, "model")

# We visualize reconstruction, samples and latent space traversal.
visualize_dir = os.path.join(FLAGS.output_directory, "visualizations")
visualize_model.visualize(model_dir, visualize_dir, FLAGS.overwrite)

# We extract the different representations and save them to disk.
evaluation_configs = sorted(gfile.Glob(FLAGS.gin_postprocess_config_glob))
for config in evaluation_configs:
post_name = os.path.basename(config).replace(".gin", "")
logging.info("Extracting representation %s...", post_name)
post_dir = os.path.join(FLAGS.output_directory, "postprocessed", post_name)
postprocess_bindings = [
"postprocess.random_seed = {}".format(random_state.randint(2**32)),
"postprocess.name = '{}'".format(post_name)
]
postprocess.postprocess_with_gin(model_dir, post_dir, FLAGS.overwrite,
[config], postprocess_bindings)

# Iterate through metrics.
metric_configs = sorted(gfile.Glob(FLAGS.gin_evaluation_config_glob))

for iternation_num in range(FLAGS.num_eval):
for config in evaluation_configs:
post_name = os.path.basename(config).replace(".gin", "")
post_dir = os.path.join(FLAGS.output_directory, "postprocessed",
post_name)
# Now, we compute all the specified scores.
for gin_eval_config in metric_configs:
metric_name_gin = os.path.basename(gin_eval_config).replace(
".gin", "")
metric_name = "{}_{}".format(metric_name_gin, iternation_num)
logging.info("Computing metric '%s' on '%s'...",
metric_name, post_name)

# Only the scores in FLAGS.scores_list need to be ran for multiple
# sample sizes.
if metric_name_gin in FLAGS.scores_list:
for num_samples in FLAGS.evaluation_num_samples_train:
metric_name = "{}_{}_{}".format(metric_name_gin, num_samples,
iternation_num)
metric_dir = os.path.join(FLAGS.output_directory, "metrics",
post_name, metric_name)
eval_bindings = [
"{}.num_train = {}".format(metric_name_gin, int(num_samples)),
"evaluation.name = '{}'".format(metric_name),
"evaluation.random_seed = {}".format(
random_state.randint(2**32))]
evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite,
[gin_eval_config], eval_bindings)
else:
metric_dir = os.path.join(FLAGS.output_directory, "metrics",
post_name, metric_name)
eval_bindings = [
"evaluation.name = '{}'".format(metric_name)]
evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite,
[gin_eval_config], eval_bindings)
# Aggregate all the results in a single json file per model.
result_dir = os.path.join(FLAGS.output_directory, "metrics",
"aggregated_results.json")
pattern = os.path.join(FLAGS.output_directory,
"metrics/*/*/results/aggregate/evaluation.json")
aggregate_results.aggregate_results_to_json(pattern, result_dir)


if __name__ == "__main__":
app.run(main)
Loading

0 comments on commit 86a644d

Please sign in to comment.