-
Notifications
You must be signed in to change notification settings - Fork 203
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 328541556
- Loading branch information
Showing
77 changed files
with
5,179 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2018 The DisentanglementLib Authors. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Pipeline that trains a model and then computes multiple scores.""" | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
|
||
# Dependency imports | ||
|
||
from absl import app | ||
from absl import flags | ||
from absl import logging | ||
from disentanglement_lib.evaluation import evaluate | ||
from disentanglement_lib.methods.unsupervised import train | ||
from disentanglement_lib.postprocessing import postprocess | ||
from disentanglement_lib.utils import aggregate_results | ||
from disentanglement_lib.visualize import visualize_model | ||
import numpy as np | ||
from six.moves import range | ||
from tensorflow.compat.v1 import gfile | ||
|
||
|
||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_string("output_directory", None, | ||
"Output directory of experiments.") | ||
|
||
# Model flags. If the model_dir flag is set, then that directory is used and | ||
# training is skipped. | ||
flags.DEFINE_string("model_dir", None, "Directory to take trained model from.") | ||
# Otherwise, the model is trained using the gin bindings and the gin model | ||
# config file in the gin model config folder. | ||
flags.DEFINE_multi_string("gin_bindings", [], | ||
"Newline separated list of Gin parameter bindings.") | ||
flags.DEFINE_string("gin_model_config_dir", None, | ||
"Path to directory with model configs.") | ||
flags.DEFINE_string("gin_model_config_name", None, | ||
"Filename of the model config.") | ||
|
||
# Postprocessing and evaluation is done using glob patterns of gin configs. | ||
flags.DEFINE_string("gin_postprocess_config_glob", None, | ||
"Path to glob pattern to evaluation configs.") | ||
flags.DEFINE_string("gin_evaluation_config_glob", None, | ||
"Path to glob pattern to evaluation configs.") | ||
flags.DEFINE_integer("num_eval", 1, | ||
"Number of times the evaluation measures are computed.") | ||
flags.DEFINE_multi_string("scores_list", [], | ||
"List of scores to be evaluate multiple times.") | ||
|
||
flags.DEFINE_multi_string("evaluation_num_samples_train", [], | ||
"List of sample sizes for the evaluation.") | ||
|
||
# Other flags. | ||
flags.DEFINE_integer("random_seed", None, | ||
"Integer with random seed for whole pipeline.") | ||
|
||
flags.DEFINE_boolean("overwrite", False, | ||
"Whether to overwrite output directory.") | ||
|
||
|
||
def main(unused_argv): | ||
if FLAGS.scores_list and not FLAGS.evaluation_num_samples_train: | ||
raise ValueError("How many samples should be used to compute the scores in" | ||
" scores_list must be specified.") | ||
# In this pipeline, we manually manage the random seeds of the different steps | ||
# as otherwise different training runs of the model (with different random | ||
# seeds) would be evaluated on exactly the same data (i.e., there would be no | ||
# randomness in evaluation). We use the random seed in the flag to seed a | ||
# random number generator from which random seeds for the different parts of | ||
# the pipeline are drawn. | ||
random_state = np.random.RandomState(FLAGS.random_seed) | ||
|
||
# Model training (if the model_dir is not provided.). | ||
|
||
# It is important that we sample the model random seed regardless whether | ||
# we actually train the model so later seeds are the same. | ||
model_random_seed = random_state.randint(2**32) | ||
if FLAGS.model_dir is None: | ||
logging.info("Training model...") | ||
model_dir = os.path.join(FLAGS.output_directory, "model") | ||
model_config_file = os.path.join(FLAGS.gin_model_config_dir, | ||
FLAGS.gin_model_config_name) | ||
model_bindings = [ | ||
"model.random_seed = {}".format(model_random_seed), | ||
"model.name = '{}'".format(FLAGS.gin_model_config_name).replace( | ||
".gin", "") | ||
] + FLAGS.gin_bindings | ||
train.train_with_gin(model_dir, FLAGS.overwrite, [model_config_file], | ||
model_bindings) | ||
else: | ||
logging.info("Skipped training...") | ||
model_dir = os.path.join(FLAGS.model_dir, "model") | ||
|
||
# We visualize reconstruction, samples and latent space traversal. | ||
visualize_dir = os.path.join(FLAGS.output_directory, "visualizations") | ||
visualize_model.visualize(model_dir, visualize_dir, FLAGS.overwrite) | ||
|
||
# We extract the different representations and save them to disk. | ||
evaluation_configs = sorted(gfile.Glob(FLAGS.gin_postprocess_config_glob)) | ||
for config in evaluation_configs: | ||
post_name = os.path.basename(config).replace(".gin", "") | ||
logging.info("Extracting representation %s...", post_name) | ||
post_dir = os.path.join(FLAGS.output_directory, "postprocessed", post_name) | ||
postprocess_bindings = [ | ||
"postprocess.random_seed = {}".format(random_state.randint(2**32)), | ||
"postprocess.name = '{}'".format(post_name) | ||
] | ||
postprocess.postprocess_with_gin(model_dir, post_dir, FLAGS.overwrite, | ||
[config], postprocess_bindings) | ||
|
||
# Iterate through metrics. | ||
metric_configs = sorted(gfile.Glob(FLAGS.gin_evaluation_config_glob)) | ||
|
||
for iternation_num in range(FLAGS.num_eval): | ||
for config in evaluation_configs: | ||
post_name = os.path.basename(config).replace(".gin", "") | ||
post_dir = os.path.join(FLAGS.output_directory, "postprocessed", | ||
post_name) | ||
# Now, we compute all the specified scores. | ||
for gin_eval_config in metric_configs: | ||
metric_name_gin = os.path.basename(gin_eval_config).replace( | ||
".gin", "") | ||
metric_name = "{}_{}".format(metric_name_gin, iternation_num) | ||
logging.info("Computing metric '%s' on '%s'...", | ||
metric_name, post_name) | ||
|
||
# Only the scores in FLAGS.scores_list need to be ran for multiple | ||
# sample sizes. | ||
if metric_name_gin in FLAGS.scores_list: | ||
for num_samples in FLAGS.evaluation_num_samples_train: | ||
metric_name = "{}_{}_{}".format(metric_name_gin, num_samples, | ||
iternation_num) | ||
metric_dir = os.path.join(FLAGS.output_directory, "metrics", | ||
post_name, metric_name) | ||
eval_bindings = [ | ||
"{}.num_train = {}".format(metric_name_gin, int(num_samples)), | ||
"evaluation.name = '{}'".format(metric_name), | ||
"evaluation.random_seed = {}".format( | ||
random_state.randint(2**32))] | ||
evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite, | ||
[gin_eval_config], eval_bindings) | ||
else: | ||
metric_dir = os.path.join(FLAGS.output_directory, "metrics", | ||
post_name, metric_name) | ||
eval_bindings = [ | ||
"evaluation.name = '{}'".format(metric_name)] | ||
evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite, | ||
[gin_eval_config], eval_bindings) | ||
# Aggregate all the results in a single json file per model. | ||
result_dir = os.path.join(FLAGS.output_directory, "metrics", | ||
"aggregated_results.json") | ||
pattern = os.path.join(FLAGS.output_directory, | ||
"metrics/*/*/results/aggregate/evaluation.json") | ||
aggregate_results.aggregate_results_to_json(pattern, result_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run(main) |
Oops, something went wrong.