Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create and update rtree spatial index only when needed #179

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 68 additions & 26 deletions mesa_geo/geospace.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@
"""
Return the bounds of the GeoSpace in [min_x, min_y, max_x, max_y] format.
"""
if self._total_bounds is None:
if len(self.agents) > 0:
self._update_bounds(self._agent_layer.total_bounds)
if len(self.layers) > 0:
for layer in self.layers:
self._update_bounds(layer.total_bounds)
return self._total_bounds

def _update_bounds(self, new_bounds: np.ndarray) -> None:
Expand Down Expand Up @@ -128,7 +134,7 @@
"to `False` to suppress this warning message."
)
layer.to_crs(self.crs, inplace=True)
self._update_bounds(layer.total_bounds)
self._total_bounds = None
self._static_layers.append(layer)

def _check_agent(self, agent):
Expand Down Expand Up @@ -161,7 +167,7 @@
for agent in agents:
self._check_agent(agent)
self._agent_layer.add_agents(agents)
self._update_bounds(new_bounds=self._agent_layer.total_bounds)
self._total_bounds = None

def _recreate_rtree(self, new_agents=None):
"""Create a new rtree index from agents geometries."""
Expand All @@ -170,6 +176,7 @@
def remove_agent(self, agent):
"""Remove an agent from the GeoSpace."""
self._agent_layer.remove_agent(agent)
self._total_bounds = None

def get_relation(self, agent, relation):
"""Return a list of related agents.
Expand Down Expand Up @@ -231,34 +238,50 @@
"""

def __init__(self):
# neighborhood graph for touching neighbors
self._neighborhood = None

# Set up rtree index
self.idx = index.Index()
self.idx.agents = {}
# rtree index for spatial indexing (e.g., neighbors within distance, agents at pos, etc.)
self._idx = None
self._id_to_agent = {}
# bounds of the layer in [min_x, min_y, max_x, max_y] format
# While it is possible to calculate the bounds from rtree index,
# total_bounds is almost always needed (e.g., for plotting), while rtree index is not.
# Hence we compute total_bounds separately from rtree index.
self._total_bounds = None

@property
def agents(self):
"""
Return a list of all agents in the layer.
"""

return list(self.idx.agents.values())
return list(self._id_to_agent.values())

@property
def total_bounds(self):
"""
Return the bounds of the layer in [min_x, min_y, max_x, max_y] format.
"""

return self.idx.get_bounds(coordinate_interleaved=True)
if self._total_bounds is None and len(self.agents) > 0:
bounds = np.array([agent.geometry.bounds for agent in self.agents])
min_x, min_y = np.min(bounds[:, :2], axis=0)
max_x, max_y = np.max(bounds[:, 2:], axis=0)
self._total_bounds = np.array([min_x, min_y, max_x, max_y])
return self._total_bounds

def _get_rtree_intersections(self, geometry):
"""
Calculate rtree intersections for candidate agents.
"""

return (self.idx.agents[i] for i in self.idx.intersection(geometry.bounds))
self._ensure_index()
if self._idx is None:
return []

Check warning on line 280 in mesa_geo/geospace.py

View check run for this annotation

Codecov / codecov/patch

mesa_geo/geospace.py#L280

Added line #L280 was not covered by tests
else:
return [
self._id_to_agent[i] for i in self._idx.intersection(geometry.bounds)
]

def _create_neighborhood(self):
"""
Expand All @@ -273,21 +296,27 @@
for agent, key in zip(agents, self._neighborhood.neighbors.keys()):
self._neighborhood.idx[agent] = key

def _ensure_index(self):
"""
Ensure that the rtree index is created.
"""

if self._idx is None:
self._recreate_rtree()

def _recreate_rtree(self, new_agents=None):
"""
Create a new rtree index from agents geometries.
"""

if new_agents is None:
new_agents = []
old_agents = list(self.agents)
agents = old_agents + new_agents

# Bulk insert agents
index_data = ((id(agent), agent.geometry.bounds, None) for agent in agents)
agents = list(self.agents) + new_agents

self.idx = index.Index(index_data)
self.idx.agents = {id(agent): agent for agent in agents}
if len(agents) > 0:
# Bulk insert agents
index_data = ((id(agent), agent.geometry.bounds, None) for agent in agents)
self._idx = index.Index(index_data)

def add_agents(self, agents):
"""
Expand All @@ -303,18 +332,25 @@

