diff --git a/src/util.h b/src/util.h index 356ead8..8b869c6 100644 --- a/src/util.h +++ b/src/util.h @@ -287,46 +287,49 @@ class IntervalMap void set_interval(K begin, K end, const V v) { if (begin >= end) return; - // get interval that `begin` intersects with (inclusive) - iter begin_intersect = --my_map.upper_bound(begin); - - // get interval that `end` intersects with (inclusive) + // get end intersector (inclusive) iter end_intersect = --my_map.upper_bound(end); - // if required, insert at start - iter inserted_start = my_map.end(); - if (begin_intersect->second != v) { - inserted_start = my_map.insert_or_assign(begin_intersect, begin, v); - } - // if required, insert at end iter inserted_end = my_map.end(); if (end_intersect->second != v) { inserted_end = my_map.insert_or_assign(end_intersect, end, end_intersect->second); } + // get begin intersector (inclusive) + iter begin_intersect = --my_map.upper_bound(begin); + + // if required, insert at start + iter inserted_start = my_map.end(); + if (begin_intersect->second != v) { + inserted_start = my_map.insert_or_assign(begin_intersect, begin, v); + } + // delete everyone inside iter del_start = inserted_start != my_map.end() ? inserted_start : begin_intersect; - if (del_start->first == begin) { + if (del_start->first < begin || (del_start->first == begin && std::prev(del_start)->second != v)) { del_start++; } iter del_end = inserted_end != my_map.end() ? inserted_end : end_intersect; + if (del_end != my_map.end() && del_end->first == end && std::next(del_end) != my_map.end() && del_end->second == v) { + del_end++; + } - if (del_start != my_map.end()) { + if (del_start != my_map.end() && del_start->first < del_end->first) { my_map.erase(del_start, del_end); } } // iterator which traverses elements in sorted order (smallest to largest) // O(1) - constexpr inline iter &begin() { + constexpr inline auto &begin() { return my_map.begin(); } // end of elements // O(1) - constexpr inline iter &end() { + constexpr inline auto &end() { return my_map.end(); } @@ -338,7 +341,12 @@ class IntervalMap // get value at key `k` // O(log N) - const V operator[](K const& k) { + const inline V& operator[](K const& k) const { + return (--my_map.upper_bound(k))->second; + } + + // don't return reference because we don't want to allow map[whatever] = value as it would edit the next-earliest value instead of inserting a new element. (TODO: write better.) + inline V operator[](K const& k) { return (--my_map.upper_bound(k))->second; } @@ -347,6 +355,18 @@ class IntervalMap my_map.clear(); my_map.insert(my_map.end(), { std::numeric_limits::lowest(), v }); } + + // get num intervals overall + // always at least 1 + constexpr int num_intervals() { + return my_map.size(); + } + + // get number of intervals in a range + // UNTESTED + constexpr int num_intervals(const K& start, const K& end) { + return get_interval(end) - get_interval(start); + } }; #endif // __UTIL_H__