diff --git a/fast_plotter/postproc/functions.py b/fast_plotter/postproc/functions.py index 9666b73..4388a3c 100644 --- a/fast_plotter/postproc/functions.py +++ b/fast_plotter/postproc/functions.py @@ -201,14 +201,17 @@ def split(df, axis, keep_split_dim, return_meta=True): split the dataframe into a list of dataframes using a given binning dimensions """ - if isinstance(axis, (list, tuple)): - axis = tuple(axis) - else: - axis = (axis, ) + def to_tuple(obj): + if isinstance(obj, (list, tuple)): + return tuple(obj) + else: + return (obj, ) + axis = to_tuple(axis) logger.info("Splitting on axis: '%s'", axis) out_dfs = [] groups = df.groupby(level=axis, group_keys=keep_split_dim) for split_val, group in groups: + split_val = to_tuple(split_val) if not keep_split_dim: group.index = group.index.droplevel(axis) result = group.copy()