if isinstance(agents, GeoAgent):
agent = agents
self.idx.insert(id(agent), agent.geometry.bounds, None)
self.idx.agents[id(agent)] = agent
self._id_to_agent[id(agent)] = agent
if self._idx:
self._idx.insert(id(agent), agent.geometry.bounds, None)
else:
self._recreate_rtree(agents)
for agent in agents:
self._id_to_agent[id(agent)] = agent
if self._idx:
self._recreate_rtree(agents)
self._total_bounds = None

def remove_agent(self, agent):
"""
Remove an agent from the layer.
"""

self.idx.delete(id(agent), agent.geometry.bounds)
del self.idx.agents[id(agent)]
del self._id_to_agent[id(agent)]
if self._idx:
self._idx.delete(id(agent), agent.geometry.bounds)

Check warning on line 352 in mesa_geo/geospace.py

View check run for this annotation

Codecov / codecov/patch

mesa_geo/geospace.py#L352

Added line #L352 was not covered by tests
self._total_bounds = None

def get_relation(self, agent, relation):
"""Return a list of related agents.
Expand All @@ -327,6 +363,7 @@
Omit to compare against all other agents of the layer.
"""

self._ensure_index()
possible_agents = self._get_rtree_intersections(agent.geometry)
for other_agent in possible_agents:
if (
Expand All @@ -336,6 +373,7 @@
yield other_agent

def get_intersecting_agents(self, agent):
self._ensure_index()
intersecting_agents = self.get_relation(agent, "intersects")
return intersecting_agents

Expand All @@ -347,7 +385,7 @@
Distance is measured as a buffer around the agent's geometry,
set center=True to calculate distance from center.
"""

self._ensure_index()
if center:
geometry = agent.geometry.centroid.buffer(distance)
else:
Expand All @@ -363,6 +401,7 @@
Return a generator of agents at given pos.
"""

self._ensure_index()
if not isinstance(pos, Point):
pos = Point(pos)

Expand All @@ -386,10 +425,13 @@
if not self._neighborhood or self._neighborhood.agents != self.agents:
self._create_neighborhood()

idx = self._neighborhood.idx[agent]
neighbors_idx = self._neighborhood.neighbors[idx]
neighbors = [self.agents[i] for i in neighbors_idx]
return neighbors
if self._neighborhood is None:
return []

Check warning on line 429 in mesa_geo/geospace.py

View check run for this annotation

Codecov / codecov/patch

mesa_geo/geospace.py#L429

Added line #L429 was not covered by tests
else:
idx = self._neighborhood.idx[agent]
neighbors_idx = self._neighborhood.neighbors[idx]
neighbors = [self.agents[i] for i in neighbors_idx]
return neighbors

def get_agents_as_GeoDataFrame(self, agent_cls=GeoAgent) -> gpd.GeoDataFrame:
"""
Expand Down
10 changes: 10 additions & 0 deletions tests/test_GeoSpace.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,13 @@ def test_agents_at(self):
agent.unique_id for agent in self.geo_space.agents_at((1, 1))
}
self.assertEqual(agents_id_found, agents_id)

def test_get_neighbors(self):
self.geo_space.add_agents(self.polygon_agent)
self.assertEqual(len(self.geo_space.get_neighbors(self.polygon_agent)), 0)
self.geo_space.add_agents(self.touching_agent)
self.assertEqual(len(self.geo_space.get_neighbors(self.polygon_agent)), 1)
self.assertEqual(
self.geo_space.get_neighbors(self.polygon_agent)[0].unique_id,
self.touching_agent.unique_id,
)