Skip to content

Commit

Permalink
Add mixture backend tell_examples support
Browse files Browse the repository at this point in the history
  • Loading branch information
evhub committed Jun 30, 2019
1 parent 29bf95a commit d14e8a0
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 19 deletions.
22 changes: 17 additions & 5 deletions bbopt-source/backends/mixture.coco
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,40 @@ class MixtureBackend(Backend):
total_weight = sum(weight for alg, weight in distribution)

# generate cutoff points
cum_probs = []
self.cum_probs = []
prev_cutoff = 0
for alg, weight in distribution:
cutoff = prev_cutoff + weight / total_weight
cum_probs.append((alg, cutoff))
self.cum_probs.append((alg, cutoff))
prev_cutoff = cutoff

self.backend_store = {}
self.tell_examples(examples, params)

def tell_examples(self, examples, params):
"""Special method that allows fast updating of the backend with new examples."""
# randomly select algorithm
rand_val = random.random()
self.selected_alg = None
for alg, cutoff in cum_probs:
for alg, cutoff in self.cum_probs:
if rand_val <= cutoff:
self.selected_alg = alg
break

# initialize backend
self.selected_backend, options = alg_registry[self.selected_alg]
self.backend = init_backend(self.selected_backend, examples, params, **options)
self.current_backend = init_backend(
self.selected_backend,
examples,
params,
attempt_to_update_backend=self.backend_store.get(self.selected_alg),
**options,
)
self.backend_store[self.selected_alg] = self.current_backend

def param(self, name, func, *args, **kwargs) =
"""Defer parameter selection to the selected backend."""
self.backend.param(name, func, *args, **kwargs)
self.current_backend.param(name, func, *args, **kwargs)


# Registered names:
Expand Down
7 changes: 5 additions & 2 deletions bbopt-source/backends/util.coco
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class Backend:
"""Attempt to update this backend with new arguments. False indicates that the
update failed while True indicates a successful update."""
if (self.tell_examples is NotImplemented
or not params
or not self._params
or params != self._params
or args != self._args
or kwargs != self._kwargs):
Expand All @@ -172,7 +172,10 @@ class Backend:
if old_examples != self._examples:
return False
if new_examples:
self.tell_examples(new_examples, params)
try:
self.tell_examples(new_examples, params)
except NotImplementedError:
return False
return True

# implement tell_examples(new_examples, params) to allow fast updating on new data
Expand Down
2 changes: 1 addition & 1 deletion bbopt-source/constants.coco
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Constants for use across all of BBopt.

# Installation constants:
name = "bbopt"
version = "1.1.4"
version = "1.1.5"
description = "The easiest hyperparameter optimization you'll ever do."
long_description = """
See BBopt's GitHub_ for more information.
Expand Down
18 changes: 12 additions & 6 deletions bbopt/backends/mixture.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# __coconut_hash__ = 0x2637c052
# __coconut_hash__ = 0x18531269

# Compiled with Coconut version 1.4.0-post_dev40 [Ernest Scribbler]

Expand Down Expand Up @@ -43,28 +43,34 @@ def __init__(self, examples, params, distribution):
total_weight = sum((weight for alg, weight in distribution))

# generate cutoff points
cum_probs = []
self.cum_probs = []
prev_cutoff = 0
for alg, weight in distribution:
cutoff = prev_cutoff + weight / total_weight
cum_probs.append((alg, cutoff))
self.cum_probs.append((alg, cutoff))
prev_cutoff = cutoff

self.backend_store = {}
self.tell_examples(examples, params)

def tell_examples(self, examples, params):
"""Special method that allows fast updating of the backend with new examples."""
# randomly select algorithm
rand_val = random.random()
self.selected_alg = None
for alg, cutoff in cum_probs:
for alg, cutoff in self.cum_probs:
if rand_val <= cutoff:
self.selected_alg = alg
break

# initialize backend
self.selected_backend, options = alg_registry[self.selected_alg]
self.backend = init_backend(self.selected_backend, examples, params, **options)
self.current_backend = init_backend(self.selected_backend, examples, params, attempt_to_update_backend=self.backend_store.get(self.selected_alg), **options)
self.backend_store[self.selected_alg] = self.current_backend

def param(self, name, func, *args, **kwargs):
"""Defer parameter selection to the selected backend."""
return self.backend.param(name, func, *args, **kwargs)
return self.current_backend.param(name, func, *args, **kwargs)


# Registered names:
Expand Down
9 changes: 6 additions & 3 deletions bbopt/backends/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# __coconut_hash__ = 0x6f664d1
# __coconut_hash__ = 0x189fad0

# Compiled with Coconut version 1.4.0-post_dev40 [Ernest Scribbler]

Expand Down Expand Up @@ -218,13 +218,16 @@ def __init__(self, examples=None, params=None):
def attempt_update(self, examples=None, params=None, *args, **kwargs):
"""Attempt to update this backend with new arguments. False indicates that the
update failed while True indicates a successful update."""
if (self.tell_examples is NotImplemented or not params or params != self._params or args != self._args or kwargs != self._kwargs):
if (self.tell_examples is NotImplemented or not self._params or params != self._params or args != self._args or kwargs != self._kwargs):
return False
old_examples, new_examples = examples[:len(self._examples)], examples[len(self._examples):]
if old_examples != self._examples:
return False
if new_examples:
self.tell_examples(new_examples, params)
try:
self.tell_examples(new_examples, params)
except NotImplementedError:
return False
return True

# implement tell_examples(new_examples, params) to allow fast updating on new data
Expand Down
4 changes: 2 additions & 2 deletions bbopt/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# __coconut_hash__ = 0x6996f040
# __coconut_hash__ = 0x4564326b

# Compiled with Coconut version 1.4.0-post_dev40 [Ernest Scribbler]

Expand Down Expand Up @@ -29,7 +29,7 @@

# Installation constants:
name = "bbopt"
version = "1.1.4"
version = "1.1.5"
description = "The easiest hyperparameter optimization you'll ever do."
long_description = """
See BBopt's GitHub_ for more information.
Expand Down

0 comments on commit d14e8a0

Please sign in to comment.