Skip to content

Commit

Permalink
Clean and comment code/example, increment sgptools version
Browse files Browse the repository at this point in the history
  • Loading branch information
itskalvik committed Sep 2, 2024
1 parent fd2610b commit ee45aa8
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 285 deletions.
81 changes: 38 additions & 43 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@


'''
Online IPP (data collection phase is excluded from the total runtime)
The parameters GP is initilized with the parameters of the kernel and variance in ipp_model
Online/Adaptive IPP (data collection phase is excluded from the total runtime)
The hyperparameters GP is initilized with the hyperparameters of the kernel and variance in ipp_model
Args:
X_train: numpy array (n, 2), Inputs X data used to initilize the OSGPR inducing points
Expand All @@ -38,24 +38,24 @@
path2data: Function that takes the path and returns the data from the path
continuous_ipp: bool, If True, model continuous sensing robots
ipp_method: str, 'SGP' or 'CMA', method used for IPP
param_method: str, 'GP' or 'SSGP', method used for parameter updates
param_method: str, 'GP' or 'SSGP', method used for hyperparameter updates
plot: bool, If True, all intermediate IPP solutions are plotted and saved to disk
Returns:
sol_data_X: Locations X where the robot traversed
sol_data_y: Ground truth label data from the dataset corresponding to sol_data_X
total_time_param : Total runtime of the parameter update phase of the online IPP approach
total_time_param : Total runtime of the hyperparameter update phase of the online IPP approach
excluding the time taken to get the
data collected along the solution paths.
total_time_ipp : Total runtime of the IPP update phase of the online IPP approach
excluding the time taken to get the
data collected along the solution paths.
'''
def online_ipp(X_train, ipp_model, Xu_init, path2data,
continuous_ipp=False,
ipp_method='SGP',
param_method='GP',
plot=False):
def run_ipp(X_train, ipp_model, Xu_init, path2data,
continuous_ipp=False,
ipp_method='SGP',
param_method='GP',
plot=False):
total_time_param = 0
total_time_ipp = 0
num_robots = Xu_init.shape[0]
Expand Down Expand Up @@ -100,7 +100,7 @@ def online_ipp(X_train, ipp_model, Xu_init, path2data,
if time_step == num_waypoints:
break # Skip param and path update for the last waypoint

# Init/update parameter model
# Init/update hyperparameters model
start_time = time()
if param_method=='GP':
# Starting from initial params ensures recovery from bad params
Expand Down Expand Up @@ -155,8 +155,7 @@ def online_ipp(X_train, ipp_model, Xu_init, path2data,
return np.array(sol_data_X), np.array(sol_data_y), total_time_param, total_time_ipp


def main(dataset_type,
dataset_path,
def main(dataset_path,
num_mc,
num_robots,
max_dist,
Expand All @@ -180,9 +179,7 @@ def main(dataset_type,
path2data = lambda x : cont2disc(x, X, y)

# Get the data
X_train, y_train, X_test, y_test, candidates, X, y = get_dataset(dataset_type,
dataset_path,
num_train=1000)
X_train, y_train, X_test, y_test, candidates, X, y = get_dataset(dataset_path)

# Get oracle hyperparameters to benchmark rmse
start_time = time()
Expand Down Expand Up @@ -250,15 +247,15 @@ def main(dataset_type,
transform,
Xu_init=Xu_init.reshape(-1, 2),
max_steps=0)
online_X, online_y, param_time, ipp_time = online_ipp(X_train,
ipp_sgpr,
Xu_init,
path2data,
continuous_ipp,
'SGP',
'SSGP' if continuous_ipp else 'GP')
solution_X, solution_y, param_time, ipp_time = run_ipp(X_train,
ipp_sgpr,
Xu_init,
path2data,
continuous_ipp,
'SGP',
'SSGP' if continuous_ipp else 'GP')
# Get RMSE from oracle hyperparameters
y_pred, _ = get_reconstruction((online_X, online_y),
y_pred, _ = get_reconstruction((solution_X, solution_y),
X_test,
noise_variance_opt,
kernel_opt)
Expand All @@ -284,15 +281,15 @@ def main(dataset_type,
aggregate_fov=True),
Xu_init=Xu_init.reshape(-1, 2),
max_steps=0)
online_X, online_y, param_time, ipp_time = online_ipp(X_train,
ipp_sgpr,
Xu_init,
path2data,
continuous_ipp,
'SGP',
'SSGP' if continuous_ipp else 'GP')
solution_X, solution_y, param_time, ipp_time = run_ipp(X_train,
ipp_sgpr,
Xu_init,
path2data,
continuous_ipp,
'SGP',
'SSGP' if continuous_ipp else 'GP')
# Get RMSE from oracle hyperparameters
y_pred, _ = get_reconstruction((online_X, online_y),
y_pred, _ = get_reconstruction((solution_X, solution_y),
X_test,
noise_variance_opt,
kernel_opt)
Expand All @@ -313,15 +310,15 @@ def main(dataset_type,
kernel,
num_robots=num_robots,
transform=transform)
online_X, online_y, param_time, ipp_time = online_ipp(X_train,
cma_es,
Xu_init,
path2data,
continuous_ipp,
'CMA',
'SSGP' if continuous_ipp else 'GP')
solution_X, solution_y, param_time, ipp_time = run_ipp(X_train,
cma_es,
Xu_init,
path2data,
continuous_ipp,
'CMA',
'SSGP' if continuous_ipp else 'GP')
# Get RMSE from oracle hyperparameters
y_pred, _ = get_reconstruction((online_X, online_y),
y_pred, _ = get_reconstruction((solution_X, solution_y),
X_test,
noise_variance_opt,
kernel_opt)
Expand Down Expand Up @@ -457,15 +454,13 @@ def main(dataset_type,
parser.add_argument("--num_robots", type=int, default=1)
parser.add_argument("--sampling_rate", type=int, default=2)
parser.add_argument("--dataset_path", type=str,
default='datasets/bathymetry/bathymetry.tif')
default='../datasets/bathymetry/bathymetry.tif')
args=parser.parse_args()

max_dist = 350 if args.num_robots==1 else 150
dataset_type = 'tif'
max_range = 101 if args.num_robots==1 and args.sampling_rate==2 else 51
xrange = range(5, max_range, 5)
main(dataset_type,
args.dataset_path,
main(args.dataset_path,
args.num_mc,
args.num_robots,
max_dist,
Expand Down
Loading

0 comments on commit ee45aa8

Please sign in to comment.