Skip to content

Commit

Permalink
Implement merging strategies for dicts
Browse files Browse the repository at this point in the history
The implementation follows the original pillarstack logic as for lists.
  • Loading branch information
jgraichen committed May 1, 2020
1 parent b5a6231 commit ce6f879
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 26 deletions.
71 changes: 45 additions & 26 deletions salt_tower/pillar/tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""
from __future__ import absolute_import

import collections
import copy
import errno
import logging
Expand Down Expand Up @@ -93,7 +92,7 @@ def merge(self, *args, **kwargs):
return _merge(*args, **kwargs)

def format(self, obj, *args, **kwargs):
if isinstance(obj, collections.Mapping):
if isinstance(obj, dict):
return {k: self.format(v, *args, **kwargs) for k, v in obj.items()}

if isinstance(obj, list):
Expand All @@ -119,7 +118,7 @@ def run(self, top):
if isinstance(item, str):
self._load_item(base, item)

elif isinstance(item, collections.Mapping):
elif isinstance(item, dict):
for tgt, items in item.items():
if not self._match_minion(tgt):
continue
Expand All @@ -139,7 +138,7 @@ def _match_minion(self, tgt):
def _load_top(self, top):
data = self._compile(top)

if not isinstance(data, collections.Mapping):
if not isinstance(data, dict):
LOGGER.critical("Tower top must be a dict, but is %s.", type(data))
return []

Expand All @@ -161,7 +160,7 @@ def _load_top(self, top):
return data[self.env]

def _load_item(self, base, item):
if isinstance(item, collections.Mapping):
if isinstance(item, dict):
self.update(item, merge=True)

elif isinstance(item, str):
Expand Down Expand Up @@ -220,7 +219,7 @@ def _load_file(self, file, base=None):

data = self._compile(file, context={"basedir": base})

if not isinstance(data, collections.Mapping):
if not isinstance(data, dict):
LOGGER.warning("Loading %s did not return dict, but %s", file, type(data))
return

Expand Down Expand Up @@ -297,34 +296,54 @@ def get_field(self, key, args, kwargs):
return (value, None)


def _merge(tgt, *objects):
def _merge(tgt, *objects, strategy="merge-last"):
for obj in objects:
if isinstance(tgt, collections.Mapping):
tgt = _merge_dict(tgt, obj)
if isinstance(tgt, dict):
tgt = _merge_dict(tgt, copy.deepcopy(obj), strategy)
elif isinstance(tgt, list):
tgt = _merge_list(tgt, obj)
tgt = _merge_list(tgt, copy.deepcopy(obj), strategy)
else:
raise TypeError(f"Cannot merge {type(tgt)}")

return tgt


def _merge_dict(tgt, obj):
if not isinstance(obj, collections.Mapping):
def _merge_dict(tgt, obj, strategy="merge-last"):
if not isinstance(obj, dict):
raise TypeError(f"Cannot merge non-dict type, but is {type(obj)}")

for key, val in obj.items():
if key in tgt:
if isinstance(tgt[key], collections.Mapping) and isinstance(
val, collections.Mapping
):
_merge(tgt[key], val)
elif isinstance(tgt[key], list) and isinstance(val, list):
_merge_list(tgt[key], val)
if "__" in obj:
strategy = obj.pop("__")

if strategy == "remove":
for k in obj:
if k in tgt:
tgt.pop(k)

elif strategy == "merge-last":
for key, val in obj.items():
if key in tgt and isinstance(tgt[key], dict) and isinstance(val, dict):
_merge_dict(tgt[key], val, strategy)
elif key in tgt and isinstance(tgt[key], list) and isinstance(val, list):
_merge_list(tgt[key], val, strategy)
else:
tgt[key] = copy.deepcopy(val)
else:
tgt[key] = copy.deepcopy(val)
tgt[key] = val

elif strategy == "merge-first":
for key, val in obj.items():
if key in tgt and isinstance(tgt[key], dict) and isinstance(val, dict):
_merge_dict(tgt[key], val, strategy)
elif key in tgt and isinstance(tgt[key], list) and isinstance(val, list):
_merge_list(tgt[key], val, strategy)
elif key not in tgt:
tgt[key] = val

elif strategy == "overwrite":
tgt.clear()
tgt.update(obj)

else:
raise ValueError(f"Unknown strategy: {strategy}")

return tgt

Expand All @@ -342,15 +361,15 @@ def _merge_list(tgt, lst, strategy="merge-last"):
tgt.remove(val)

elif strategy == "merge-last":
tgt.extend(copy.deepcopy(lst))
tgt.extend(lst)

elif strategy == "merge-first":
for val in lst:
tgt.insert(0, copy.deepcopy(val))
tgt.insert(0, val)

elif strategy == "overwrite":
del tgt[:]
tgt.extend(copy.deepcopy(lst))
tgt.extend(lst)

else:
raise ValueError(f"Unknown strategy: {strategy}")
Expand Down
36 changes: 36 additions & 0 deletions test/pillar/test_tower_tower.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,42 @@ def test_merge_list_strategy_merge_overwrite(tower):
assert tgt == ['c']


def test_merge_dict_strategy_remove(tower):
tgt = {'a': 0, 'b': 1}
mod = {'__': 'remove', 'a': None}

tower.merge(tgt, mod)

assert tgt == {'b': 1}


def test_merge_dict_strategy_remove_non_existant(tower):
tgt = {'a': 0, 'b': 1}
mod = {'__': 'remove', 'c': None}

tower.merge(tgt, mod)

assert tgt == {'a': 0, 'b': 1}


def test_merge_dict_strategy_merge_first(tower):
tgt = {'a': 0, 'b': 1, 'd': [4]}
mod = {'__': 'merge-first', 'a': 1, 'c': 2, 'd': [5]}

tower.merge(tgt, mod)

assert tgt == {'a': 0, 'b': 1, 'c': 2, 'd': [5, 4]}


def test_merge_dict_strategy_merge_overwrite(tower):
tgt = {'a': 0, 'b': 1}
mod = {'__': 'overwrite', 'c': 2}

tower.merge(tgt, mod)

assert tgt == {'c': 2}


def test_format(tower):
tower.update({'app': {'name': 'MyApp'}})

Expand Down

0 comments on commit ce6f879

Please sign in to comment.