Skip to content

Commit

Permalink
Heatmap plotting (#13)
Browse files Browse the repository at this point in the history
* Create heatmap plot
* Implement product in log domain
* Remove obsolete parameters
  • Loading branch information
duembgen authored Sep 11, 2020
1 parent a6aea99 commit 75536fa
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 56 deletions.
2 changes: 1 addition & 1 deletion crazyflie-audio
Submodule crazyflie-audio updated 106 files
43 changes: 15 additions & 28 deletions src/audio_stack/audio_stack/doa_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np

from audio_interfaces.msg import Spectrum, DoaEstimates
from .spectrum_estimator import normalize_each_row, NORMALIZE
from .spectrum_estimator import normalize_rows, combine_rows, NORMALIZE

N_ESTIMATES = 3
COMBINATION_N = 5
Expand All @@ -35,7 +35,7 @@ def __init__(self):
Spectrum, "audio/spectrum", self.listener_callback_spectrum, 10
)
self.publisher_spectrum = self.create_publisher(
Spectrum, "audio/combined_spectrum", 10
Spectrum, "audio/dynamic_spectrum", 10
)
self.publisher_doa = self.create_publisher(
DoaEstimates, "geometry/doa_estimates", 10
Expand All @@ -45,17 +45,17 @@ def __init__(self):

# create ROS parameters that can be changed from command line.
self.declare_parameter("combination_n")
self.combination_n = COMBINATION_N
self.declare_parameter("combination_method")
self.combination_method = COMBINATION_METHOD
parameters = [
rclpy.parameter.Parameter(
"combination_method",
rclpy.Parameter.Type.STRING,
self.combination_method,
COMBINATION_METHOD,
),
rclpy.parameter.Parameter(
"combination_n", rclpy.Parameter.Type.INTEGER, self.combination_n
"combination_n",
rclpy.Parameter.Type.INTEGER,
COMBINATION_N
),
]
self.set_parameters_callback(self.set_params)
Expand All @@ -65,7 +65,7 @@ def set_params(self, params):
for param in params:
if param.name == "combination_method":
self.combination_method = param.get_parameter_value().string_value
if param.name == "combination_n":
elif param.name == "combination_n":
self.combination_n = param.get_parameter_value().integer_value
else:
return SetParametersResult(successful=False)
Expand All @@ -92,32 +92,19 @@ def listener_callback_spectrum(self, msg_spec):
# TODO(FD) use a rolling buffer (linked list or so) instead of the copies here.
spectra_shifted.append(np.c_[spectrum[:, index:], spectrum[:, :index]])

if self.combination_method == "sum":
combined_spectrum = np.sum(
spectra_shifted, axis=0
) # n_frequencies x n_angles
elif self.combination_method == "product":
combined_spectrum = np.product(
spectra_shifted, axis=0
) # n_frequencies x n_angles

combined_spectrum = normalize_each_row(combined_spectrum, NORMALIZE)
dynamic_spectrum = combine_rows(spectra_shifted, self.combination_method, keepdims=False) # n_frequencies x n_angles
dynamic_spectrum = normalize_rows(dynamic_spectrum, NORMALIZE)

# publish
msg_new = msg_spec
msg_new.spectrum_vect = list(combined_spectrum.astype(float).flatten())
msg_new.spectrum_vect = list(dynamic_spectrum.astype(float).flatten())
self.publisher_spectrum.publish(msg_new)
self.get_logger().info(f"Published combined spectrum.")
self.get_logger().info(f"Published dynamic spectrum.")

# calculate and publish doa estimates
if self.combination_method == "product":
# need to make sure spectrum is not too small before multiplying.
final_spectrum = np.product(normalize_each_row(combined_spectrum, "zero_to_one"),
axis=0, keepdims=True) # n_angles
elif self.combination_method == "sum":
final_spectrum = np.sum(combined_spectrum, axis=0, keepdims=True) # n_angles

final_spectrum = normalize_each_row(final_spectrum, NORMALIZE)
final_spectrum = combine_rows(dynamic_spectrum, self.combination_method, keepdims=True)
final_spectrum = normalize_rows(final_spectrum, NORMALIZE)

angles = np.linspace(0, 360, msg_spec.n_angles)
sorted_indices = np.argsort(final_spectrum.flatten()) # sorts in ascending order
doa_estimates = angles[sorted_indices[-N_ESTIMATES:][::-1]]
Expand All @@ -127,7 +114,7 @@ def listener_callback_spectrum(self, msg_spec):
msg_doa.timestamp = msg_spec.timestamp
msg_doa.doa_estimates_deg = list(doa_estimates.astype(float).flatten())
self.publisher_doa.publish(msg_doa)
self.get_logger().info(f"Published estimates: {doa_estimates}.")
self.get_logger().info(f"Published doa estimates: {doa_estimates}.")


def main(args=None):
Expand Down
29 changes: 23 additions & 6 deletions src/audio_stack/audio_stack/spectrum_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,42 @@
# - "mvdr": minimum-variacne distortionless response
BF_METHOD = "das"

NORMALIZE = "zero_to_one"
NORMALIZE = "zero_to_one_all"
#NORMALIZE = "zero_to_one"
#NORMALIZE = "sum_to_one"


def normalize_each_row(matrix, method="zero_to_one"):
def normalize_rows(matrix, method="zero_to_one"):
if method == "zero_to_one":
normalized = (matrix - np.min(matrix, axis=1, keepdims=True)) / (np.max(matrix, axis=1, keepdims=True) - np.min(matrix, axis=1, keepdims=True))
#assert np.max(normalized) == 1.0
#assert np.min(normalized) == 0.0
return normalized
elif method == "zero_to_one_all":
normalized = (matrix - np.min(matrix)) / (np.max(matrix) - np.min(matrix))
#assert np.max(normalized) == 1.0
#assert np.min(normalized) == 0.0
elif method == "sum_to_one":

# first make sure values are between 0 and 1 (otherwise division can lead to errors)
denom = np.max(matrix, axis=1, keepdims=True) - np.min(matrix, axis=1, keepdims=True)
matrix = (matrix - np.min(matrix, axis=1, keepdims=True)) / denom
sum_matrix = np.sum(matrix, axis=1, keepdims=True)
normalized = matrix / sum_matrix
np.testing.assert_allclose(np.sum(normalized, axis=1), 1.0, rtol=1e-5)
return normalized
else:
raise ValueError(method)
return normalized


def combine_rows(matrix, method="product", keepdims=False):
if method == "product":
# do the product in log domain for numerical reasons
# sum(log10(matrix)) = log10(product(matrix))
combined_matrix = np.power(10, np.sum(np.log10(matrix), axis=0, keepdims=keepdims))
elif method == "sum":
combined_matrix = np.sum(matrix, axis=0, keepdims=keepdims)
else:
raise ValueError(method)
return combined_matrix


class SpectrumEstimator(Node):
Expand Down Expand Up @@ -129,7 +146,7 @@ def listener_callback_correlations(self, msg_cor):
else:
orientation = message.yaw_deg

spectrum = normalize_each_row(spectrum, NORMALIZE)
spectrum = normalize_rows(spectrum, NORMALIZE)

# publish
msg_spec = Spectrum()
Expand Down
10 changes: 4 additions & 6 deletions src/crazyflie_crtp/crazyflie_crtp/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,12 @@
PARAMETERS_TUPLES = [
("debug", rclpy.Parameter.Type.INTEGER, 0),
("send_audio_enable", rclpy.Parameter.Type.INTEGER, 1),
("filter_snr_enable", rclpy.Parameter.Type.INTEGER, 0),
("filter_prop_enable", rclpy.Parameter.Type.INTEGER, 0),
("max_freq", rclpy.Parameter.Type.INTEGER, 10000),
("min_freq", rclpy.Parameter.Type.INTEGER, 100),
("max_freq", rclpy.Parameter.Type.INTEGER, 10000),
("delta_freq", rclpy.Parameter.Type.INTEGER, 100),
("use_iir", rclpy.Parameter.Type.INTEGER, 0),
("ma_window", rclpy.Parameter.Type.INTEGER, 1),
("alpha_iir", rclpy.Parameter.Type.DOUBLE, 0.5)
("n_average", rclpy.Parameter.Type.INTEGER, 1),
("filter_snr_enable", rclpy.Parameter.Type.INTEGER, 0),
("filter_prop_enable", rclpy.Parameter.Type.INTEGER, 0),
]


Expand Down
32 changes: 19 additions & 13 deletions src/topic_plotter/topic_plotter/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from audio_interfaces.msg import Spectrum, Signals, SignalsFreq, PoseRaw
from audio_stack.spectrum_estimator import normalize_each_row, NORMALIZE
from audio_stack.spectrum_estimator import normalize_rows, combine_rows, NORMALIZE
from audio_stack.topic_synchronizer import TopicSynchronizer
from .live_plotter import LivePlotter

Expand All @@ -29,8 +29,8 @@ def __init__(self):
Spectrum, "audio/spectrum", self.listener_callback_spectrum, 10
)

