Skip to content

Commit

Permalink
create and update rtree spatial index only when needed
Browse files Browse the repository at this point in the history
  • Loading branch information
wang-boyu authored and Corvince committed Jan 4, 2024
1 parent 08f6570 commit f9f7d6a
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 26 deletions.
94 changes: 68 additions & 26 deletions mesa_geo/geospace.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ def total_bounds(self) -> np.ndarray | None:
"""
Return the bounds of the GeoSpace in [min_x, min_y, max_x, max_y] format.
"""
if self._total_bounds is None:
if len(self.agents) > 0:
self._update_bounds(self._agent_layer.total_bounds)
if len(self.layers) > 0:
for layer in self.layers:
self._update_bounds(layer.total_bounds)
return self._total_bounds

def _update_bounds(self, new_bounds: np.ndarray) -> None:
Expand Down Expand Up @@ -128,7 +134,7 @@ def add_layer(self, layer: ImageLayer | RasterLayer | gpd.GeoDataFrame) -> None:
"to `False` to suppress this warning message."
)
layer.to_crs(self.crs, inplace=True)
self._update_bounds(layer.total_bounds)
self._total_bounds = None
self._static_layers.append(layer)

def _check_agent(self, agent):
Expand Down Expand Up @@ -161,7 +167,7 @@ def add_agents(self, agents):
for agent in agents:
self._check_agent(agent)
self._agent_layer.add_agents(agents)
self._update_bounds(new_bounds=self._agent_layer.total_bounds)
self._total_bounds = None

def _recreate_rtree(self, new_agents=None):
"""Create a new rtree index from agents geometries."""
Expand All @@ -170,6 +176,7 @@ def _recreate_rtree(self, new_agents=None):
def remove_agent(self, agent):
"""Remove an agent from the GeoSpace."""
self._agent_layer.remove_agent(agent)
self._total_bounds = None

def get_relation(self, agent, relation):
"""Return a list of related agents.
Expand Down Expand Up @@ -231,34 +238,50 @@ class _AgentLayer:
"""

def __init__(self):
# neighborhood graph for touching neighbors
self._neighborhood = None

# Set up rtree index
self.idx = index.Index()
self.idx.agents = {}
# rtree index for spatial indexing (e.g., neighbors within distance, agents at pos, etc.)
self._idx = None
self._id_to_agent = {}
# bounds of the layer in [min_x, min_y, max_x, max_y] format
# While it is possible to calculate the bounds from rtree index,
# total_bounds is almost always needed (e.g., for plotting), while rtree index is not.
# Hence we compute total_bounds separately from rtree index.
self._total_bounds = None

@property
def agents(self):
"""
Return a list of all agents in the layer.
"""

return list(self.idx.agents.values())
return list(self._id_to_agent.values())

@property
def total_bounds(self):
"""
Return the bounds of the layer in [min_x, min_y, max_x, max_y] format.
"""

return self.idx.get_bounds(coordinate_interleaved=True)
if self._total_bounds is None and len(self.agents) > 0:
bounds = np.array([agent.geometry.bounds for agent in self.agents])
min_x, min_y = np.min(bounds[:, :2], axis=0)
max_x, max_y = np.max(bounds[:, 2:], axis=0)
self._total_bounds = np.array([min_x, min_y, max_x, max_y])
return self._total_bounds

def _get_rtree_intersections(self, geometry):
"""
Calculate rtree intersections for candidate agents.
"""

return (self.idx.agents[i] for i in self.idx.intersection(geometry.bounds))
self._ensure_index()
if self._idx is None:
return []
else:
return [
self._id_to_agent[i] for i in self._idx.intersection(geometry.bounds)
]

