From 335bf5b23acf19fe5c2b1b011579ab5d45782dd6 Mon Sep 17 00:00:00 2001 From: Mukundan314 Date: Wed, 6 Nov 2024 20:33:29 +0530 Subject: [PATCH] Refactor LazySegmentTree to be easier to modify --- pyrival/data_structures/LazySegmentTree.py | 76 ++++++++++++++-------- 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/pyrival/data_structures/LazySegmentTree.py b/pyrival/data_structures/LazySegmentTree.py index f6a5221..842578f 100644 --- a/pyrival/data_structures/LazySegmentTree.py +++ b/pyrival/data_structures/LazySegmentTree.py @@ -1,56 +1,74 @@ class LazySegmentTree: - def __init__(self, data, default=0, func=max): + def __init__(self, data): """initialize the lazy segment tree with data""" - self._default = default - self._func = func - self._len = len(data) self._size = _size = 1 << (self._len - 1).bit_length() - self._lazy = [0] * (2 * _size) + self._buffer_idx = 2 * _size - self.data = [default] * (2 * _size) + self.lazy = [0] * (2 * _size + 1) + self.data = [0] * (2 * _size + 1) self.data[_size:_size + self._len] = data for i in reversed(range(_size)): - self.data[i] = func(self.data[i + i], self.data[i + i + 1]) + self._merge_data(i + i, i + i + 1, i) def __len__(self): return self._len - def _push(self, idx): - """push query on idx to its children""" - # Let the children know of the queries - q, self._lazy[idx] = self._lazy[idx], 0 + def _get_range(self, a): + shift = self._size.bit_length() - a.bit_length() + return a << shift, (a << shift) + (1 << shift) + + def _unset_lazy(self, a): + """a: lazy_idx; unset a""" + self.lazy[a] = 0 + + def _apply_to_data(self, a, b): + """a: lazy_idx, b: data_idx; apply a to b""" + l, r = self._get_range(b) + self.data[b] += self.lazy[a] * (r - l) + + def _apply_to_lazy(self, a, b): + """a: lazy_idx, b: lazy_idx; apply a to b""" + self.lazy[b] += self.lazy[a] - self._lazy[2 * idx] += q - self._lazy[2 * idx + 1] += q - self.data[2 * idx] += q - self.data[2 * idx + 1] += q + def _merge_data(self, a, b, c): + """a: data_idx, b: data_idx, c: data_idx; merge a and b store result in c""" + self.data[c] = self.data[a] + self.data[b] def _update(self, idx): """updates the node idx to know of all queries applied to it via its ancestors""" for i in reversed(range(1, idx.bit_length())): - self._push(idx >> i) + _idx = idx >> i + self._apply_to_data(_idx, 2 * _idx) + self._apply_to_lazy(_idx, 2 * _idx) + self._apply_to_data(_idx, 2 * _idx + 1) + self._apply_to_lazy(_idx, 2 * _idx + 1) + self._unset_lazy(_idx) def _build(self, idx): """make the changes to idx be known to its ancestors""" idx >>= 1 while idx: - self.data[idx] = self._func(self.data[2 * idx], self.data[2 * idx + 1]) + self._lazy[idx] + self._merge_data(2 * idx, 2 * idx + 1, idx) + self._apply_to_data(idx, idx) idx >>= 1 - def add(self, start, stop, value): - """lazily add value to [start, stop)""" + def apply(self, start, stop, value): + """lazily apply value to [start, stop)""" start = start_copy = start + self._size stop = stop_copy = stop + self._size + + self.lazy[self._buffer_idx] = value + while start < stop: if start & 1: - self._lazy[start] += value - self.data[start] += value + self._apply_to_lazy(self._buffer_idx, start) + self._apply_to_data(self._buffer_idx, start) start += 1 if stop & 1: stop -= 1 - self._lazy[stop] += value - self.data[stop] += value + self._apply_to_lazy(self._buffer_idx, stop) + self._apply_to_data(self._buffer_idx, stop) start >>= 1 stop >>= 1 @@ -58,7 +76,7 @@ def add(self, start, stop, value): self._build(start_copy) self._build(stop_copy - 1) - def query(self, start, stop, default=0): + def query(self, start, stop): """func of data[start, stop)""" start += self._size stop += self._size @@ -67,17 +85,19 @@ def query(self, start, stop, default=0): self._update(start) self._update(stop - 1) - res = default + self.data[self._buffer_idx] = 0 + while start < stop: if start & 1: - res = self._func(res, self.data[start]) + self._merge_data(self._buffer_idx, start, self._buffer_idx) start += 1 if stop & 1: stop -= 1 - res = self._func(res, self.data[stop]) + self._merge_data(self._buffer_idx, stop, self._buffer_idx) start >>= 1 stop >>= 1 - return res + + return self.data[self._buffer_idx] def __repr__(self): return "LazySegmentTree({0})".format(self.data)