diff --git a/src/progpy/models/aircraft_model/small_rotorcraft.py b/src/progpy/models/aircraft_model/small_rotorcraft.py index 3cc50cb..06cee9d 100644 --- a/src/progpy/models/aircraft_model/small_rotorcraft.py +++ b/src/progpy/models/aircraft_model/small_rotorcraft.py @@ -9,6 +9,7 @@ from progpy.models.aircraft_model.vehicles.aero import aerodynamics as aero from progpy.models.aircraft_model.vehicles import vehicles from progpy.utils.traj_gen import geometry as geom +from progpy.utils.traj_gen.trajectory import TrajectoryFigure class SmallRotorcraft(AircraftModel): @@ -321,7 +322,7 @@ def linear_model(self, phi, theta, psi, p, q, r, T): return A, B - def visualize_traj(self, pred, ref=None, prefix=''): + def visualize_traj(self, pred, ref=None, prefix='', fig=None, **kwargs): """ This method provides functionality to visualize a predicted trajectory generated, plotted with the reference trajectory. @@ -338,21 +339,21 @@ def visualize_traj(self, pred, ref=None, prefix=''): Reference trajectory - dict with keys for each state in the vehicle model and corresponding values as numpy arrays prefix : str, optional Prefix added to keys in predicted values. This is used to plot the trajectory using the results from a composite model + fig : TrajectoryFigure, optional + Figure where the additional diagrams are to be added. Creates a new figure if not provided - Returns + Returns ------- - fig : Visualization of trajectory generation results + TrajectoryFigure : Visualization of trajectory generation results """ - # Extract predicted trajectory information - pred_time = pred.times - pred_x = [pred.outputs[iter][prefix+'x'] for iter in range(len(pred_time))] - pred_y = [pred.outputs[iter][prefix+'y'] for iter in range(len(pred_time))] - pred_z = [pred.outputs[iter][prefix+'z'] for iter in range(len(pred_time))] + if fig is None: + fig = plt.figure(FigureClass=TrajectoryFigure) + elif not isinstance(fig, TrajectoryFigure): + raise TypeError(f"fig must be a TrajectorFigure, was {type(fig)}") - # Initialize Figure - params = dict(figsize=(13, 9), fontsize=14, linewidth=2.0, alpha_preds=0.6) - fig, (ax1, ax2) = plt.subplots(2) + params = {'linewidth': 2.0, 'alpha': 0.6} + params.update(kwargs) # Handle reference information if ref is not None: @@ -363,19 +364,17 @@ def visualize_traj(self, pred, ref=None, prefix=''): ref_z = ref['z'].tolist() # Plot reference trajectories - ax1.plot(ref_x, ref_y, '--', linewidth=params['linewidth'], color='tab:orange', alpha=0.5, label='reference trajectory') - ax2.plot(time, ref_z, '-', color='tab:orange', alpha=params['alpha_preds'], linewidth=params['linewidth'], label='reference trajectory') + fig.plot_traj(ref_x, ref_y, linestyle='--', color='tab:orange', label='reference trajectory', **params) + fig.plot_alt(time, ref_z, linestyle='-', color='tab:orange', label='reference trajectory', **params) + + # Extract predicted trajectory information + pred_x = [pred.outputs[iter][prefix+'x'] for iter in range(len(pred.times))] + pred_y = [pred.outputs[iter][prefix+'y'] for iter in range(len(pred.times))] + pred_z = [pred.outputs[iter][prefix+'z'] for iter in range(len(pred.times))] # Plot predictions - ax1.plot(pred_x, pred_y,'-', color='tab:blue', alpha=params['alpha_preds'], linewidth=params['linewidth'], label='prediction') - ax2.plot(pred_time, pred_z,'-', color='tab:blue',alpha=params['alpha_preds'], linewidth=params['linewidth'], label='prediction') - - # Add labels - ax1.set_xlabel('x', fontsize=params['fontsize']) - ax1.set_ylabel('y', fontsize=params['fontsize']) - ax1.legend(fontsize=params['fontsize']) - ax2.set_xlabel('time stamp, -', fontsize=params['fontsize']) - ax2.set_ylabel('z', fontsize=params['fontsize']) + fig.plot_traj(pred_x, pred_y, linestyle='--', color='tab:orange', label='prediction', **params) + fig.plot_alt(pred.times, pred_z, linestyle='-', color='tab:orange', label='prediction', **params) return fig \ No newline at end of file diff --git a/src/progpy/utils/traj_gen/trajectory.py b/src/progpy/utils/traj_gen/trajectory.py index b511a34..6800ba9 100644 --- a/src/progpy/utils/traj_gen/trajectory.py +++ b/src/progpy/utils/traj_gen/trajectory.py @@ -5,6 +5,8 @@ Auxiliary functions for trajectories and aircraft routes """ +from matplotlib import pyplot as plt +from matplotlib.figure import Figure import numpy as np from warnings import warn @@ -12,6 +14,41 @@ from progpy.utils.traj_gen.nurbs import NURBS +class TrajectoryFigure(Figure): + """ + Figure visualizing a trajectory + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + (ax_traj, ax_alt) = self.subplots(2) + ax_traj.set_xlabel('x', fontsize=14) + ax_traj.set_ylabel('y', fontsize=14) + + ax_alt.set_xlabel('time stamp, -', fontsize=14) + ax_alt.set_ylabel('z', fontsize=14) + + def plot_traj(self, x, y, **kwargs) -> None: + """ + Plot the trajectory in 2d space (upper graph) + + Args: + x (list[float]): x location + y (list[float]): y location + """ + ax_traj = self.axes[0] + ax_traj.plot(x, y, **kwargs) + + def plot_alt(self, t, z, **kwargs) -> None: + """ + Plot the altitude vs time (lower graph) + + Args: + t (list[float]): Times + z (list[float]): z location + """ + ax_alt = self.axes[1] + ax_alt.plot(t, z, **kwargs) + def compute_derivatives(position_profile, timevec): # Compute derivatives of position: velocity and acceleration # (optional: jerk, not needed) @@ -288,6 +325,34 @@ def ref_traj(self,): 't': self.trajectory['time']} + def plot(self, fig=None, **kwargs) -> TrajectoryFigure: + """ + Plot the reference trajectory + + Args: + fig (TrajectoryFigure, optional): Figure where the additional diagrams are to be added. Creates a new figure if not provided + + Raises: + TypeError: if Figure is not a TrajectoryFigure + + Returns: + TrajectoryFigure + """ + params = {'figsize': (13, 9), 'linewidth': 2.0, 'alpha_preds': 0.6} + params.update(kwargs) + if fig is None: + fig = plt.figure(FigureClass=TrajectoryFigure) + elif not isinstance(fig, TrajectoryFigure): + raise TypeError(f"fig must be a TrajectorFigure, was {type(fig)}") + + x = self.trajectory['position'][:, 0].tolist() + y = self.trajectory['position'][:, 1].tolist() + z = self.trajectory['position'][:, 2].tolist() + t = self.trajectory['time'].tolist() + + fig.plot_traj(x, y, **params) + fig.plot_alt(t, z, **params) + def compute_attitude(self, heading_profile, acceleration_profile, timestep_size): """ Compute attitude defined by Euler's angles as a function of time given heading and acceleration profiles.