Skip to content

Commit

Permalink
Update trajectory formats
Browse files Browse the repository at this point in the history
  • Loading branch information
teubert committed Dec 15, 2023
1 parent bda3c9e commit cf3b525
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 22 deletions.
43 changes: 21 additions & 22 deletions src/progpy/models/aircraft_model/small_rotorcraft.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from progpy.models.aircraft_model.vehicles.aero import aerodynamics as aero
from progpy.models.aircraft_model.vehicles import vehicles
from progpy.utils.traj_gen import geometry as geom
from progpy.utils.traj_gen.trajectory import TrajectoryFigure


class SmallRotorcraft(AircraftModel):
Expand Down Expand Up @@ -321,7 +322,7 @@ def linear_model(self, phi, theta, psi, p, q, r, T):

return A, B

def visualize_traj(self, pred, ref=None, prefix=''):
def visualize_traj(self, pred, ref=None, prefix='', fig=None, **kwargs):
"""
This method provides functionality to visualize a predicted trajectory generated, plotted with the reference trajectory.
Expand All @@ -338,21 +339,21 @@ def visualize_traj(self, pred, ref=None, prefix=''):
Reference trajectory - dict with keys for each state in the vehicle model and corresponding values as numpy arrays
prefix : str, optional
Prefix added to keys in predicted values. This is used to plot the trajectory using the results from a composite model
fig : TrajectoryFigure, optional
Figure where the additional diagrams are to be added. Creates a new figure if not provided
Returns
Returns
-------
fig : Visualization of trajectory generation results
TrajectoryFigure : Visualization of trajectory generation results
"""

# Extract predicted trajectory information
pred_time = pred.times
pred_x = [pred.outputs[iter][prefix+'x'] for iter in range(len(pred_time))]
pred_y = [pred.outputs[iter][prefix+'y'] for iter in range(len(pred_time))]
pred_z = [pred.outputs[iter][prefix+'z'] for iter in range(len(pred_time))]
if fig is None:
fig = plt.figure(FigureClass=TrajectoryFigure)
elif not isinstance(fig, TrajectoryFigure):
raise TypeError(f"fig must be a TrajectorFigure, was {type(fig)}")

# Initialize Figure
params = dict(figsize=(13, 9), fontsize=14, linewidth=2.0, alpha_preds=0.6)
fig, (ax1, ax2) = plt.subplots(2)
params = {'linewidth': 2.0, 'alpha': 0.6}
params.update(kwargs)

# Handle reference information
if ref is not None:
Expand All @@ -363,19 +364,17 @@ def visualize_traj(self, pred, ref=None, prefix=''):
ref_z = ref['z'].tolist()

# Plot reference trajectories
ax1.plot(ref_x, ref_y, '--', linewidth=params['linewidth'], color='tab:orange', alpha=0.5, label='reference trajectory')
ax2.plot(time, ref_z, '-', color='tab:orange', alpha=params['alpha_preds'], linewidth=params['linewidth'], label='reference trajectory')
fig.plot_traj(ref_x, ref_y, linestyle='--', color='tab:orange', label='reference trajectory', **params)
fig.plot_alt(time, ref_z, linestyle='-', color='tab:orange', label='reference trajectory', **params)

# Extract predicted trajectory information
pred_x = [pred.outputs[iter][prefix+'x'] for iter in range(len(pred.times))]
pred_y = [pred.outputs[iter][prefix+'y'] for iter in range(len(pred.times))]
pred_z = [pred.outputs[iter][prefix+'z'] for iter in range(len(pred.times))]

# Plot predictions
ax1.plot(pred_x, pred_y,'-', color='tab:blue', alpha=params['alpha_preds'], linewidth=params['linewidth'], label='prediction')
ax2.plot(pred_time, pred_z,'-', color='tab:blue',alpha=params['alpha_preds'], linewidth=params['linewidth'], label='prediction')

# Add labels
ax1.set_xlabel('x', fontsize=params['fontsize'])
ax1.set_ylabel('y', fontsize=params['fontsize'])
ax1.legend(fontsize=params['fontsize'])
ax2.set_xlabel('time stamp, -', fontsize=params['fontsize'])
ax2.set_ylabel('z', fontsize=params['fontsize'])
fig.plot_traj(pred_x, pred_y, linestyle='--', color='tab:orange', label='prediction', **params)
fig.plot_alt(pred.times, pred_z, linestyle='-', color='tab:orange', label='prediction', **params)

return fig

65 changes: 65 additions & 0 deletions src/progpy/utils/traj_gen/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,50 @@
Auxiliary functions for trajectories and aircraft routes
"""

from matplotlib import pyplot as plt
from matplotlib.figure import Figure
import numpy as np
from warnings import warn

from progpy.utils.traj_gen import geometry as geom
from progpy.utils.traj_gen.nurbs import NURBS


class TrajectoryFigure(Figure):
"""
Figure visualizing a trajectory
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
(ax_traj, ax_alt) = self.subplots(2)
ax_traj.set_xlabel('x', fontsize=14)
ax_traj.set_ylabel('y', fontsize=14)

ax_alt.set_xlabel('time stamp, -', fontsize=14)
ax_alt.set_ylabel('z', fontsize=14)

def plot_traj(self, x, y, **kwargs) -> None:
"""
Plot the trajectory in 2d space (upper graph)
Args:
x (list[float]): x location
y (list[float]): y location
"""
ax_traj = self.axes[0]
ax_traj.plot(x, y, **kwargs)

def plot_alt(self, t, z, **kwargs) -> None:
"""
Plot the altitude vs time (lower graph)
Args:
t (list[float]): Times
z (list[float]): z location
"""
ax_alt = self.axes[1]
ax_alt.plot(t, z, **kwargs)

def compute_derivatives(position_profile, timevec):
# Compute derivatives of position: velocity and acceleration
# (optional: jerk, not needed)
Expand Down Expand Up @@ -288,6 +325,34 @@ def ref_traj(self,):

't': self.trajectory['time']}

def plot(self, fig=None, **kwargs) -> TrajectoryFigure:
"""
Plot the reference trajectory
Args:
fig (TrajectoryFigure, optional): Figure where the additional diagrams are to be added. Creates a new figure if not provided
Raises:
TypeError: if Figure is not a TrajectoryFigure
Returns:
TrajectoryFigure
"""
params = {'figsize': (13, 9), 'linewidth': 2.0, 'alpha_preds': 0.6}
params.update(kwargs)
if fig is None:
fig = plt.figure(FigureClass=TrajectoryFigure)
elif not isinstance(fig, TrajectoryFigure):
raise TypeError(f"fig must be a TrajectorFigure, was {type(fig)}")

x = self.trajectory['position'][:, 0].tolist()
y = self.trajectory['position'][:, 1].tolist()
z = self.trajectory['position'][:, 2].tolist()
t = self.trajectory['time'].tolist()

fig.plot_traj(x, y, **params)
fig.plot_alt(t, z, **params)

def compute_attitude(self, heading_profile, acceleration_profile, timestep_size):
"""
Compute attitude defined by Euler's angles as a function of time given heading and acceleration profiles.
Expand Down

0 comments on commit cf3b525

Please sign in to comment.