From c3b8f214cf65fa388b98aca6169bf560d616b6e4 Mon Sep 17 00:00:00 2001 From: Julius Busecke Date: Fri, 1 Jul 2022 16:15:49 -0400 Subject: [PATCH] Match variable string input (#245) * Add failing test * manual addition * final test * fix the test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update whats-new.rst Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- cmip6_preprocessing/postprocessing.py | 8 +++++++- docs/whats-new.rst | 2 ++ tests/test_postprocessing.py | 21 +++++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/cmip6_preprocessing/postprocessing.py b/cmip6_preprocessing/postprocessing.py index a978de11..286ceaa6 100644 --- a/cmip6_preprocessing/postprocessing.py +++ b/cmip6_preprocessing/postprocessing.py @@ -5,7 +5,11 @@ import numpy as np import xarray as xr -from cmip6_preprocessing.utils import _key_from_attrs, cmip6_dataset_id +from cmip6_preprocessing.utils import ( + _key_from_attrs, + _maybe_make_list, + cmip6_dataset_id, +) try: @@ -604,6 +608,8 @@ def match_metrics( # metrics should never match the variable exact_attrs_wo_var = [ma for ma in exact_attrs if ma != "variable_id"] + match_variables = _maybe_make_list(match_variables) + # if match is set to exact check all these attributes if match_attrs == "exact": match_attrs = exact_attrs_wo_var diff --git a/docs/whats-new.rst b/docs/whats-new.rst index 8d7f9cc4..69a067d4 100644 --- a/docs/whats-new.rst +++ b/docs/whats-new.rst @@ -25,6 +25,8 @@ By `Tom Nicholas `_ and `Julius Busecke `_ - Fixes incompatibility with upstream changes in xarray>=0.19.0 (:issue:`173`, :pull:`174`). By `Julius Busecke `_ diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py index 68f1d95b..70ee3d16 100644 --- a/tests/test_postprocessing.py +++ b/tests/test_postprocessing.py @@ -357,6 +357,27 @@ def test_match_metrics_print_statistics(capsys, metricname): assert "No match found:" + str({metricname: 0}) in captured.out +def test_match_metrics_match_variable_str_input(): + # give a dataset that has member_id as dim (indicator that it was aggregated). + metricname = "area" + attrs = { + "source_id": "a", + "grid_label": "a", + "experiment_id": "a", + "table_id": "a", + "variant_label": "a", + "version": "a", + } + ds = random_ds() + ds.attrs = attrs + ds_metric = random_ds().rename({"data": metricname}) + ds_metric.attrs = attrs + + ds_dict_parsed_list = match_metrics({"a": ds}, {"aa": ds_metric}, [metricname]) + ds_dict_parsed_str = match_metrics({"a": ds}, {"aa": ds_metric}, metricname) + xr.testing.assert_equal(ds_dict_parsed_str["a"], ds_dict_parsed_list["a"]) + + @pytest.mark.parametrize("combine_func_kwargs", [{}, {"compat": "override"}]) def test_combine_datasets_merge(combine_func_kwargs): attrs_a = {