Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend redistribute_integer_pairs to floats #4222

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
RELEASE_TYPE: patch

The shrinker contains a pass aimed at integers which are required to sum to a value. This patch extends that pass to floats as well.
35 changes: 18 additions & 17 deletions hypothesis-python/src/hypothesis/internal/conjecture/shrinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def greedy_shrink(self):
"reorder_examples",
"minimize_duplicated_nodes",
"minimize_individual_nodes",
"redistribute_integer_pairs",
"redistribute_numeric_pairs",
"lower_blocks_together",
]
)
Expand Down Expand Up @@ -1224,34 +1224,30 @@ def minimize_duplicated_nodes(self, chooser):
more values at once.
"""
nodes = chooser.choose(self.duplicated_nodes)
# we can't lower any nodes which are trivial. try proceeding with the
# remaining nodes.
nodes = [node for node in nodes if not node.trivial]
if len(nodes) <= 1:
return

# no point in lowering nodes together if one is already trivial.
# TODO_BETTER_SHRINK: we could potentially just drop the trivial nodes
# here and carry on with nontrivial ones?
if any(node.trivial for node in nodes):
return

self.minimize_nodes(nodes)

@defines_shrink_pass()
def redistribute_integer_pairs(self, chooser):
"""If there is a sum of generated integers that we need their sum
def redistribute_numeric_pairs(self, chooser):
"""If there is a sum of generated numbers that we need their sum
to exceed some bound, lowering one of them requires raising the
other. This pass enables that."""
# TODO_SHRINK let's extend this to floats as well.

# look for a pair of nodes (node1, node2) which are both integers and
# aren't separated by too many other nodes. We'll decrease node1 and
# look for a pair of nodes (node1, node2) which are both numeric
# and aren't separated by too many other nodes. We'll decrease node1 and
# increase node2 (note that the other way around doesn't make sense as
# it's strictly worse in the ordering).
node1 = chooser.choose(
self.nodes, lambda node: node.ir_type == "integer" and not node.trivial
self.nodes,
lambda node: node.ir_type in {"integer", "float"} and not node.trivial,
)
node2 = chooser.choose(
self.nodes,
lambda node: node.ir_type == "integer"
lambda node: node.ir_type in {"integer", "float"}
# Note that it's fine for node2 to be trivial, because we're going to
# explicitly make it *not* trivial by adding to its value.
and not node.was_forced
Expand All @@ -1267,8 +1263,13 @@ def boost(k):
if k > m:
return False

node_value = m - k
next_node_value = n + k
try:
node_value = m - k
next_node_value = n + k
except OverflowError: # pragma: no cover
# if n or m is a float and k is over sys.float_info.max, coercing
# k to a float will overflow.
return False
Comment on lines +1266 to +1272
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How well does this handle loss of precision? For integer k I guess very small floats just round to the nearest resulting integer which seems fine, but for large m we might have m - k == m and n + k > n which makes this an antishrink pass in such cases!

(adding some type annotations for m, n, k might help clarify for future readers)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, this will antishrink as you point out. Let's change this pass to only consider floats below MAX_PRECISE_INTEGER which should sidestep most of these issues

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth noting that antishrinking shouldn't cause any problems, right? The shrink order means that we always only shrink, and cases where this ends up increasing shrink order should just get silently discarded. As long as those cases are relatively rare this is probably fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think this is just a smallish performance issue, but worth the easy fix.


return self.consider_new_tree(
self.nodes[: node1.index]
Expand Down
6 changes: 3 additions & 3 deletions hypothesis-python/tests/conjecture/test_shrinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,16 +505,16 @@ def shrinker(data: ConjectureData):
assert shrinker.choices == (1, 0) + (0,) * n_gap + (1,)


def test_redistribute_integer_pairs_with_forced_node():
def test_redistribute_pairs_with_forced_node_integer():
@shrinking_from(ir(15, 10))
def shrinker(data: ConjectureData):
n1 = data.draw_integer(0, 100)
n2 = data.draw_integer(0, 100, forced=10)
if n1 + n2 > 20:
data.mark_interesting()

shrinker.fixate_shrink_passes(["redistribute_integer_pairs"])
# redistribute_integer_pairs shouldn't try modifying forced nodes while
shrinker.fixate_shrink_passes(["redistribute_numeric_pairs"])
# redistribute_numeric_pairs shouldn't try modifying forced nodes while
# shrinking. Since the second draw is forced, this isn't possible to shrink
# with just this pass.
assert shrinker.choices == (15, 10)
33 changes: 31 additions & 2 deletions hypothesis-python/tests/quality/test_shrink_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,13 +343,29 @@ def test_lists_forced_near_top(n):
) == [0] * (n + 2)


def test_sum_of_pair():
def test_sum_of_pair_int():
assert minimal(
tuples(integers(0, 1000), integers(0, 1000)), lambda x: sum(x) > 1000
) == (1, 1000)


def test_sum_of_pair_separated():
def test_sum_of_pair_float():
assert minimal(
tuples(st.floats(0, 1000), st.floats(0, 1000)), lambda x: sum(x) > 1000
) == (1.0, 1000.0)


def test_sum_of_pair_mixed():
# check both orderings
assert minimal(
tuples(st.floats(0, 1000), st.integers(0, 1000)), lambda x: sum(x) > 1000
) == (1.0, 1000.0)
assert minimal(
tuples(st.integers(0, 1000), st.floats(0, 1000)), lambda x: sum(x) > 1000
) == (1.0, 1000.0)
tybug marked this conversation as resolved.
Show resolved Hide resolved


def test_sum_of_pair_separated_int():
@st.composite
def separated_sum(draw):
n1 = draw(st.integers(0, 1000))
Expand All @@ -362,6 +378,19 @@ def separated_sum(draw):
assert minimal(separated_sum(), lambda x: sum(x) > 1000) == (1, 1000)


def test_sum_of_pair_separated_float():
@st.composite
def separated_sum(draw):
f1 = draw(st.floats(0, 1000))
draw(st.text())
draw(st.booleans())
draw(st.integers())
f2 = draw(st.floats(0, 1000))
return (f1, f2)

assert minimal(separated_sum(), lambda x: sum(x) > 1000) == (1, 1000)


def test_calculator_benchmark():
"""This test comes from
https://github.com/jlink/shrinking-challenge/blob/main/challenges/calculator.md,
Expand Down
Loading