self.subscription_combined_spectrum = self.create_subscription(
Spectrum, "audio/combined_spectrum", self.listener_callback_combined_spectrum, 10
self.subscription_dynamic_spectrum = self.create_subscription(
Spectrum, "audio/dynamic_spectrum", self.listener_callback_dynamic_spectrum, 10
)

self.plotter_dict = {}
Expand All @@ -49,10 +49,11 @@ def init_plotter(self, name, xlabel='x', ylabel='y', log=True, ymin=-np.inf, yma


def listener_callback_spectrum(self, msg_spec, name="static", eps=YLIM_MIN):
xlabel = "angle [rad]"
xlabel = "angle [deg]"
ylabel = "magnitude [-]"
self.init_plotter(f"{name} raw spectra", xlabel=xlabel, ylabel=ylabel, ymin=eps, ymax=2)
self.init_plotter(f"{name} combined spectra", xlabel=xlabel, ylabel=ylabel, ymin=eps, ymax=2)
self.init_plotter(f"{name} raw spectra heatmap", xlabel=xlabel, ylabel=ylabel, ymin=eps, ymax=2)

frequencies = np.array(msg_spec.frequencies)
spectrum = np.array(msg_spec.spectrum_vect).reshape((msg_spec.n_frequencies, msg_spec.n_angles))
Expand All @@ -64,15 +65,15 @@ def listener_callback_spectrum(self, msg_spec, name="static", eps=YLIM_MIN):
self.plotter_dict[f"{name} raw spectra"].update_lines(
spectrum[mask] + eps, theta_scan, labels=labels
)
self.plotter_dict[f"{name} raw spectra heatmap"].update_mesh(
spectrum[mask] + eps, y_labels=labels
)

# compute and plot combinations.
spectrum_sum = np.sum(spectrum, axis=0, keepdims=True)
spectrum_sum = normalize_each_row(spectrum_sum, NORMALIZE)

# need to make sure spectrum is not too small before multiplying.
spectrum_product = np.product(normalize_each_row(spectrum, "zero_to_one"),
axis=0, keepdims=True)
spectrum_product = normalize_each_row(spectrum_product, NORMALIZE)
spectrum_sum = combine_rows(spectrum, "sum", keepdims=True)
spectrum_sum = normalize_rows(spectrum_sum, NORMALIZE)
spectrum_product = combine_rows(spectrum, "product", keepdims=True)
spectrum_product = normalize_rows(spectrum_product, NORMALIZE)

spectrum_plot = np.r_[spectrum_product, spectrum_sum]
labels = ["product", "sum"]
Expand All @@ -86,12 +87,17 @@ def listener_callback_spectrum(self, msg_spec, name="static", eps=YLIM_MIN):
self.plotter_dict[f"{name} raw spectra"].update_axvlines([orientation])
self.plotter_dict[f"{name} combined spectra"].update_axvlines([orientation])

def listener_callback_combined_spectrum(self, msg_spec):
angles = np.linspace(0, 360, msg_spec.n_angles)
orientation_index = np.argmin(abs(angles - orientation))
self.plotter_dict[f"{name} raw spectra heatmap"].update_axvlines([orientation_index], color='orange')


def listener_callback_dynamic_spectrum(self, msg_spec):
return self.listener_callback_spectrum(msg_spec, name="dynamic")


def listener_callback_signals_f(self, msg):
self.init_plotter("signals frequency", xlabel="frequency [Hz]", ylabel="magnitude [-]")
self.init_plotter("signals frequency", xlabel="frequency [Hz]", ylabel="magnitude [-]", ymin=1e-10, ymax=1e3)

if msg.n_frequencies != self.current_n_frequencies:
self.plotter_dict["signals frequency"].clear()
Expand Down
32 changes: 30 additions & 2 deletions src/topic_plotter/topic_plotter/live_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import matplotlib
import matplotlib.pylab as plt
import numpy as np

matplotlib.use("TkAgg")

Expand All @@ -34,6 +35,7 @@ def __init__(self, max_ylim=MAX_YLIM, min_ylim=MIN_YLIM, log=True, label='', max
self.axvlines = {}
self.arrows = {}
self.scatter = {}
self.meshes = {}

self.fig.canvas.mpl_connect("close_event", self.handle_close)

Expand Down Expand Up @@ -117,12 +119,38 @@ def update_lines(self, data_matrix, x_data=None, labels=None):
# without this, the plot does not get updated live.
self.fig.canvas.draw()

def update_axvlines(self, data_vector):
def update_mesh(self, data_matrix, y_labels=None, name="standard"):
""" Plot each row of data_matrix in an image.
"""
if name in self.meshes.keys():
self.meshes[name].set_array(data_matrix.flatten())
else:
mesh = self.ax.pcolormesh(data_matrix)
self.meshes[name] = mesh
angles = np.linspace(0, 360, data_matrix.shape[1], dtype=str)
xticks = self.ax.get_xticks()

# for some reason, xticks has an extra element which is not shown on the plot
# so we need to exclude that from the indices.
new_xticks = angles[xticks[:-1].astype(int)]

self.ax.set_xticklabels(new_xticks)
if y_labels is not None:
yticks = self.ax.get_yticks()
new_yticks = np.array(y_labels)[yticks[:-1].astype(int)]
self.ax.set_yticklabels(new_yticks)
# without this, the plot does not get updated live.
self.fig.canvas.draw()


def update_axvlines(self, data_vector, color=None):
for i, xcoord in enumerate(data_vector):
if i in self.axvlines.keys():
self.axvlines[i].set_xdata(xcoord)
else:
axvline = self.ax.axvline(xcoord, color=f"C{i % 10}", ls=":")
if color is None:
color = f"C{i % 10}"
axvline = self.ax.axvline(xcoord, color=color, ls=":")
self.axvlines[i] = axvline

self.fig.canvas.draw()
Expand Down

0 comments on commit 75536fa

Please sign in to comment.