diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 150a24d..72b64b0 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,7 @@ Contributors to this version: Trevor James Smith (:user:`Zeitsperre`), Marco Bra New features and enhancements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +* `figanos` now has timeseries available in the hvplot plotting library (:pull:`198`). * `figanos` now supports Python 3.12. (:pull:`210`). * Use list or ndarray as levels for colorbar in gridmap and small bug fixes (:pull:`176`). * Added style sheet ``transparent.mplstyle`` (:issue:`183`, :pull:`185`) diff --git a/docs/notebooks/figanos_hvplot.ipynb b/docs/notebooks/figanos_hvplot.ipynb new file mode 100644 index 0000000..095ded4 --- /dev/null +++ b/docs/notebooks/figanos_hvplot.ipynb @@ -0,0 +1,170 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import xarray as xr\n", + "from xclim import ensembles\n", + "xr.set_options(keep_attrs=True)" + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## HvPlot ouranos theme " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [], + "source": [ + "from figanos.hvplot.utils import set_hv_style\n", + "\n", + "set_hv_style('ouranos')" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Timeseries\n", + "\n", + "The main elements of a plot are dependent on four arguments, each accepting dictionaries:\n", + "\n", + "1. `data` : a dictionary containing the Xarray objects and their respective keys, used as labels on the plot.\n", + "2. `use_attrs`: a dictionary linking attributes from the Xarray object to plot text elements.\n", + "3. `plot_kw` : a dictionary using the same keys as `data` to pass arguments to the underlying plotting function, in this case [hvplot.line](https://hvplot.holoviz.org/reference/tabular/line.html).\n", + "4. `opts_kw`: a dictionary using the same keys as `data` plus overlay (to be passed to the combined elements of all `data` values) to pass to the underlying `.opts` [holoviz funciton](https://holoviews.org/user_guide/Customizing_Plots.html).\n", + "\n", + "When labels are passed in `data`, any 'label' argument passed in `plot_kw` will be ignored." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [], + "source": [ + "# creates dataset\n", + "\n", + "time = pd.date_range(start='1960-01-01', end='2020-01-01', periods=10)\n", + "np.random.seed(1231)\n", + "dat_1 = np.random.rand(10) * 20\n", + "dat_2 = np.random.rand(10) * 20\n", + "dat_3 = np.random.rand(10) * 20\n", + "\n", + "dt = xr.Dataset(data_vars={'tas': (['realization', \n", + " #'group', \n", + " 'time'], \n", + " #np.array([[dat_1, dat_2], [dat_2, dat_3], [dat_3, dat_1]])\n", + " np.array([dat_1, dat_2, dat_3])\n", + " )\n", + " },\n", + " coords={'time': time, \n", + " 'lat': 41,\n", + " 'lon':-73, \n", + " #'group': ['a', 'b'], \n", + " 'realization': [0, 1, 2]},\n", + " )\n", + "dt.tas.attrs={'long_name': 'Randomly generated time-series',\n", + " 'standart_name': 'air_temp',\n", + " 'description': \"Synthetic time-series\",\n", + " 'units': 'degC',}\n", + " \n", + "data2 = dt+10\n", + "\n", + "perc = ensembles.ensemble_percentiles(dt, values=[25, 50, 75], split=False)\n", + "stat = ensembles.ensemble_mean_std_max_min(dt).drop_vars('tas_stdev')\n", + "perc2 = ensembles.ensemble_percentiles(data2, values=[10, 50, 90], split=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [], + "source": [ + "from figanos.hvplot import timeseries\n", + "timeseries(data2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [], + "source": [ + "timeseries(data2, legend='in_plot', show_lat_lon='lower left')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [], + "source": [ + "timeseries({'rcp85': data2, 'rcp45': dt}, opts_kw={'overlay': {'legend_position': 'bottom_right', 'legend_cols': 2}}, show_lat_lon='lower left')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [], + "source": [ + "timeseries({'val1': data2, 'val2': dt}, plot_kw={'val1': {'color': 'yellow'}, 'val2': {'line_width': 5, 'color': 'teal'}}, opts_kw={'overlay': {'legend_position': 'right'}}, show_lat_lon='lower left')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [], + "source": [ + "timeseries({'ssp245': perc, 'ssp585': perc2}, legend='full' , show_lat_lon=False, opts_kw={'overlay': {'legend_position': 'right'}})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/environment-dev.yml b/environment-dev.yml index c0de91f..ccd1789 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -18,6 +18,11 @@ dependencies: - scikit-image - xarray - xclim >=0.47 + # for interactive figures + - hvplot + - holoviews + - geoviews + - bokeh # To make the package and notebooks usable - dask - h5py diff --git a/src/figanos/__init__.py b/src/figanos/__init__.py index 13ad4ad..29b27a8 100644 --- a/src/figanos/__init__.py +++ b/src/figanos/__init__.py @@ -22,6 +22,6 @@ __email__ = "bourdeau-goulet.sarah-claude@ouranos.ca" __version__ = "0.3.1-dev.11" -from . import matplotlib +from . import hvplot, matplotlib from ._data import data from ._logo import Logos diff --git a/src/figanos/hvplot/__init__.py b/src/figanos/hvplot/__init__.py new file mode 100644 index 0000000..d19b2c8 --- /dev/null +++ b/src/figanos/hvplot/__init__.py @@ -0,0 +1,4 @@ +"""Figanos hvplot plotting module.""" + +from .plot import timeseries +from .utils import get_hv_styles, set_hv_style diff --git a/src/figanos/hvplot/plot.py b/src/figanos/hvplot/plot.py new file mode 100644 index 0000000..4ce4eb5 --- /dev/null +++ b/src/figanos/hvplot/plot.py @@ -0,0 +1,375 @@ +"""Hvplot figanos plotting functions.""" + +import copy +import warnings +from functools import partial +from pathlib import Path +from typing import Any + +import holoviews as hv +import hvplot # noqa: F401 +import hvplot.xarray # noqa: F401 +import xarray as xr + +from figanos.matplotlib.utils import ( + check_timeindex, + fill_between_label, + get_array_categ, + get_scen_color, + sort_lines, +) + +from .utils import ( + add_default_opts_overlay, + create_dict_timeseries, + curve_hover_hook, + curve_hover_tool, + defaults_curves, + formatters_data, + get_all_values, + get_glyph_param, + set_plot_attrs_hv, + x_timeseries, +) + + +def _plot_ens_reals( + name: str, + array_categ: dict[str, str], + arr: xr.DataArray, + non_dict_data: bool, + plot_kw: dict[str, Any], + opts_kw: dict[str, Any] | list, + form: str, + use_attrs: dict[str, Any], +) -> dict: + """Plot realizations ensembles""" + hv_fig = {} + + if array_categ[name] == "ENS_REALS_DS": + if len(arr.data_vars) >= 2: + raise TypeError( + "To plot multiple ensembles containing realizations, use DataArrays outside a Dataset" + ) + else: + arr = arr[list(arr.data_vars)[0]] + + if non_dict_data is True: + if not ( + "groupby" in plot_kw[name].keys() + and plot_kw[name]["groupby"] == "realization" + ): + plot_kw[name] = {"by": "realization", "x": "time"} | plot_kw[ + name + ] # why did i put this two times? + plot_kw[name] = {"by": "realization", "x": "time"} | plot_kw[name] + opts_kw[name].setdefault( + "hooks", + [ + partial( + curve_hover_hook, + att=use_attrs, + form=form, + x=list(plot_kw.values())[-1]["x"], + ) + ], + ) + return arr.hvplot.line(**plot_kw[name]).opts(**opts_kw[name]) + else: + plot_kw[name].setdefault("label", name) + # opts_kw[name].setdefault("hooks", [ + # partial(curve_hover_hook, att=use_attrs, form=form, x=list(plot_kw.values())[-1]["x"])]) + for r in arr.realization: + hv_fig[f"realization_{r.values.item()}"] = ( + arr.sel(realization=r) + .hvplot.line(hover=False, **plot_kw[name]) + .opts( + tools=curve_hover_tool(use_attrs, form, r=r.values.item()), + **opts_kw[name], + ) + ) + return hv_fig + + +def _plot_ens_pct_stats( + name: str, + arr: xr.DataArray, + array_categ: dict[str, str], + array_data: dict[str, xr.DataArray], + plot_kw: dict[str, Any], + opts_kw: dict[str, Any], + legend: str, + form: str, + use_attrs: dict[str, Any], + sub_name: str | None = None, +) -> dict: + """Plot ensembles with percentiles and statistics (min/moy/max)""" + hv_fig = {} + + # create a dictionary labeling the middle, upper and lower line + sorted_lines = sort_lines(array_data) + + # which label to use + if sub_name: + lab = sub_name + else: + lab = name + + plot_kw_line = copy.deepcopy(plot_kw[name]) + plot_kw_line = {"label": lab, "hover": False} | plot_kw_line + # plot + hv_fig["line"] = ( + array_data[sorted_lines["middle"]] + .hvplot.line(**plot_kw_line) + .opts(**opts_kw[name]) + ) + + c = get_glyph_param(hv_fig["line"], "line_color") + lab_area = fill_between_label(sorted_lines, name, array_categ, legend) + opts_kw_area = copy.deepcopy(opts_kw[name]) + opts_kw_area.setdefault("tools", curve_hover_tool(use_attrs, form)) + plot_kw[name].setdefault("color", c) + if "ENS_PCT_DIM" in array_categ[name]: + arr = arr.to_dataset(dim="percentiles") + arr = arr.rename({k: str(k) for k in arr.keys()}) + hv_fig["area"] = arr.hvplot.area( + y=sorted_lines["lower"], + y2=sorted_lines["upper"], + label=lab_area, + line_color=None, + alpha=0.2, + **plot_kw[name], + ).opts(**opts_kw_area) + return hv_fig + + +def _plot_timeseries( + name: str, + arr: xr.DataArray | xr.Dataset, + array_categ: dict[str, str], + plot_kw: dict[str, Any], + opts_kw: dict[str, Any], + non_dict_data: bool, + legend: str, + form: str, + use_attrs: dict[str, Any], +) -> dict | hv.element.chart.Curve | hv.core.overlay.Overlay: + """Plot time series from 1D Xarray Datasets or DataArrays as line plots.""" + hv_fig = {} + + if ( + array_categ[name] == "ENS_REALS_DA" or array_categ[name] == "ENS_REALS_DS" + ): # ensemble with 'realization' dim, as DataArray or Dataset + return _plot_ens_reals( + name, + array_categ, + arr, + non_dict_data, + plot_kw, + opts_kw, + form, + use_attrs, + ) + elif ( + array_categ[name] == "ENS_PCT_DIM_DS" + ): # ensemble percentiles stored as dimension coordinates, DataSet + for k, sub_arr in arr.data_vars.items(): + sub_name = ( + sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name) + ) + hv_fig[sub_name] = {} + # extract each percentile array from the dims + array_data = {} + for pct in sub_arr.percentiles.values: + array_data[str(pct)] = sub_arr.sel(percentiles=pct) + + hv_fig[sub_name] = _plot_ens_pct_stats( + name, + sub_arr, + array_categ, + array_data, + plot_kw, + opts_kw, + legend, + form, + use_attrs, + sub_name, + ) + elif array_categ[name] in [ + "ENS_PCT_VAR_DS", # ensemble statistics (min, mean, max) stored as variables + "ENS_STATS_VAR_DS", # ensemble percentiles stored as variables + "ENS_PCT_DIM_DA", # ensemble percentiles stored as dimension coordinates, DataArray + ]: + # extract each array from the datasets + array_data = {} + if array_categ[name] == "ENS_PCT_DIM_DA": + for pct in arr.percentiles: + array_data[str(int(pct))] = arr.sel(percentiles=int(pct)) + else: + for k, v in arr.data_vars.items(): + array_data[k] = v + + return _plot_ens_pct_stats( + name, + arr, + array_categ, + array_data, + plot_kw, + opts_kw, + legend, + form, + use_attrs, + ) + # non-ensemble Datasets + elif array_categ[name] == "DS": + ignore_label = False + for k, sub_arr in arr.data_vars.items(): + sub_name = ( + sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name) + ) + # if kwargs are specified by user, all lines are the same and we want one legend entry + if plot_kw[name]: + label = name if not ignore_label else "" + ignore_label = True + else: + label = sub_name + hv_fig[sub_name] = sub_arr.hvplot.line( + x="time", label=label, **plot_kw[name] + ).opts(**opts_kw[name]) + + # non-ensemble DataArrays + elif array_categ[name] in ["DA"]: + return arr.hvplot.line(label=name, **plot_kw[name]).opts(**opts_kw[name]) + else: + raise ValueError( + "Data structure not supported" + ) # can probably be removed along with elif logic above, + # given that get_array_categ() also does this check + if hv_fig: + return hv_fig + + +def timeseries( + data: dict[str, Any] | xr.DataArray | xr.Dataset, + use_attrs: dict[str, Any] | None = None, + plot_kw: dict[str, Any] | None = None, + opts_kw: dict[str, Any] | None = None, + legend: str = "lines", + show_lat_lon: bool | str | int | tuple[float, float] = True, +) -> hv.element.chart.Curve | hv.core.overlay.Overlay: + """Plot time series from 1D Xarray Datasets or DataArrays as line plots. + + Parameters + ---------- + data : dict or Dataset/DataArray + Input data to plot. It can be a DataArray, Dataset or a dictionary of DataArrays and/or Datasets. + use_attrs : dict, optional + A dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description'). + Default value is {'title': 'description', 'ylabel': 'long_name', 'yunits': 'units'}. + Only the keys found in the default dict can be used. + plot_kw : dict, optional + Arguments to pass to the `hvplot.line()` or hvplot.area() function. Changes how the line looks. + If 'data' is a dictionary, must be a nested dictionary with the same keys as 'data'. + opts_kw: dict, optional + Arguments to pass to the `holoviews/hvplot.opts()` function. Changes figure options and access to hooks. + If 'data' is a dictionary, must be a nested dictionary with the same keys as 'data' to pass nested to each + individual figure or key 'overlay' to pass to overlayed figures. + legend : str (default 'lines') or dict + 'full' (lines and shading), 'lines' (lines only), 'in_plot' (end of lines), 'none' (no legend). + show_lat_lon : bool, tuple, str or int + If True, show latitude and longitude at the bottom right of the figure. + Can be a tuple of axis coordinates (from 0 to 1, as a fraction of the axis length) representing + the location of the text. If a string or an int, the same values as those of the 'loc' parameter + of matplotlib's legends are accepted. + + Returns + ------- + hvplot.Overlay + + """ + # timeseries dict/data + use_attrs, data, plot_kw, opts_kw, non_dict_data = create_dict_timeseries( + use_attrs, data, plot_kw, opts_kw + ) + + # add ouranos default cycler colors + defaults_curves() + + # assign keys to plot_kw if not there + if non_dict_data is False: + for name in data: + if name not in plot_kw: + plot_kw[name] = {} + warnings.warn( + f"Key {name} not found in plot_kw. Using empty dict instead." + ) + if name not in opts_kw: + opts_kw[name] = {} + warnings.warn( + f"Key {name} not found in opts_kw. Using empty dict instead." + ) + for key in plot_kw: + # add "x" to plot_kw if not there + x_timeseries(data[key], plot_kw[key]) + if key not in data: + raise KeyError( + 'plot_kw must be a nested dictionary with keys corresponding to the keys in "data"' + ) + else: + x_timeseries(data["_no_label"], plot_kw["_no_label"]) + + # check: type + for name, arr in data.items(): + if not isinstance(arr, (xr.Dataset, xr.DataArray)): + raise TypeError( + '"data" must be a xr.Dataset, a xr.DataArray or a dictionary of such objects.' + ) + + # check: 'time' dimension and calendar format + data = check_timeindex(data) + + # add use attributes defaults + use_attrs, plot_kw = set_plot_attrs_hv(use_attrs, list(data.values())[-1], plot_kw) + + # dict of array 'categories' + array_categ = {name: get_array_categ(array) for name, array in data.items()} + + # formatters for data + form = formatters_data(data) + + # dictionary of hvplots plots + figs = {} + + # get data and plot + for name, arr in data.items(): + # remove 'label' to avoid error due to double 'label' args + if "label" in plot_kw[name]: + del plot_kw[name]["label"] + warnings.warn(f'"label" entry in plot_kw[{name}] will be ignored.') + + # SSP, RCP, CMIP model colors + cat_colors = ( + Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json" + ) + if get_scen_color(name, cat_colors): + plot_kw[name].setdefault("color", get_scen_color(name, cat_colors)) + + figs[name] = _plot_timeseries( + name, + arr, + array_categ, + plot_kw, + opts_kw, + non_dict_data, + legend, + form, + use_attrs, + ) + + # overlay opts_kw + if "overlay" not in opts_kw: + opts_kw["overlay"] = {} + opts_overlay = add_default_opts_overlay( + opts_kw["overlay"].copy(), legend, show_lat_lon, list(data.values())[0] + ) + return hv.Overlay(list(get_all_values(figs))).opts(**opts_overlay) diff --git a/src/figanos/hvplot/style/ouranos.yml b/src/figanos/hvplot/style/ouranos.yml new file mode 100644 index 0000000..66f5235 --- /dev/null +++ b/src/figanos/hvplot/style/ouranos.yml @@ -0,0 +1,93 @@ +#Ouranos bokeh theme +#caliber matplotlib version: https://github.com/bokeh/bokeh/blob/8a5a35eb078e386c14b5580a51e8a6673d57d197/src/bokeh/themes/_caliber.py +#refer to IPCC style guide + +attrs: + Plot: + background_fill_color: white + border_fill_color: white #color around plot - could be null to match behind plot + outline_line_color: !!null + + Axis: + major_tick_in: 4 + major_tick_out: 0 + major_tick_line_alpha: 1 + major_tick_line_color: black + + minor_tick_line_alpha: 0 + minor_tick_line_color: !!null + + axis_line_alpha: 1 + axis_line_color: black + + major_label_text_color: black + major_label_text_font: DejaVu Sans + major_label_text_font_size: 10pt + major_label_text_font_style: normal + + axis_label_standoff: 10 + axis_label_text_color: black + axis_label_text_font: DejaVu Sans + axis_label_text_font_size: 11pt + axis_label_text_font_style: normal + + Legend: + spacing: 8 + glyph_width: 15 + + label_standoff: 8 + label_text_color: black + label_text_font: DejaVu Sans + label_text_font_size: 10pt + label_text_font_style: normal + + title_text_font: DejaVu Sans + title_text_font_style: normal + title_text_font_size: 11pt + title_text_color: black + + border_line_alpha: 0 + background_fill_alpha: 1 + background_fill_color: white + + BaseColorBar: + title_text_color: black + title_text_font: DejaVu Sans + title_text_font_size: 11pt + title_text_font_style: normal + + major_label_text_color: black + major_label_text_font: DejaVu Sans + major_label_text_font_size: 10pt + major_label_text_font_style: normal + + major_tick_line_alpha: 0 + bar_line_alpha: 0 + + Grid: + grid_line_width: 0 + grid_line_color: black + grid_line_alpha: 0.4 + + Title: + text_color: black + text_font: DejaVu Sans + text_font_style: normal + text_font_size: 14pt + + Toolbar: + logo: !!null + autohide: True + + figure: + toolbar_location: below + + CategoricalColorMapper: + palette: + - '#052946' + - '#ce153d' + - '#18bbbb' + - '#fdc414' + - '#6850af' + - '#196a5e' + - '#7a5315' diff --git a/src/figanos/hvplot/utils.py b/src/figanos/hvplot/utils.py new file mode 100644 index 0000000..a50d326 --- /dev/null +++ b/src/figanos/hvplot/utils.py @@ -0,0 +1,586 @@ +"""Utility functions for figanos hvplot figure-creation.""" + +import collections.abc +import pathlib +import warnings +from functools import partial +from typing import Any + +import holoviews as hv +import xarray as xr +from bokeh.models import ( + ColumnDataSource, + GlyphRenderer, + HoverTool, + Label, + LegendItem, + Line, + Range1d, + Text, +) +from bokeh.themes import Theme + +from figanos.matplotlib.utils import ( + convert_scen_name, + empty_dict, + get_attributes, + process_keys, + wrap_text, +) + + +def get_hv_styles() -> dict[str, str]: + """Get the available matplotlib styles and their paths, as a dictionary.""" + folder = pathlib.Path(__file__).parent / "style/" + paths = sorted(p for ext in ["*.yaml", "*.json", "*.yml"] for p in folder.glob(ext)) + names = [ + str(p) + .split("/")[-1] + .removesuffix(".yaml") + .removesuffix(".yml") + .removesuffix(".json") + for p in paths + ] + return {str(name): path for name, path in zip(names, paths)} + + +def set_hv_style(*args: str | dict) -> None: + """Set the holoviews bokeh style using a yaml file or a dict. + + Parameters + ---------- + args : str or dict + Name(s) of figanos bokeh style ('ouranos'), build-ing bokeh theme, path(s) to json or yaml or dict. + + Returns + ------- + None + """ + for s in args: + if isinstance(s, dict): + hv.renderer("bokeh").theme = Theme(json=s) + elif s.endswith(".json") is True or s.endswith(".yaml") is True: + hv.renderer("bokeh").theme = Theme(filename=s) + elif s in get_hv_styles(): + hv.renderer("bokeh").theme = Theme(get_hv_styles()[s]) + elif s in [ + "light_minimal", + "dark_minimal", + "caliber", + "night_sky", + "contrast", + ]: # bokeh build in themes + hv.renderer("bokeh").theme = s + else: + warnings.warn(f"Style {s} not found.") + + # Add Ouranos defaults that can't be directly added to the bokeh theme yaml file + if "ouranos" in args: + defaults_curves() + + +def defaults_curves() -> None: + """Adds Ouranos defaults to curves that can't be added to bokeh theme.""" + return hv.opts.defaults( + hv.opts.Curve( + color=hv.Cycle( + [ + "#052946", + "#ce153d", + "#18bbbb", + "#fdc414", + "#6850af", + "#196a5e", + "#7a5315", + ] + ), # ouranos colors + gridstyle={"ygrid_line_width": 0.5}, + show_grid=True, + ) + ) + + +def get_glyph_param(hplt, param) -> str: + """Returns bokeh glyph parameters from hvplot object.""" + # Get the Bokeh renderer + renderer = hv.renderer("bokeh") + plot = renderer.get_plot(hplt) + return getattr(plot.handles["glyph"], param) + + +def hook_real_dict(plot, element) -> None: + """Creates hook between hvplot and bokeh to have custom legend to link all realizations insde a key to the same label in legend.""" + # Iterate over the elements in the overlay to get their labels + if isinstance(element, hv.Overlay): + labels = [] + for sub_element in element.values(): + labels.append( + sub_element.label + ) # would be better if added check for LineType of sub elements + + rends = {} # store glyph to create legend + colors = [] # store new colors to know which glyph to add to legend + n = -1 + for renderer in plot.handles["plot"].renderers: + if isinstance(renderer.glyph, Line): + if renderer.glyph.line_color not in colors: + n += 1 # would be better if found a link between label and glyphs... + colors.append(renderer.glyph.line_color) + rends[labels[n]] = [renderer] + else: + rends[labels[n]].append(renderer) + + plot.state.legend.items = [ + LegendItem(label=k, renderers=v) + for k, v in zip(list(rends.keys()), list(rends.values())) + ] + + +def edge_legend(plot, element) -> None: + """Creates hook between hvplot and bokeh to have custom legend to link all realizations insde a key to the same label in legend.""" + l_txt = [] + + for renderer in plot.handles["plot"].renderers: + + if isinstance(renderer.glyph, Line): + data = renderer.data_source.data + ll = {t: v[-1] for t, v in data.items()} + + # if only one curve does not have label linked to the curve in z + if len(data) < 3 and isinstance(element.label, str): + txt = element.label + else: + txt = f"{list(ll.keys())[-1]}: {list(ll.values())[-1]}" + + l_txt.append(len(txt)) # store len of texts + + source = ColumnDataSource( + data=dict( + x=[ll[renderer.glyph.x]], y=[ll[renderer.glyph.y]], text=[txt] + ) + ) + + tt = Text( + x="x", + y="y", + text="text", + text_align="left", + text_baseline="middle", + text_color=renderer.glyph.line_color, + text_font_size="12px", + x_offset=5, + ) + glyph_renderer = GlyphRenderer(data_source=source, glyph=tt) + plot.state.renderers.append(glyph_renderer) + + if plot.state.legend: + plot.state.legend.items = [] + + # increase x_range to show text when plotting + plot.state.x_range = Range1d( + start=plot.state.x_range.start, + end=plot.state.x_range.end + + 0.0125 * max(l_txt) * (plot.state.x_range.end - plot.state.x_range.start), + ) + + +def get_all_values(nested_dictionary) -> list: + """Get all values from a nested dictionary.""" + for key, value in nested_dictionary.items(): + if isinstance(value, collections.abc.Mapping): + yield from get_all_values(value) + else: + yield value + + +def curve_hover_hook(plot, element, att, form, x) -> None: + """Hook function to be passed to hvplot.opts to modify hover tooltips.""" + for hov_id, hover in plot.handles["hover_tools"].items(): + if hover.tooltips[0][0] != x: + hover.tooltips[-2:] = [ + (att["xhover"], "$x{%F}"), + (att["yhover"], "$y{" + form + "}"), + ] + else: + hover.tooltips = [ + (att["xhover"], "$x{%F}"), + (att["yhover"], "$y{" + form + "}"), + ] + hover.formatters = { + "$x": "datetime", + } + + +def curve_hover_tool(att, form, r=None) -> list[HoverTool]: + """Tool to be passed to hvplot.opts to modify hover tooltips.""" + tips = [ + (att["xhover"], "$x{%F}"), + (att["yhover"], "$y{" + form + "}"), + ] + if r is not None: + tips.insert(0, ("realization", f"{r}")) + return [HoverTool(tooltips=tips, formatters={"$x": "datetime"})] + + +# can probably delete this function +def rm_curve_hover_hook(plot, element) -> None: + """Hook to remove hover curve.""" + plot.handles["hover"].tooltips = None + + +def get_min_max(data) -> tuple: + """Get min and max values from data.""" + minn = [] + maxx = [] + if isinstance(data, dict): + for v in data.values(): + if isinstance(v, xr.Dataset): + for vv in v.values(): + minn.append(vv.min().values.item()) + maxx.append(vv.max().values.item()) + else: + minn.append(v.min().values.item()) + maxx.append(v.max().values.item()) + elif isinstance(data, xr.Dataset): + for v in data.values(): + minn.append(v.min().values.item()) + maxx.append(v.max().values.item()) + else: + minn.append(data.min().values.item()) + maxx.append(data.max().values.item()) + return min(minn), max(maxx) + + +def formatters_data(data) -> str: + """Get the correct formatter for the data.""" + ymin, ymax = get_min_max(data) + diff = ymax - ymin + + if abs(ymin) > 1000: + form = "0 a" + if diff < 1000 and diff > 100: + form = "0.00 a" + elif diff <= 100: + form = "0.000 a" + elif diff > 50: + form = "0" + elif 10 < diff <= 50: + form = "0.0" + elif 1 < diff <= 10: + form = "0.00" + elif diff <= 1: + form = "0.000" + else: + form = "0.00" + return form + + +def add_default_opts_overlay(opts_kw, legend, show_lat_lon, data) -> dict: + """Add default opts to curve plot. + + Parameters + ---------- + opts_kw : dict + Custom options to be passed to opts() of holoviews overlay figure. + legend : str + Type of legend. + show_lat_lon : bool, tuple, str or int + Show latitude and longitude on figure + data : xr.DataArray | xr.Dataset | dict + Data to be plotted. + + Returns + ------- + dict + + """ + if legend == "edge": + warnings.warn( + "Legend 'edge' is not supported in hvplot. Using 'in_plot' instead." + ) + legend = "in_plot" + elif legend == "in_plot": + opts_kw["show_legend"] = False + if "hooks" in list(opts_kw.keys()): + opts_kw["hooks"].append(edge_legend) + else: + opts_kw["hooks"] = [edge_legend] + if not legend: + opts_kw.setdefault("show_legend", False) + + if show_lat_lon: + sll = plot_coords( + list(data.values())[0], + loc=show_lat_lon, + param="location", + backgroundalpha=1, + ) + if "hooks" in list(opts_kw.keys()): + opts_kw["hooks"].append(sll) + else: + opts_kw["hooks"] = [sll] + + return opts_kw + + +def create_dict_timeseries( + use_attrs, data, plot_kw, opts_kw +) -> [dict, dict, dict, dict, dict]: + """Create default dicts for timeseries plot.""" + # convert SSP, RCP, CMIP formats in keys + if isinstance(data, dict): + data = process_keys(data, convert_scen_name) + if isinstance(plot_kw, dict): + plot_kw = process_keys(plot_kw, convert_scen_name) + if isinstance(opts_kw, dict): + opts_kw = process_keys(opts_kw, convert_scen_name) + + # create empty dicts if None + use_attrs = empty_dict(use_attrs) + opts_kw = empty_dict(opts_kw) + plot_kw = empty_dict(plot_kw) + + # if only one data input, insert in dict. + non_dict_data = False + if not isinstance(data, dict): + non_dict_data = True + data = {"_no_label": data} # mpl excludes labels starting with "_" from legend + plot_kw = {"_no_label": plot_kw} + opts_kw = {"_no_label": opts_kw} + + # add overlay option if absent in opts_ke + opts_kw.setdefault("overlay", {}) + return use_attrs, data, plot_kw, opts_kw, non_dict_data + + +def x_timeseries(data, plot_kw) -> None: + """Get x coordinate for timeseries plot.""" + if "x" not in plot_kw.keys(): + if "time" in data.coords: + plot_kw["x"] = "time" + elif "month" in data.coords: + plot_kw["x"] = "month" + elif "season" in data.coords: + plot_kw["x"] = "season" + elif "year" in data.coords: + plot_kw["x"] = "year" + elif "dayofyear" in data.coords: + plot_kw["x"] = "dayofyear" + elif "annual_cycle" in data.coords: + plot_kw["x"] = "annual_cycle" + elif "x" in data.coords: + plot_kw["x"] = "x" + else: + raise ValueError( + "None if these coordinates; time, month, year," + " season, dayofyear, annual_cycle and x were found in data." + "Please specify x coordinate in plot_kw." + ) + + +def set_plot_attrs_hv( + use_attrs: dict[str, Any], + xr_obj: xr.DataArray | xr.Dataset, + plot_kw: dict[str, Any], + wrap_kw: dict[str, Any] | None = None, +) -> [dict, dict]: + """Set plot attributes with the last plot_kw entry based on use_attr.""" + # set default use_attrs values + use_attrs = { + "title": "description", + "ylabel": "long_name", + "yunits": "units", + "yhover": "standart_name", + } | use_attrs + + wrap_kw = empty_dict(wrap_kw) + + # last plot_kw entry + name = list(plot_kw.keys())[0] + + for key in use_attrs.keys(): + if key not in [ + "title", + "ylabel", + "yunits", + "xunits", + "xhover", + "yhover", + # "xlabel", + # "cbar_label", + # "cbar_units", + # "suptitle", + ]: + warnings.warn(f'Use_attrs element "{key}" not supported') + + if "title" in use_attrs: + title = get_attributes(use_attrs["title"], xr_obj) + plot_kw[name].setdefault("title", wrap_text(title, **wrap_kw)) + + if "ylabel" in use_attrs: + if ( + "yunits" in use_attrs + and len(get_attributes(use_attrs["yunits"], xr_obj)) >= 1 + ): # second condition avoids '[]' as label + ylabel = wrap_text( + get_attributes(use_attrs["ylabel"], xr_obj) + + " (" + + get_attributes(use_attrs["yunits"], xr_obj) + + ")" + ) + else: + ylabel = wrap_text(get_attributes(use_attrs["ylabel"], xr_obj)) + + plot_kw[name].setdefault("ylabel", ylabel) + + if "xlabel" in use_attrs: + if ( + "xunits" in use_attrs + and len(get_attributes(use_attrs["xunits"], xr_obj)) >= 1 + ): # second condition avoids '[]' as label + xlabel = wrap_text( + get_attributes(use_attrs["xlabel"], xr_obj) + + " (" + + get_attributes(use_attrs["xunits"], xr_obj) + + ")" + ) + else: + xlabel = wrap_text(get_attributes(use_attrs["xlabel"], xr_obj)) + else: + xlabel = plot_kw[name]["x"] + plot_kw[name].setdefault("xlabel", xlabel) + + if "yhover" in use_attrs: + if ( + "yunits" in use_attrs + and len(get_attributes(use_attrs["yhover"], xr_obj)) >= 1 + ): # second condition avoids '[]' as label + yhover = wrap_text( + get_attributes(use_attrs["yhover"], xr_obj) + + " (" + + get_attributes(use_attrs["yunits"], xr_obj) + + ")" + ) + else: + yhover = get_attributes(use_attrs["yhover"], xr_obj) + use_attrs["yhover"] = yhover + + if "xhover" in use_attrs: + if ( + "xunits" in use_attrs + and len(get_attributes(use_attrs["xunits"], xr_obj)) >= 1 + ): # second condition avoids '[]' as label + xhover = wrap_text( + get_attributes(use_attrs["xhover"], xr_obj) + + " (" + + get_attributes(use_attrs["xunits"], xr_obj) + + ")" + ) + else: + xhover = wrap_text(get_attributes(use_attrs["xhover"], xr_obj)) + else: + xhover = plot_kw[name]["x"] + use_attrs["xhover"] = xhover + return use_attrs, plot_kw + + +def plot_coords_hook(plot, element, text, loc, bgc) -> None: + """Hook to add text to plot. Use hooks to have access to screen units.""" + pk = {} + pk["background_fill_alpha"] = bgc + + if isinstance(loc, str): + pk["x_units"] = "screen" + pk["y_units"] = "screen" + + width, height = plot.state.width, plot.state.height + + if loc == "center": + pk["y"] = (height - 150) / 2 + pk["x"] = (width - 300) / 2 + else: + if "upper" in loc: + pk["y"] = height - 150 + if "lower" in loc: + pk["y"] = 10 + if "right" in loc: + pk["x"] = width - 180 + pk["text_align"] = "right" + if "left" in loc: + pk["x"] = 10 + pk["text_align"] = "left" + if "center" in loc: + if loc[0] == "c": + pk["y"] = (height - 150) / 2 + else: + pk["x"] = (width - 300) / 2 + elif isinstance(loc, tuple): + pk["x_units"] = "data" + pk["y_units"] = "data" + pk["x"] = loc[0] + pk["y"] = loc[1] + + label = Label(text=text, **pk) + plot.state.add_layout(label) + + +def plot_coords( + xr_obj: xr.DataArray | xr.Dataset, + loc: str | tuple[float, float] | int, + param: str | None = None, + backgroundalpha: float = 1, +) -> hv.Text: + """Plot the coordinates of an xarray object. + + Parameters + ---------- + xr_obj : xr.DataArray | xr.Dataset + The xarray object from which to plot the coordinates. + loc : str | tuple[float, float] | int + Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html. + If a tuple, must be in axes coordinates. + param : {"location", "time"}, optional + The parameter used. + backgroundalpha : float + Transparency of the text background. 1 is opaque, 0 is transparent. + projection: str + Custom projection. Default is None. + + Returns + ------- + hv.Text + """ + text = None + if param == "location": + if "lat" in xr_obj.coords and "lon" in xr_obj.coords: + text = "lat={:.2f}, lon={:.2f}".format( + float(xr_obj["lat"]), float(xr_obj["lon"]) + ) + else: + warnings.warn( + 'show_lat_lon set to True, but "lat" and/or "lon" not found in coords' + ) + if param == "time": + if "time" in xr_obj.coords: + text = str(xr_obj.time.dt.strftime("%Y-%m-%d").values) + + else: + warnings.warn('show_time set to True, but "time" not found in coords') + + if isinstance(loc, int): + equiv = { + 1: "upper right", + 2: "upper left", + 3: "lower left", + 4: "lower right", + 6: "center left", + 7: "center right", + 8: "lower center", + 9: "upper center", + 10: "center", + } + loc = equiv[loc] + if isinstance(loc, bool): + loc = "lower left" + + return partial(plot_coords_hook, text=text, loc=loc, bgc=backgroundalpha) diff --git a/src/figanos/matplotlib/__init__.py b/src/figanos/matplotlib/__init__.py index 974fcda..b4f2e86 100644 --- a/src/figanos/matplotlib/__init__.py +++ b/src/figanos/matplotlib/__init__.py @@ -1,4 +1,4 @@ -"""Figanos plotting module.""" +"""Figanos matplotlib plotting module.""" from .plot import ( gdfmap,