diff --git a/CHANGELOG.md b/CHANGELOG.md index c425c32..f116f51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.7.0] - 2020-05-14 +### Added +- Post-processing functions that interact with curator to apply or rebin datasets, PR #36 [@benkrikler](https://github.com/benkrikler) + ## [0.6.5] - 2020-05-12 ### Added - Implement the multiply_values with a mapping, PR #35 [@benkrikler](https://github.com/benkrikler) diff --git a/fast_plotter/postproc/__init__.py b/fast_plotter/postproc/__init__.py index e69de29..052d6ed 100644 --- a/fast_plotter/postproc/__init__.py +++ b/fast_plotter/postproc/__init__.py @@ -0,0 +1,4 @@ +from .functions import open_many + + +__all__ = ["open_many"] diff --git a/fast_plotter/postproc/functions.py b/fast_plotter/postproc/functions.py index c132396..d2b8992 100644 --- a/fast_plotter/postproc/functions.py +++ b/fast_plotter/postproc/functions.py @@ -3,6 +3,7 @@ import re import numpy as np import pandas as pd +from .query_curator import prepare_datasets_scale_factor, make_dataset_map import logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -102,6 +103,18 @@ def explode(mapping, expect_depth, prefix="", depth=0): return out_df +def rebin_by_curator_cfg(df, curator_cfg, map_from="name", map_to="eventtype", + column_from="dataset", column_to=None, + default_from=None, default_to=None, error_all_missing=True): + mapping = make_dataset_map(curator_cfg, + map_from=map_from, map_to=map_to, + default_from=default_from, + default_to=default_to, + error_all_missing=error_all_missing) + df = rebin(df, axis=column_from, mapping=mapping, rename=column_to) + return df + + def split_dimension(df, axis, delimeter=";"): """ Split up a binning dimensions @@ -400,18 +413,28 @@ def multiply_values(df, constant=0, mapping={}, weight_by_dataframes=[], apply_i return df -def multiply_dataframe(df, multiply_df, use_column=None): +def multiply_dataframe(df, multiply_df, use_column=None, level=None): if isinstance(multiply_df, six.string_types): multiply_df = open_many([multiply_df], return_meta=False)[0] if use_column is not None: multiply_df = multiply_df[use_column] if isinstance(multiply_df, pd.Series): - out = df.mul(multiply_df, axis=0) + out = df.mul(multiply_df, axis=0, level=level) else: - out = df * multiply_df + out = df.mul(multiply_df, level=level) return out +def scale_datasets(df, curator_cfg, multiply_by=[], divide_by=[], + dataset_col="dataset", eventtype="mc", use_column=None): + """ + Pull fields from a fast-curator config for datasets, and use these to normalise inputs + """ + scale = prepare_datasets_scale_factor(curator_cfg, multiply_by, divide_by, dataset_col, eventtype) + result = multiply_dataframe(df, scale, use_column=use_column, level=dataset_col) + return result + + def normalise_group(df, groupby_dimensions, apply_if=None, use_column=None): logger.info("Normalising within groups defined by: %s", str(groupby_dimensions)) norm_to = 1 / df.groupby(level=groupby_dimensions).sum() diff --git a/fast_plotter/postproc/query_curator.py b/fast_plotter/postproc/query_curator.py new file mode 100644 index 0000000..ba0bdb0 --- /dev/null +++ b/fast_plotter/postproc/query_curator.py @@ -0,0 +1,60 @@ +import pandas as pd +from fast_curator import read + + +def _get_cfg(cfg): + if isinstance(cfg, list): + return cfg + return read.from_yaml(cfg) + + +def prepare_datasets_scale_factor(curator_cfg, multiply_by=[], divide_by=[], dataset_col="dataset", eventtype="mc"): + dataset_cfg = _get_cfg(curator_cfg) + + sfs = {} + for dataset in dataset_cfg: + if eventtype and dataset.eventtype not in eventtype: + sfs[dataset.name] = 1 + continue + + scale = 1 + for m in multiply_by: + scale *= float(getattr(dataset, m)) + for d in divide_by: + scale /= float(getattr(dataset, d)) + sfs[dataset.name] = scale + + sfs = pd.Series(sfs, name=dataset_col) + return sfs + + +def make_dataset_map(curator_cfg, map_from="name", map_to="eventtype", + default_from=None, default_to=None, error_all_missing=True): + dataset_cfg = _get_cfg(curator_cfg) + + mapping = {} + missing_from = 0 + missing_to = 0 + for dataset in dataset_cfg: + if hasattr(dataset, map_from): + key = getattr(dataset, map_from) + else: + key = default_from + missing_from += 1 + + if hasattr(dataset, map_to): + value = getattr(dataset, map_to) + else: + value = default_to + missing_to += 1 + + mapping[key] = value + if missing_from == len(dataset_cfg) and error_all_missing: + msg = "None of the datasets contain the 'from' field, '%s'" + raise RuntimeError(msg % map_from) + + if missing_to == len(dataset_cfg) and error_all_missing: + msg = "None of the datasets contain the 'to' field, '%s'" + raise RuntimeError(msg % map_to) + + return mapping diff --git a/fast_plotter/postproc/stages.py b/fast_plotter/postproc/stages.py index d97bc7c..612896e 100644 --- a/fast_plotter/postproc/stages.py +++ b/fast_plotter/postproc/stages.py @@ -131,6 +131,16 @@ class MultiplyValues(BaseManipulator): func = "multiply_values" +class ScaleDatasets(BaseManipulator): + cardinality = "one-to-one" + func = "scale_datasets" + + +class RebinByCuratorCfg(BaseManipulator): + cardinality = "one-to-one" + func = "rebin_by_curator_cfg" + + class NormaliseGroup(BaseManipulator): cardinality = "one-to-one" func = "normalise_group" diff --git a/fast_plotter/version.py b/fast_plotter/version.py index 148d646..bf21ade 100644 --- a/fast_plotter/version.py +++ b/fast_plotter/version.py @@ -12,5 +12,5 @@ def split_version(version): return tuple(result) -__version__ = '0.6.5' +__version__ = '0.7.0' version_info = split_version(__version__) # noqa diff --git a/setup.cfg b/setup.cfg index c65bee9..9ffd9b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.6.5 +current_version = 0.7.0 commit = True tag = False diff --git a/setup.py b/setup.py index 81eba23..2679c80 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,8 @@ def get_version(): return _globals["__version__"] -requirements = ['matplotlib', 'pandas', 'numpy', 'scipy'] +requirements = ['matplotlib', 'pandas', 'numpy', 'scipy', + 'fast-curator', 'fast-flow'] setup_requirements = ['pytest-runner', ]