Skip to content

Commit

Permalink
Update data utils docs
Browse files Browse the repository at this point in the history
  • Loading branch information
itskalvik committed Aug 17, 2024
1 parent 5ee220e commit 409d0d9
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 24 deletions.
2 changes: 2 additions & 0 deletions docs/API-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ ________________________________________________________________________________
::: sgptools.utils.metrics
---
::: sgptools.utils.gpflow
---
::: sgptools.utils.data
____________________________________________________________________________________________________________________________________________________________
---
::: sgptools.kernels.neural_kernel
Expand Down
1 change: 1 addition & 0 deletions sgptools/models/cma_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class CMA_ES:
as waypoints of a path
num_robots (int): Number of robots, used when modeling
multi-robot IPP with a distance budget
transform (Transform): Transform object
"""
def __init__(self, X_train, noise_variance, kernel,
distance_budget=None,
Expand Down
103 changes: 85 additions & 18 deletions sgptools/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,86 @@
from sklearn.preprocessing import StandardScaler
from hkb_diamondsquare.DiamondSquare import diamond_square

# Load optional dependency
try:
from osgeo import gdal
except:
pass


####################################################
# Utils used to prepare synthetic datasets

'''
Remove points inside polygons
'''
def remove_polygons(X, Y, polygons):
'''
Remove points inside polygons.
Args:
X (ndarray): (N,); array of x-coordinate
Y (ndarray): (N,); array of y-coordinate
polygons (list of matplotlib path polygon): Polygons to remove from the X, Y points
Returns:
X (ndarray): (N,); array of x-coordinate
Y (ndarray): (N,); array of y-coordinate
'''
points = np.array([X.flatten(), Y.flatten()]).T
for polygon in polygons:
p = path.Path(polygon)
points = points[~p.contains_points(points)]
return points[:, 0], points[:, 1]

'''
Remove points inside circle patches
'''
def remove_circle_patches(X, Y, circle_patches):
'''
Remove points inside polycircle patchesgons.
Args:
X (ndarray): (N,); array of x-coordinate
Y (ndarray): (N,); array of y-coordinate
polygons (list of matplotlib circle patches): Circle patches to remove from the X, Y points
Returns:
X (ndarray): (N,); array of x-coordinate
Y (ndarray): (N,); array of y-coordinate
'''
points = np.array([X.flatten(), Y.flatten()]).T
for circle_patch in circle_patches:
points = points[~circle_patch.contains_points(points)]
return points[:, 0], points[:, 1]

'''
Generate a point at a distance d from a point at angle theta
Args:
point: (N, 2) array of points
d: distance
theta: angle in radians
'''
def point_pos(point, d, theta):
'''
Generate a point at a distance d from a point at angle theta.
Args:
point (ndarray): (N, 2); array of points
d (float): distance
theta (float): angle in radians
Returns:
X (ndarray): (N,); array of x-coordinate
Y (ndarray): (N,); array of y-coordinate
'''
return np.c_[point[:, 0] + d*np.cos(theta), point[:, 1] + d*np.sin(theta)]

####################################################

def prep_tif_dataset(dataset_path=None):
def prep_tif_dataset(dataset_path):
'''Load and preprocess a dataset from a GeoTIFF file (.tif file). The input features
are set to the x and y pixel block coordinates and the labels are read from the file.
The method also removes all invalid points.
Large tif files
need to be downsampled using the following command:
```gdalwarp -tr 50 50 <input>.tif <output>.tif```
Args:
dataset_path (str): Path to the dataset file, used only when dataset_type is 'tif'.
Returns:
X: (n, d); Dataset input features
y: (n, 1); Dataset labels
'''
ds = gdal.Open(dataset_path)
cols = ds.RasterXSize
rows = ds.RasterYSize
Expand All @@ -75,6 +114,15 @@ def prep_tif_dataset(dataset_path=None):
####################################################

def prep_synthetic_dataset():
'''Generates a 50x50 grid of synthetic elevation data using the diamond square algorithm.
```https://github.com/buckinha/DiamondSquare```
Args:
Returns:
X: (n, d); Dataset input features
y: (n, 1); Dataset labels
'''
data = diamond_square(shape=(50,50),
min_height=0,
max_height=30,
Expand All @@ -92,15 +140,34 @@ def prep_synthetic_dataset():

####################################################

def get_dataset(dataset, dataset_path=None,
def get_dataset(dataset_type, dataset_path=None,
num_train=1000,
num_test=2500,
num_candidates=150):
"""Method to generate/load datasets and preprocess them for SP/IPP. The method uses kmeans to
generate train and test sets.
Args:
dataset_type (str): 'tif' or 'synthetic'. 'tif' will load and proprocess data from a GeoTIFF file.
'synthetic' will use the diamond square algorithm to generate synthetic elevation data.
dataset_path (str): Path to the dataset file, used only when dataset_type is 'tif'.
num_train (int): Number of training samples to generate.
num_test (int): Number of testing samples to generate.
num_candidates (int): Number of candidate locations to generate.
Returns:
X_train (ndarray): (n, d); Training set inputs
y_train (ndarray): (n, 1); Training set labels
X_test (ndarray): (n, d); Testing set inputs
y_test (ndarray): (n, 1); Testing set labels
candidates (ndarray): (n, d); Candidate sensor placement locations
X: (n, d); Full dataset inputs
y: (n, 1); Full dataset labels
"""
# Load the data
if dataset == 'tif':
if dataset_type == 'tif':
X, y = prep_tif_dataset(dataset_path=dataset_path)
elif dataset == 'synthetic':
elif dataset_type == 'synthetic':
X, y = prep_synthetic_dataset()

X_train = get_inducing_pts(X, num_train)
Expand Down
28 changes: 22 additions & 6 deletions sgptools/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,33 @@ def interpolate_path(waypoints, sampling_rate=0.05):
interpolated_path.extend(points)
return np.array(interpolated_path)

# Reorder the waypoints to match the order of the points in the path
# The waypoints are mathched to the closest points in the path
def reoder_path(path, waypoints):
def _reoder_path(path, waypoints):
"""Reorder the waypoints to match the order of the points in the path.
The waypoints are mathched to the closest points in the path. Used by project_waypoints.
Args:
path (n, d): Robot path, i.e., waypoints in the path traversal order
waypoints (n, d): Waypoints that need to be reordered to match the target path
Returns:
waypoints (n, d): Reordered waypoints of the robot's path
"""
dists = pairwise_distances(path, Y=waypoints, metric='euclidean')
_, col_ind = linear_sum_assignment(dists)
Xu = waypoints[col_ind].copy()
return Xu

# Project the waypoints back to the candidate set while retaining the
# waypoint visitation order
def project_waypoints(waypoints, candidates):
"""Project the waypoints back to the candidate set while retaining the
waypoint visitation order.
Args:
waypoints (n, d): Waypoints of the robot's path
candidates (ndarray): (n, 2); Discrete set of candidate locations
Returns:
waypoints (n, d): Projected waypoints of the robot's path
"""
waypoints_disc = cont2disc(waypoints, candidates)
waypoints_valid = reoder_path(waypoints, waypoints_disc)
waypoints_valid = _reoder_path(waypoints, waypoints_disc)
return waypoints_valid

0 comments on commit 409d0d9

Please sign in to comment.