Skip to content

Commit

Permalink
Introduce AdaNet controller.
Browse files Browse the repository at this point in the history
This AdaNet controller performs a two phase ensembling of deeper and deeper neural network architectures.

PiperOrigin-RevId: 284378509
  • Loading branch information
csvillalta authored and cweill committed Dec 17, 2019
1 parent 3a63312 commit 712bc8e
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 10 deletions.
16 changes: 16 additions & 0 deletions adanet/experimental/controllers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,19 @@ py_library(
"//adanet/experimental/work_units:work_unit",
],
)

py_library(
name = "adanet_controller",
srcs = ["adanet_controller.py"],
srcs_version = "PY3",
visibility = ["//adanet/experimental:__subpackages__"],
deps = [
":controller",
"//adanet/experimental/keras:ensemble_model",
"//adanet/experimental/phases:phase",
"//adanet/experimental/storages:in_memory_storage",
"//adanet/experimental/storages:storage",
"//adanet/experimental/work_units:keras_trainer",
"//adanet/experimental/work_units:work_unit",
],
)
151 changes: 151 additions & 0 deletions adanet/experimental/controllers/adanet_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Lint as: python3
# Copyright 2019 The AdaNet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An AdaNet controller for model search."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import itertools

from typing import Callable, Iterator, List, Sequence, Union
from adanet.experimental.controllers.controller import Controller
from adanet.experimental.keras.ensemble_model import MeanEnsemble
from adanet.experimental.phases.phase import Phase
from adanet.experimental.storages.in_memory_storage import InMemoryStorage
from adanet.experimental.storages.storage import Storage
from adanet.experimental.work_units.keras_trainer import KerasTrainer
from adanet.experimental.work_units.work_unit import WorkUnit

import tensorflow as tf


class AdaNetCandidatePhase(Phase):
"""Generates and trains neural networks with various layer depths."""

def __init__(self, dataset: tf.data.Dataset,
candidates_per_iteration: int,
optimizer: Union[str, tf.keras.optimizers.Optimizer],
loss: Union[str, tf.keras.losses.Loss],
output_units: int,
metrics: List[Union[str, tf.keras.metrics.Metric]] = None,
units_per_layer: int = 128,
layer_activation: Union[str, Callable[..., tf.Tensor]] = 'relu',
output_activation: Union[str, Callable[...,
tf.Tensor]] = 'linear'):
self._dataset = dataset
self._candidates_per_iteration = candidates_per_iteration
self._optimizer = optimizer
self._loss = loss
self._metrics = metrics
self._units_per_layer = units_per_layer
self._output_units = output_units
self._layer_activation = layer_activation
self._output_activation = output_activation
self._candidate_storage = None

# TODO: Add warning about build not being called.
def build(self, candidate_storage: Storage):
self._candidate_storage = candidate_storage

def work_units(self) -> Iterator[WorkUnit]:
for network in self._generate_networks():
yield KerasTrainer(network, self._dataset, self._candidate_storage)

def _generate_networks(self) -> Iterator[tf.keras.Model]:
best_candidate = self._candidate_storage.get_best_models(num_models=1)
if not best_candidate:
num_layers = 0
else:
num_layers = len(best_candidate[0].layers)
for i in range(self._candidates_per_iteration):
model = tf.keras.Sequential()
for _ in range(num_layers+i):
model.add(tf.keras.layers.Dense(units=self._units_per_layer,
activation=self._layer_activation))
model.add(tf.keras.layers.Dense(units=self._output_units,
activation=self._output_activation))
model.compile(optimizer=self._optimizer,
loss=self._loss,
metrics=self._metrics)
yield model


# TODO: Make this a more general phase.
class AdaNetEnsemblePhase(Phase):
"""Ensembles submodels."""

def __init__(self, dataset: tf.data.Dataset,
candidates_per_iteration: int,
optimizer: Union[str, tf.keras.optimizers.Optimizer],
loss: Union[str, tf.keras.losses.Loss],
metrics: List[Union[str, tf.keras.metrics.Metric]] = None):
self._dataset = dataset
self._candidates_per_iteration = candidates_per_iteration
self._optimizer = optimizer
self._loss = loss
self._metrics = metrics
self._candidate_storage = None
self._ensemble_storage = None

