diff --git a/forestplot/mplot_graph_utils.py b/forestplot/mplot_graph_utils.py index 52b3b03..de790d3 100644 --- a/forestplot/mplot_graph_utils.py +++ b/forestplot/mplot_graph_utils.py @@ -242,6 +242,41 @@ def mdraw_legend( mcolor: Union[Sequence[str], None] = ["0", "0.4", ".8", "0.2"], **kwargs: Any, ) -> Axes: + """ + Add a custom legend to a matplotlib Axes object for the different models. + + This function creates and adds a legend to a given Axes object, allowing for customization of + the legend's markers, colors, size, and positioning. It's particularly useful for graphs + representing different models or categories with distinct markers and colors. + + Parameters + ---------- + ax : Axes + The matplotlib Axes object to which the legend will be added. + xlabel : Union[Sequence[str], None] + A sequence of strings for x-axis labels, used to adjust the legend position. If None, the default position is used. + modellabels : Optional[Union[Sequence[str], None]] + A sequence of strings that serve as labels for the legend entries. + msymbols : Union[Sequence[str], None], optional + A sequence of marker symbols for each legend entry, defaults to 'soDx'. + mcolor : Union[Sequence[str], None], optional + A sequence of colors for each legend entry, defaults to ["0", "0.4", ".8", "0.2"]. + **kwargs : Any + Additional keyword arguments for further customization. Supported customizations include 'leg_markersize' + (size of the legend markers, default 8), 'bbox_to_anchor' (tuple specifying the anchor point of the legend), + 'leg_loc' (location of the legend, default 'lower center' or 'best'), 'leg_ncol' (number of columns in the legend, + default 2 or 1), and 'leg_fontsize' (font size of legend text, default 12). + + Returns + ------- + Axes + The modified matplotlib Axes object with the legend added. + + Notes + ----- + - The 'xlabel' parameter is used to adjust the legend's position based on the presence of x-axis labels. + It does not directly set the x-axis labels. + """ leg_markersize = kwargs.get("leg_markersize", 8) leg_artists = [] for ix, symbol in enumerate(msymbols): diff --git a/tests/test_mplot_graph_utils.py b/tests/test_mplot_graph_utils.py index 8d2626f..8caa331 100644 --- a/tests/test_mplot_graph_utils.py +++ b/tests/test_mplot_graph_utils.py @@ -1,10 +1,12 @@ import matplotlib.pyplot as plt import pandas as pd +from matplotlib.lines import Line2D from matplotlib.pyplot import Axes from forestplot.mplot_graph_utils import ( mdraw_ci, mdraw_est_markers, + mdraw_legend, mdraw_ref_xline, mdraw_yticklabels, ) @@ -95,3 +97,33 @@ def test_mdraw_ci(): # Assertions assert isinstance(ax, Axes) assert len(ax.collections) == len(set(models_vector)) + +def test_mdraw_legend(): + # Create a simple plot + fig, ax = plt.subplots() + ax.plot([0, 1], [0, 1], marker="o", color="0") + ax.plot([0, 1], [1, 0], marker="s", color="0.4") + + # Sample parameters for the legend + modellabels = ["Model 1", "Model 2"] + msymbols = ["o", "s"] + mcolor = ["0", "0.4"] + + # Call the function + ax = mdraw_legend(ax, None, modellabels, msymbols, mcolor) + + # Assertions + legend = ax.get_legend() + assert legend is not None, "Legend was not created." + + # Check number of legend entries + assert len(legend.get_texts()) == len(modellabels), "Incorrect number of legend entries." + + # Check legend labels + for label, model_label in zip(legend.get_texts(), modellabels): + assert label.get_text() == model_label, "Legend labels do not match." + + # Check legend marker colors and symbols + for line, color in zip(legend.legendHandles, mcolor): + assert isinstance(line, Line2D), "Legend entry is not a Line2D instance." + assert line.get_color() == color, "Legend marker color does not match."