diff --git a/mesa_geo/geospace.py b/mesa_geo/geospace.py index 7cfe6567..44e3c8d8 100644 --- a/mesa_geo/geospace.py +++ b/mesa_geo/geospace.py @@ -91,6 +91,12 @@ def total_bounds(self) -> np.ndarray | None: """ Return the bounds of the GeoSpace in [min_x, min_y, max_x, max_y] format. """ + if self._total_bounds is None: + if len(self.agents) > 0: + self._update_bounds(self._agent_layer.total_bounds) + if len(self.layers) > 0: + for layer in self.layers: + self._update_bounds(layer.total_bounds) return self._total_bounds def _update_bounds(self, new_bounds: np.ndarray) -> None: @@ -128,7 +134,7 @@ def add_layer(self, layer: ImageLayer | RasterLayer | gpd.GeoDataFrame) -> None: "to `False` to suppress this warning message." ) layer.to_crs(self.crs, inplace=True) - self._update_bounds(layer.total_bounds) + self._total_bounds = None self._static_layers.append(layer) def _check_agent(self, agent): @@ -161,7 +167,7 @@ def add_agents(self, agents): for agent in agents: self._check_agent(agent) self._agent_layer.add_agents(agents) - self._update_bounds(new_bounds=self._agent_layer.total_bounds) + self._total_bounds = None def _recreate_rtree(self, new_agents=None): """Create a new rtree index from agents geometries.""" @@ -170,6 +176,7 @@ def _recreate_rtree(self, new_agents=None): def remove_agent(self, agent): """Remove an agent from the GeoSpace.""" self._agent_layer.remove_agent(agent) + self._total_bounds = None def get_relation(self, agent, relation): """Return a list of related agents. @@ -231,11 +238,16 @@ class _AgentLayer: """ def __init__(self): + # neighborhood graph for touching neighbors self._neighborhood = None - - # Set up rtree index - self.idx = index.Index() - self.idx.agents = {} + # rtree index for spatial indexing (e.g., neighbors within distance, agents at pos, etc.) + self._idx = None + self._id_to_agent = {} + # bounds of the layer in [min_x, min_y, max_x, max_y] format + # While it is possible to calculate the bounds from rtree index, + # total_bounds is almost always needed (e.g., for plotting), while rtree index is not. + # Hence we compute total_bounds separately from rtree index. + self._total_bounds = None @property def agents(self): @@ -243,7 +255,7 @@ def agents(self): Return a list of all agents in the layer. """ - return list(self.idx.agents.values()) + return list(self._id_to_agent.values()) @property def total_bounds(self): @@ -251,14 +263,25 @@ def total_bounds(self): Return the bounds of the layer in [min_x, min_y, max_x, max_y] format. """ - return self.idx.get_bounds(coordinate_interleaved=True) + if self._total_bounds is None and len(self.agents) > 0: + bounds = np.array([agent.geometry.bounds for agent in self.agents]) + min_x, min_y = np.min(bounds[:, :2], axis=0) + max_x, max_y = np.max(bounds[:, 2:], axis=0) + self._total_bounds = np.array([min_x, min_y, max_x, max_y]) + return self._total_bounds def _get_rtree_intersections(self, geometry): """ Calculate rtree intersections for candidate agents. """ - return (self.idx.agents[i] for i in self.idx.intersection(geometry.bounds)) + self._ensure_index() + if self._idx is None: + return [] + else: + return [ + self._id_to_agent[i] for i in self._idx.intersection(geometry.bounds) + ] def _create_neighborhood(self): """ @@ -273,6 +296,14 @@ def _create_neighborhood(self): for agent, key in zip(agents, self._neighborhood.neighbors.keys()): self._neighborhood.idx[agent] = key + def _ensure_index(self): + """ + Ensure that the rtree index is created. + """ + + if self._idx is None: + self._recreate_rtree() + def _recreate_rtree(self, new_agents=None): """ Create a new rtree index from agents geometries. @@ -280,14 +311,12 @@ def _recreate_rtree(self, new_agents=None): if new_agents is None: new_agents = [] - old_agents = list(self.agents) - agents = old_agents + new_agents - - # Bulk insert agents - index_data = ((id(agent), agent.geometry.bounds, None) for agent in agents) + agents = list(self.agents) + new_agents - self.idx = index.Index(index_data) - self.idx.agents = {id(agent): agent for agent in agents} + if len(agents) > 0: + # Bulk insert agents + index_data = ((id(agent), agent.geometry.bounds, None) for agent in agents) + self._idx = index.Index(index_data) def add_agents(self, agents): """ @@ -303,18 +332,25 @@ def add_agents(self, agents): if isinstance(agents, GeoAgent): agent = agents - self.idx.insert(id(agent), agent.geometry.bounds, None) - self.idx.agents[id(agent)] = agent + self._id_to_agent[id(agent)] = agent + if self._idx: + self._idx.insert(id(agent), agent.geometry.bounds, None) else: - self._recreate_rtree(agents) + for agent in agents: + self._id_to_agent[id(agent)] = agent + if self._idx: + self._recreate_rtree(agents) + self._total_bounds = None def remove_agent(self, agent): """ Remove an agent from the layer. """ - self.idx.delete(id(agent), agent.geometry.bounds) - del self.idx.agents[id(agent)] + del self._id_to_agent[id(agent)] + if self._idx: + self._idx.delete(id(agent), agent.geometry.bounds) + self._total_bounds = None def get_relation(self, agent, relation): """Return a list of related agents. @@ -327,6 +363,7 @@ def get_relation(self, agent, relation): Omit to compare against all other agents of the layer. """ + self._ensure_index() possible_agents = self._get_rtree_intersections(agent.geometry) for other_agent in possible_agents: if ( @@ -336,6 +373,7 @@ def get_relation(self, agent, relation): yield other_agent def get_intersecting_agents(self, agent): + self._ensure_index() intersecting_agents = self.get_relation(agent, "intersects") return intersecting_agents @@ -347,7 +385,7 @@ def get_neighbors_within_distance( Distance is measured as a buffer around the agent's geometry, set center=True to calculate distance from center. """ - + self._ensure_index() if center: geometry = agent.geometry.centroid.buffer(distance) else: @@ -363,6 +401,7 @@ def agents_at(self, pos): Return a generator of agents at given pos. """ + self._ensure_index() if not isinstance(pos, Point): pos = Point(pos) @@ -386,10 +425,13 @@ def get_neighbors(self, agent): if not self._neighborhood or self._neighborhood.agents != self.agents: self._create_neighborhood() - idx = self._neighborhood.idx[agent] - neighbors_idx = self._neighborhood.neighbors[idx] - neighbors = [self.agents[i] for i in neighbors_idx] - return neighbors + if self._neighborhood is None: + return [] + else: + idx = self._neighborhood.idx[agent] + neighbors_idx = self._neighborhood.neighbors[idx] + neighbors = [self.agents[i] for i in neighbors_idx] + return neighbors def get_agents_as_GeoDataFrame(self, agent_cls=GeoAgent) -> gpd.GeoDataFrame: """ diff --git a/tests/test_GeoSpace.py b/tests/test_GeoSpace.py index 13fb0897..f18c36d1 100644 --- a/tests/test_GeoSpace.py +++ b/tests/test_GeoSpace.py @@ -265,3 +265,13 @@ def test_agents_at(self): agent.unique_id for agent in self.geo_space.agents_at((1, 1)) } self.assertEqual(agents_id_found, agents_id) + + def test_get_neighbors(self): + self.geo_space.add_agents(self.polygon_agent) + self.assertEqual(len(self.geo_space.get_neighbors(self.polygon_agent)), 0) + self.geo_space.add_agents(self.touching_agent) + self.assertEqual(len(self.geo_space.get_neighbors(self.polygon_agent)), 1) + self.assertEqual( + self.geo_space.get_neighbors(self.polygon_agent)[0].unique_id, + self.touching_agent.unique_id, + )