From f54a87a48902b8a772d98752f30b1be626f6bf39 Mon Sep 17 00:00:00 2001 From: Chris Arderne Date: Sat, 17 Feb 2024 18:07:17 +0000 Subject: [PATCH] add back notebook test --- Makefile | 1 + examples/example.ipynb | 2 +- examples/test.sh | 16 ++++++++++++++++ gridfinder/gridfinder.py | 41 +++++++++++++++++++++------------------- 4 files changed, 40 insertions(+), 20 deletions(-) create mode 100755 examples/test.sh diff --git a/Makefile b/Makefile index eba7e64..85a1925 100644 --- a/Makefile +++ b/Makefile @@ -6,3 +6,4 @@ lint: .PHONY: test test: pytest + ./examples/test.sh diff --git a/examples/example.ipynb b/examples/example.ipynb index 2816705..58c8142 100644 --- a/examples/example.ipynb +++ b/examples/example.ipynb @@ -182,7 +182,7 @@ "metadata": {}, "outputs": [], "source": [ - "true_pos, false_neg = gf.accuracy(grid_truth, guess_out, aoi_in)\n", + "true_pos, false_neg = gf.accuracy(grid_truth_in, guess_out, aoi_in)\n", "print(f\"Points identified as grid that are grid: {100*true_pos:.0f}%\")\n", "print(f\"Actual grid that was missed: {100*false_neg:.0f}%\")" ] diff --git a/examples/test.sh b/examples/test.sh new file mode 100755 index 0000000..98fcada --- /dev/null +++ b/examples/test.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +cd examples + +# convert the example notebook to a script +python -m jupyter nbconvert --to script example.ipynb + + +# disable Jupyter and plt.imshow output +sed -i -e 's/jupyter=True/jupyter=False/g' example.py +sed -i -e 's/plt.imshow(.*)//g' example.py + +# run script +python example.py + +rm example.py example.py-e || true diff --git a/gridfinder/gridfinder.py b/gridfinder/gridfinder.py index 5bb3655..1be32f4 100644 --- a/gridfinder/gridfinder.py +++ b/gridfinder/gridfinder.py @@ -125,26 +125,22 @@ def optimise( while len(queue): _, current_loc = heappop(queue) - current_i = current_loc[0] - current_j = current_loc[1] current_dist = dist[current_loc] for x in (-1, 0, 1): for y in (-1, 0, 1): - next_i = current_i + x - next_j = current_j + y + next_i = current_loc[0] + x + next_j = current_loc[1] + y next_loc = (next_i, next_j) - # ensure we're within bounds - if next_i < 0 or next_j < 0 or next_i >= max_i or next_j >= max_j: - continue - - # ensure we're not looking at the same spot - if next_loc == current_loc: - continue - - # skip if we've already set dist to 0 - if dist[next_loc] == 0.0: + if ( + (x == 0 and y == 0) # same spot + or dist[next_loc] == 0.0 # already zerod + or next_i < 0 # out of bounds + or next_j < 0 + or next_i >= max_i + or next_j >= max_j + ): continue # if the location is connected @@ -169,12 +165,14 @@ def optimise( next_dist = current_dist + dist_add + # visited before if visited[next_loc]: if next_dist < dist[next_loc]: dist[next_loc] = next_dist prev[next_loc] = current_loc heappush(queue, (next_dist, next_loc)) + # brand new cell - progress! else: heappush(queue, (next_dist, next_loc)) visited[next_loc] = 1 @@ -186,9 +184,14 @@ def optimise( progress_new = int(100 * counter / max_cells) if progress_new > progress: progress = progress_new - if progress % 5 == 0: - print(progress) - else: - print(".") - + with nb.objmode(): + print_progress(progress) + print() return dist + + +def print_progress(progress: int) -> None: + if progress % 5 == 0: + print(progress, end="", flush=True) + else: + print(".", end="", flush=True)