Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for mdraw_legend (#88, #89) #94

Merged
merged 21 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions forestplot/mplot_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,41 @@ def mdraw_legend(
mcolor: Union[Sequence[str], None] = ["0", "0.4", ".8", "0.2"],
**kwargs: Any,
) -> Axes:
"""
Add a custom legend to a matplotlib Axes object for the different models.

This function creates and adds a legend to a given Axes object, allowing for customization of
the legend's markers, colors, size, and positioning. It's particularly useful for graphs
representing different models or categories with distinct markers and colors.

Parameters
----------
ax : Axes
The matplotlib Axes object to which the legend will be added.
xlabel : Union[Sequence[str], None]
A sequence of strings for x-axis labels, used to adjust the legend position. If None, the default position is used.
modellabels : Optional[Union[Sequence[str], None]]
A sequence of strings that serve as labels for the legend entries.
msymbols : Union[Sequence[str], None], optional
A sequence of marker symbols for each legend entry, defaults to 'soDx'.
mcolor : Union[Sequence[str], None], optional
A sequence of colors for each legend entry, defaults to ["0", "0.4", ".8", "0.2"].
**kwargs : Any
Additional keyword arguments for further customization. Supported customizations include 'leg_markersize'
(size of the legend markers, default 8), 'bbox_to_anchor' (tuple specifying the anchor point of the legend),
'leg_loc' (location of the legend, default 'lower center' or 'best'), 'leg_ncol' (number of columns in the legend,
default 2 or 1), and 'leg_fontsize' (font size of legend text, default 12).

Returns
-------
Axes
The modified matplotlib Axes object with the legend added.

Notes
-----
- The 'xlabel' parameter is used to adjust the legend's position based on the presence of x-axis labels.
It does not directly set the x-axis labels.
"""
leg_markersize = kwargs.get("leg_markersize", 8)
leg_artists = []
for ix, symbol in enumerate(msymbols):
Expand Down
32 changes: 32 additions & 0 deletions tests/test_mplot_graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.lines import Line2D
from matplotlib.pyplot import Axes

from forestplot.mplot_graph_utils import (
mdraw_ci,
mdraw_est_markers,
mdraw_legend,
mdraw_ref_xline,
mdraw_yticklabels,
)
Expand Down Expand Up @@ -95,3 +97,33 @@ def test_mdraw_ci():
# Assertions
assert isinstance(ax, Axes)
assert len(ax.collections) == len(set(models_vector))

def test_mdraw_legend():
# Create a simple plot
fig, ax = plt.subplots()
ax.plot([0, 1], [0, 1], marker="o", color="0")
ax.plot([0, 1], [1, 0], marker="s", color="0.4")

# Sample parameters for the legend
modellabels = ["Model 1", "Model 2"]
msymbols = ["o", "s"]
mcolor = ["0", "0.4"]

# Call the function
ax = mdraw_legend(ax, None, modellabels, msymbols, mcolor)

# Assertions
legend = ax.get_legend()
assert legend is not None, "Legend was not created."

# Check number of legend entries
assert len(legend.get_texts()) == len(modellabels), "Incorrect number of legend entries."

# Check legend labels
for label, model_label in zip(legend.get_texts(), modellabels):
assert label.get_text() == model_label, "Legend labels do not match."

# Check legend marker colors and symbols
for line, color in zip(legend.legendHandles, mcolor):
assert isinstance(line, Line2D), "Legend entry is not a Line2D instance."
assert line.get_color() == color, "Legend marker color does not match."
Loading