def _create_neighborhood(self):
"""
Expand All @@ -273,21 +296,27 @@ def _create_neighborhood(self):
for agent, key in zip(agents, self._neighborhood.neighbors.keys()):
self._neighborhood.idx[agent] = key

def _ensure_index(self):
"""
Ensure that the rtree index is created.
"""

if self._idx is None:
self._recreate_rtree()

def _recreate_rtree(self, new_agents=None):
"""
Create a new rtree index from agents geometries.
"""

if new_agents is None:
new_agents = []
old_agents = list(self.agents)
agents = old_agents + new_agents

# Bulk insert agents
index_data = ((id(agent), agent.geometry.bounds, None) for agent in agents)
agents = list(self.agents) + new_agents

self.idx = index.Index(index_data)
self.idx.agents = {id(agent): agent for agent in agents}
if len(agents) > 0:
# Bulk insert agents
index_data = ((id(agent), agent.geometry.bounds, None) for agent in agents)
self._idx = index.Index(index_data)

def add_agents(self, agents):
"""
Expand All @@ -303,18 +332,25 @@ def add_agents(self, agents):

if isinstance(agents, GeoAgent):
agent = agents
self.idx.insert(id(agent), agent.geometry.bounds, None)
self.idx.agents[id(agent)] = agent
self._id_to_agent[id(agent)] = agent
if self._idx:
self._idx.insert(id(agent), agent.geometry.bounds, None)
else:
self._recreate_rtree(agents)
for agent in agents:
self._id_to_agent[id(agent)] = agent
if self._idx:
self._recreate_rtree(agents)
self._total_bounds = None

def remove_agent(self, agent):
"""
Remove an agent from the layer.
"""

self.idx.delete(id(agent), agent.geometry.bounds)
del self.idx.agents[id(agent)]
del self._id_to_agent[id(agent)]
if self._idx:
self._idx.delete(id(agent), agent.geometry.bounds)
self._total_bounds = None

def get_relation(self, agent, relation):
"""Return a list of related agents.
Expand All @@ -327,6 +363,7 @@ def get_relation(self, agent, relation):
Omit to compare against all other agents of the layer.
"""

self._ensure_index()
possible_agents = self._get_rtree_intersections(agent.geometry)
for other_agent in possible_agents:
if (
Expand All @@ -336,6 +373,7 @@ def get_relation(self, agent, relation):
yield other_agent

def get_intersecting_agents(self, agent):
self._ensure_index()
intersecting_agents = self.get_relation(agent, "intersects")
return intersecting_agents

Expand All @@ -347,7 +385,7 @@ def get_neighbors_within_distance(
Distance is measured as a buffer around the agent's geometry,
set center=True to calculate distance from center.
"""

self._ensure_index()
if center:
geometry = agent.geometry.centroid.buffer(distance)
else:
Expand All @@ -363,6 +401,7 @@ def agents_at(self, pos):
Return a generator of agents at given pos.
"""

self._ensure_index()
if not isinstance(pos, Point):
pos = Point(pos)

Expand All @@ -386,10 +425,13 @@ def get_neighbors(self, agent):
if not self._neighborhood or self._neighborhood.agents != self.agents:
self._create_neighborhood()

idx = self._neighborhood.idx[agent]
neighbors_idx = self._neighborhood.neighbors[idx]
neighbors = [self.agents[i] for i in neighbors_idx]
return neighbors
if self._neighborhood is None:
return []
else:
idx = self._neighborhood.idx[agent]
neighbors_idx = self._neighborhood.neighbors[idx]
neighbors = [self.agents[i] for i in neighbors_idx]
return neighbors

def get_agents_as_GeoDataFrame(self, agent_cls=GeoAgent) -> gpd.GeoDataFrame:
"""
Expand Down
10 changes: 10 additions & 0 deletions tests/test_GeoSpace.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,13 @@ def test_agents_at(self):
agent.unique_id for agent in self.geo_space.agents_at((1, 1))
}
self.assertEqual(agents_id_found, agents_id)

def test_get_neighbors(self):
self.geo_space.add_agents(self.polygon_agent)
self.assertEqual(len(self.geo_space.get_neighbors(self.polygon_agent)), 0)
self.geo_space.add_agents(self.touching_agent)
self.assertEqual(len(self.geo_space.get_neighbors(self.polygon_agent)), 1)
self.assertEqual(
self.geo_space.get_neighbors(self.polygon_agent)[0].unique_id,
self.touching_agent.unique_id,
)

0 comments on commit f9f7d6a

Please sign in to comment.