def build(self, candidate_storage: Storage, ensemble_storage: Storage):
self._candidate_storage = candidate_storage
self._ensemble_storage = ensemble_storage

@property
def ensemble_storage(self):
return self._ensemble_storage

# TODO: Revisit how newest candidates are obtained within this
# phase.
def work_units(self) -> Iterator[WorkUnit]:
best_candidates = self._candidate_storage.get_newest_models(
num_models=self._candidates_per_iteration)
best_ensemble = self._ensemble_storage.get_best_models(num_models=1)
for candidate in best_candidates:
if not best_ensemble:
ensemble = MeanEnsemble([candidate])
else:
ensemble = MeanEnsemble(best_ensemble[0].submodels + [candidate])
ensemble.compile(optimizer=self._optimizer,
loss=self._loss,
metrics=self._metrics)

yield KerasTrainer(ensemble, self._dataset, self._ensemble_storage)


class AdaNetController(Controller):
"""A controller that trains candidate networks and ensembles them."""

def __init__(self,
candidate_phase: AdaNetCandidatePhase,
ensemble_phase: AdaNetEnsemblePhase,
iterations: int,
candidate_storage: Storage = InMemoryStorage(),
ensemble_storage: Storage = InMemoryStorage()):
candidate_phase.build(candidate_storage)
ensemble_phase.build(candidate_storage, ensemble_storage)
self._candidate_phase = candidate_phase
self._ensemble_phase = ensemble_phase
self._iterations = iterations

def work_units(self) -> Iterator[WorkUnit]:
for _ in range(self._iterations):
for work_unit in itertools.chain(self._candidate_phase.work_units(),
self._ensemble_phase.work_units()):
yield work_unit

def get_best_models(self, num_models) -> Sequence[tf.keras.Model]:
return self._ensemble_phase.ensemble_storage.get_best_models(num_models)
1 change: 1 addition & 0 deletions adanet/experimental/keras/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ py_strict_test(
":ensemble_model",
":model_search",
":testing_utils",
"//adanet/experimental/controllers:adanet_controller",
"//adanet/experimental/controllers:sequential_controller",
"//adanet/experimental/phases:keras_tuner_phase",
"//adanet/experimental/phases:train_keras_models_phase",
Expand Down
10 changes: 9 additions & 1 deletion adanet/experimental/keras/ensemble_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,19 @@
class EnsembleModel(tf.keras.Model):
"""An ensemble of Keras models."""

def __init__(self, submodels: Sequence[tf.keras.Model]):
def __init__(self, submodels: Sequence[tf.keras.Model],
freeze_submodels: bool = True):
"""Initializes an EnsembleModel.
Args:
submodels: A list of `tf.keras.Model` that compose the ensemble.
freeze_submodels: Whether to freeze the weights of submodels.
"""

super().__init__()
if freeze_submodels:
for submodel in submodels:
submodel.trainable = False
self._submodels = submodels

@property
Expand All @@ -49,6 +54,9 @@ class MeanEnsemble(EnsembleModel):
"""An ensemble that averages submodel outputs."""

def call(self, inputs):
if len(self._submodels) == 1:
return self._submodels[0](inputs)

submodel_outputs = []
for submodel in self.submodels:
submodel_outputs.append(submodel(inputs))
Expand Down
50 changes: 42 additions & 8 deletions adanet/experimental/keras/model_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

from absl import flags
from absl.testing import parameterized
from adanet.experimental.controllers.adanet_controller import AdaNetCandidatePhase
from adanet.experimental.controllers.adanet_controller import AdaNetController
from adanet.experimental.controllers.adanet_controller import AdaNetEnsemblePhase
from adanet.experimental.controllers.sequential_controller import SequentialController
from adanet.experimental.keras import testing_utils
from adanet.experimental.keras.ensemble_model import MeanEnsemble
Expand Down Expand Up @@ -77,7 +80,10 @@ def test_phases_end_to_end(self):
model2.compile(
optimizer=tf.keras.optimizers.Adam(0.01), loss='mse', metrics=['mae'])

ensemble = MeanEnsemble(submodels=[model1, model2])
# TODO: This test could potentially have the best model be
# a non-ensemble Keras model. Therefore, need to address this issue and
# remove the freeze_submodels flag.
ensemble = MeanEnsemble(submodels=[model1, model2], freeze_submodels=False)
ensemble.compile(
optimizer=tf.keras.optimizers.Adam(0.01), loss='mse', metrics=['mae'])

Expand All @@ -89,13 +95,6 @@ def test_phases_end_to_end(self):
TrainKerasModelsPhase([ensemble], dataset=train_dataset),
])

