Skip to content

Commit

Permalink
Merge pull request #35 from FAST-HEP/BK_implement_mapping_multiply
Browse files Browse the repository at this point in the history
Implement multiply by mapping
  • Loading branch information
benkrikler authored May 12, 2020
2 parents 11ce33b + b7bbff5 commit 9de925a
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.6.5] - 2020-05-12
### Added
- Implement the multiply_values with a mapping, PR #35 [@benkrikler](https://github.com/benkrikler)

## [0.6.4] - 2020-05-07
### Added
- New postprocessing stage to filter columns, PR #34 [@benkrikler](https://github.com/benkrikler)
Expand Down
10 changes: 7 additions & 3 deletions fast_plotter/postproc/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,17 +382,21 @@ def multiply_values(df, constant=0, mapping={}, weight_by_dataframes=[], apply_i
if mask.dtype.kind != "b":
msg = "'apply_if' statement doesn't return a boolean: %s"
raise ValueError(msg % apply_if)
index = df.index
ignored = df.loc[~mask]
df = df.loc[mask]
if mapping:
raise NotImplementedError("'mapping' option not yet implemented")
for select, value in mapping.items():
df = multiply_values(df, constant=value, apply_if=select)
if weight_by_dataframes:
for mul_df in weight_by_dataframes:
df = multiply_dataframe(df, mul_df)
if constant:
df = df * constant
numeric_cols = df.select_dtypes('number')
df = df.assign(**numeric_cols.multiply(constant))
if apply_if:
df = pd.concat([df, ignored])
df = pd.concat([df, ignored]).reindex(index=index)

return df


Expand Down
2 changes: 1 addition & 1 deletion fast_plotter/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ def split_version(version):
return tuple(result)


__version__ = '0.6.4'
__version__ = '0.6.5'
version_info = split_version(__version__) # noqa
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.6.4
current_version = 0.6.5
commit = True
tag = False

Expand Down
21 changes: 18 additions & 3 deletions tests/postproc/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,24 @@ def test_filter_cols(binned_df):
# #def merge(dfs):
# pass

# def test_multiply_values():
# #def multiply_values(df, constant=0, mapping={}, weight_by_dataframes=[], apply_if=None):
# pass
def test_multiply_values(binned_df):
result = funcs.multiply_values(binned_df, constant=3)
assert np.array_equal(result.a, np.arange(len(binned_df)) * 3)
assert np.array_equal(result.b, binned_df.b)

result = funcs.multiply_values(binned_df, apply_if="int % 2 == 0", constant=3)
assert np.array_equal(result.a, np.arange(len(binned_df)) * np.repeat([3, 1, 3, 1], 10))
assert np.array_equal(result.b, binned_df.b)

result = funcs.multiply_values(binned_df, mapping={"int % 2 == 0": 3, "int % 2 == 1": 7.2})
assert np.array_equal(result.a, np.arange(len(binned_df)) * np.repeat([3, 7.2, 3, 7.2], 10))
assert np.array_equal(result.b, binned_df.b)

result = funcs.multiply_values(binned_df, mapping={"cat=='foo'": 1.2, "cat=='bar'": 19})
tiled_vals = np.tile(np.repeat([1.2, 19], 5), 4)
assert np.array_equal(result.a, np.arange(len(binned_df)) * tiled_vals)
assert np.array_equal(result.b, binned_df.b)


# def test_multiply_dataframe():
# #def multiply_dataframe(df, multiply_df, use_column=None):
Expand Down

0 comments on commit 9de925a

Please sign in to comment.