Skip to content

Commit

Permalink
lint fastpair
Browse files Browse the repository at this point in the history
  • Loading branch information
jGaboardi committed Jun 2, 2024
1 parent 1ee24d1 commit d053de5
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 65 deletions.
3 changes: 0 additions & 3 deletions fastpair/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""FastPair: Data-structure for the dynamic closest-pair problem.
Init module for FastPair.
Expand Down
29 changes: 12 additions & 17 deletions fastpair/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""FastPair: Data-structure for the dynamic closest-pair problem.
This data-structure is based on the observation that the conga line data
Expand Down Expand Up @@ -33,25 +30,23 @@
# Copyright (c) 2002-2015, David Eppstein
# Licensed under the MIT Licence (http://opensource.org/licenses/MIT).

from __future__ import print_function, division, absolute_import
from itertools import combinations, cycle
from operator import itemgetter
from collections import defaultdict
import scipy.spatial.distance as dist
from itertools import combinations, cycle

import scipy.spatial.distance as dist

__all__ = ["FastPair", "dist"]


class attrdict(dict):
class AttrDict(dict):
"""Simple dict with support for accessing elements as attributes."""

def __init__(self, *args, **kwargs):
super(attrdict, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.__dict__ = self


class FastPair(object):
class FastPair:
"""FastPair 'sketch' class."""

def __init__(self, min_points=10, dist=dist.euclidean):
Expand All @@ -73,8 +68,8 @@ def __init__(self, min_points=10, dist=dist.euclidean):
self.min_points = min_points
self.dist = dist
self.initialized = False # Has the data-structure been initialized?
self.neighbors = defaultdict(attrdict) # Dict of neighbor points and dists
self.points = list() # Internal point set; entries may be non-unique
self.neighbors = defaultdict(AttrDict) # Dict of neighbor points and dists
self.points = [] # Internal point set; entries may be non-unique

def __add__(self, p):
"""Add a point and find its nearest neighbor.
Expand All @@ -98,7 +93,7 @@ def __sub__(self, p):
# We must update neighbors of points for which `p` had been nearest.
for q in self.points:
if self.neighbors[q].neigh == p:
res = self._find_neighbor(q)
self._find_neighbor(q)
return self

def __len__(self):
Expand All @@ -116,13 +111,13 @@ def __iter__(self):
return iter(self.points)

def __getitem__(self, item):
if not item in self:
raise KeyError("{} not found".format(item))
if item not in self:
raise KeyError(f"{item} not found")
return self.neighbors[item]

def __setitem__(self, item, value):
if not item in self:
raise KeyError("{} not found".format(item))
if item not in self:
raise KeyError(f"{item} not found")
self._update_point(item, value)

def build(self, points=None):
Expand Down
86 changes: 41 additions & 45 deletions fastpair/test/test_fastpair.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""FastPair: Data-structure for the dynamic closest-pair problem.
Testing module for FastPair.
Expand All @@ -10,18 +7,17 @@
# Copyright (c) 2002-2015, David Eppstein
# Licensed under the MIT Licence (http://opensource.org/licenses/MIT).

from __future__ import print_function, division, absolute_import

from operator import itemgetter
from types import FunctionType
from itertools import cycle, combinations, groupby
import random
import pytest
from fastpair import FastPair
from itertools import combinations, cycle
from math import isinf, isnan
from operator import itemgetter
from types import FunctionType

from scipy import mean, array, unique
import numpy as np
import pytest
from scipy import array, mean

from fastpair import FastPair


def normalized_distance(_a, _b):
Expand Down Expand Up @@ -63,7 +59,7 @@ def all_close(s, t, tol=1e-8):
# Ignores inf and nan values...
return all(
abs(a - b) < tol
for a, b in zip(s, t)
for a, b in zip(s, t, strict=True)
if not isinf(a) and not isinf(b) and not isnan(a) and not isnan(b)
)

Expand All @@ -79,7 +75,7 @@ def interact(u, v):

# Setup fixtures
@pytest.fixture(scope="module")
def PointSet(n=50, d=10):
def point_set(n=50, d=10):
"""Return numpy array of shape `n`x`d`."""
# random.seed(8714)
return [rand_tuple(d) for _ in range(n)]
Expand All @@ -96,15 +92,15 @@ def test_init(self):
assert len(fp.points) == 0
assert len(fp.neighbors) == 0

def test_build(self, PointSet):
ps = PointSet
def test_build(self, point_set):
ps = point_set
fp = FastPair().build(ps)
assert len(fp) == len(ps)
assert len(fp.neighbors) == len(ps)
assert fp.initialized is True

def test_add(self, PointSet):
ps = PointSet
def test_add(self, point_set):
ps = point_set
fp = FastPair()
for p in ps[:9]:
fp += p
Expand All @@ -114,8 +110,8 @@ def test_add(self, PointSet):
fp += p
assert fp.initialized is True

def test_sub(self, PointSet):
ps = PointSet
def test_sub(self, point_set):
ps = point_set
fp = FastPair().build(ps)
start = fp._find_neighbor(ps[-1])
fp -= ps[-1]
Expand All @@ -127,31 +123,31 @@ def test_sub(self, PointSet):
with pytest.raises(ValueError):
fp -= rand_tuple(len(ps[0]))

def test_len(self, PointSet):
ps = PointSet
def test_len(self, point_set):
ps = point_set
fp = FastPair()
assert len(fp) == 0
fp.build(ps)
assert len(fp) == len(ps)

def test_contains(self, PointSet):
ps = PointSet
def test_contains(self, point_set):
ps = point_set
fp = FastPair()
assert ps[0] not in fp
fp.build(ps)
assert ps[0] in fp

def test_call_and_closest_pair(self, PointSet):
ps = PointSet
def test_call_and_closest_pair(self, point_set):
ps = point_set
fp = FastPair().build(ps)
cp = fp.closest_pair()
bf = fp.closest_pair_brute_force()
assert fp() == cp
assert abs(cp[0] - bf[0]) < 1e-8
assert cp[1] == bf[1]

def test_all_closest_pairs(self, PointSet):
ps = PointSet
def test_all_closest_pairs(self, point_set):
ps = point_set
fp = FastPair().build(ps)
cp = fp.closest_pair()
bf = fp.closest_pair_brute_force() # Ordering should be the same
Expand All @@ -168,24 +164,24 @@ def test_all_closest_pairs(self, PointSet):
# Ordering may be different, but both should be in there
# assert dc[1][0] in cp[1] and dc[1][1] in cp[1]

def test_find_neighbor_and_sdist(self, PointSet):
ps = PointSet
def test_find_neighbor_and_sdist(self, point_set):
ps = point_set
fp = FastPair().build(ps)
rando = rand_tuple(len(ps[0]))
neigh = fp._find_neighbor(rando) # Abusing find_neighbor!
dist = fp.dist(rando, neigh["neigh"])
assert abs(dist - neigh["dist"]) < 1e-8
assert len(fp) == len(ps) # Make sure we didn't add a point...
l = [(fp.dist(a, b), b) for a, b in zip(cycle([rando]), ps)]
res = min(l, key=itemgetter(0))
l_ = [(fp.dist(a, b), b) for a, b in zip(cycle([rando]), ps)]
res = min(l_, key=itemgetter(0))
assert abs(res[0] - neigh["dist"]) < 1e-8
assert res[1] == neigh["neigh"]
res = min(fp.sdist(rando), key=itemgetter(0))
assert abs(neigh["dist"] - res[0]) < 1e-8
assert neigh["neigh"] == res[1]

def test_cluster(self, PointSet):
ps = PointSet
def test_cluster(self, point_set):
ps = point_set
fp = FastPair().build(ps)
for i in range(len(fp) - 1):
# Version one
Expand All @@ -209,9 +205,9 @@ def test_cluster(self, PointSet):
assert contains_same(fp.points, ps)
assert len(fp.points) == len(ps) == 1

def test_update_point(self, PointSet):
def test_update_point(self, point_set):
# Still failing sometimes...
ps = PointSet
ps = point_set
fp = FastPair().build(ps)
assert len(fp) == len(ps)
old = ps[0] # Just grab the first point...
Expand All @@ -220,9 +216,9 @@ def test_update_point(self, PointSet):
assert old not in fp
assert new in fp
assert len(fp) == len(ps) # Size shouldn't change
l = [(fp.dist(a, b), b) for a, b in zip(cycle([new]), ps)]
res = min(l, key=itemgetter(0))
neigh = fp.neighbors[new]
l_ = [(fp.dist(a, b), b) for a, b in zip(cycle([new]), ps)]
res = min(l_, key=itemgetter(0)) # noqa: F841
neigh = fp.neighbors[new] # noqa: F841
# assert abs(res[0] - neigh["dist"]) < 1e-8
# assert res[1] == neigh["neigh"]

Expand Down Expand Up @@ -267,8 +263,8 @@ def test_call_and_closest_pair_min_points(self, image_array):
assert abs(cp[0] - bf[0]) < 1e-8
assert cp[1] == bf[1]

def test_iter(self, PointSet):
ps = PointSet
def test_iter(self, point_set):
ps = point_set
fp = FastPair().build(ps)
assert fp.min_points == 10
assert isinstance(fp.dist, FunctionType)
Expand All @@ -277,23 +273,23 @@ def test_iter(self, PointSet):
assert fp[ps[0]].neigh in set(ps)

try:
myitem = fp[(2, 3, 4)]
fp[(2, 3, 4)]
except KeyError as err:
print(err)

fp[ps[0]] = fp[ps[0]].neigh
try:
fp[(2, 3, 4)] = fp[ps[0]].neigh
fp[ps[0]].neigh # noqa: B018
except KeyError as err:
print(err)

def test_update_point_less_points(self, PointSet):
ps = PointSet
def test_update_point_less_points(self, point_set):
ps = point_set
fp = FastPair()
for p in ps[:9]:
fp += p
assert fp.initialized is False
old = ps[0] # Just grab the first point...
new = rand_tuple(len(ps[0]))
res = fp._update_point(old, new)
fp._update_point(old, new)
assert len(fp) == 1
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,9 @@ omit = ["fastpair/test/*"]

[tool.ruff]
line-length = 88
lint.select = ["E", "F", "W", "I", "UP", "N", "B", "A", "C4", "SIM", "ARG"]

[tool.ruff.lint.per-file-ignores]
"*__init__.py" = [
"F401", # imported but unused
]

0 comments on commit d053de5

Please sign in to comment.