train_dataset, test_dataset = testing_utils.get_test_data(
train_samples=128,
test_samples=64,
input_shape=(10,),
num_classes=2,
random_seed=42)

model_search = ModelSearch(controller)
model_search.run()
self.assertIsInstance(
Expand Down Expand Up @@ -158,6 +157,41 @@ def build_ensemble():
self.assertIsInstance(
model_search.get_best_models(num_models=1)[0], MeanEnsemble)

def test_adanet_controller_end_to_end(self):
train_dataset, test_dataset = testing_utils.get_test_data(
train_samples=1280,
test_samples=640,
input_shape=(10,),
num_classes=10,
random_seed=42)

train_dataset = train_dataset.batch(32)
test_dataset = test_dataset.batch(32)

candidate_phase = AdaNetCandidatePhase(
train_dataset,
candidates_per_iteration=2,
optimizer='adam',
loss='sparse_categorical_crossentropy',
output_units=10)
# TODO: Setting candidates_per_iteration greater than the one
# for the candidate phase will lead to unexpected behavior.
ensemble_phase = AdaNetEnsemblePhase(
train_dataset,
candidates_per_iteration=2,
optimizer='adam',
loss='sparse_categorical_crossentropy')

adanet_controller = AdaNetController(
candidate_phase,
ensemble_phase,
iterations=5)

model_search = ModelSearch(adanet_controller)
model_search.run()
self.assertIsInstance(
model_search.get_best_models(num_models=1)[0], MeanEnsemble)


if __name__ == '__main__':
tf.enable_v2_behavior()
Expand Down
1 change: 1 addition & 0 deletions adanet/experimental/phases/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Phase(abc.ABC):
A phase is only complete once all its work units complete, as a barrier.
"""

# TODO: Remove this build function.
def build(self, storage: Storage, previous: 'Phase' = None):
self._storage = storage
self._previous = previous
Expand Down
6 changes: 6 additions & 0 deletions adanet/experimental/storages/in_memory_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@ def load_model(self, model_id: int) -> tf.keras.Model:

def get_best_models(self, num_models) -> Sequence[tf.keras.Model]:
return [m for _, _, m in heapq.nsmallest(num_models, self._models)]

def get_newest_models(self, num_models) -> Sequence[tf.keras.Model]:
return [
m for _, m_id, m in self._models
if m_id in [self._id - i for i in range(num_models)]
]
4 changes: 4 additions & 0 deletions adanet/experimental/storages/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,7 @@ def get_best_models(self, num_models) -> Sequence[tf.keras.Model]:
# TODO: Rethink get_best_model API since it's defined in Storage,
# Phases, and Controllers.
pass

@abc.abstractmethod
def get_newest_models(self, num_models) -> Sequence[tf.keras.Model]:
pass
8 changes: 7 additions & 1 deletion adanet/experimental/work_units/keras_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,10 @@ def __init__(self, model: tf.keras.Model, dataset: tf.data.Dataset,
def execute(self):
self._model.fit(self._dataset)
results = self._model.evaluate(self._dataset)
self._storage.save_model(self._model, results[0])
# If the model was compiled with metrics, the results is a list of loss +
# metric values. If the model was compiled without metrics, it is a loss
# scalar.
if isinstance(results, list):
self._storage.save_model(self._model, results[0])
else:
self._storage.save_model(self._model, results)

0 comments on commit 712bc8e

Please sign in to comment.