forked from sarmueller/gibo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
loop.py
83 lines (71 loc) · 2.75 KB
/
loop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import Optional, Dict, Callable, Union, Tuple
import torch
from src.environment_api import EnvironmentObjective
from src.optimizers import (
RandomSearch,
VanillaBayesianOptimization,
BayesianGradientAscent,
)
def call_counter(func) -> Callable:
"""Decorate a function and "substitute" it function with a wrapper that
count its calls.
Args:
func: Function which calls should be counted.
Returns:
Helper function which has attributes _calls and _func.
"""
def helper(*args, **kwargs):
helper._calls += 1
return func(*args, **kwargs)
helper._func = func
helper._calls = 0
return helper
def loop(
params_init: torch.tensor,
max_iterations: Optional[int],
max_objective_calls: Optional[int],
objective: Union[Callable[[torch.Tensor], torch.Tensor], EnvironmentObjective],
Optimizer: Union[RandomSearch, BayesianGradientAscent, VanillaBayesianOptimization],
optimizer_config: Optional[Dict],
verbose=True,
) -> Tuple[list, list]:
"""Connects parameters with objective and optimizer.
Args:
params_init:
max_iterations: Stopping criterion for optimization after maximum of
iterations (update steps of parameters).
max_objective_calls: Stopping criterion for optimization after maximum
of function calls of objective function.
objective: Objective function to be optimized (search for maximum).
Optimizer: One of the implemented optimizers.
optimizer_config: Configuration dictionary for optimizer.
verbose: If True an output is logged.
Returns:
Tuple of
- list of parameters history
- list of objective function calls in every iteration
"""
calls_in_iteration = []
objective_w_counter = call_counter(objective)
optimizer = Optimizer(params_init, objective_w_counter, **optimizer_config)
if max_iterations:
for iteration in range(max_iterations):
if verbose:
print(f"--- Iteration {iteration+1} ---")
optimizer()
calls_in_iteration.append(objective_w_counter._calls)
elif max_objective_calls:
iteration = 0
while objective_w_counter._calls < max_objective_calls:
if verbose:
print(
f"--- Iteration {iteration+1} ({objective_w_counter._calls} objective calls so far) ---"
)
optimizer()
iteration += 1
calls_in_iteration.append(objective_w_counter._calls)
if verbose:
print(
f"\nObjective function was called {objective_w_counter._calls} times (sample complexity).\n"
)
return optimizer.params_history_list, calls_in_iteration