Skip to content

Commit

Permalink
Make Compatible with Mesa 3.0
Browse files Browse the repository at this point in the history
- update visualization module to Solara naming convention
- update tutorial to ensure solara naming and `super().init()`
- update tests to solara naming convention
  • Loading branch information
tpike3 committed Aug 24, 2024
1 parent b32c1e0 commit 8ac0f48
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 32 deletions.
6 changes: 4 additions & 2 deletions docs/tutorials/intro_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@
" self, pop_size=30, mobility_range=500, init_infection=0.2, exposure_dist=500, max_infection_risk=0.2,\n",
" max_recovery_time=5\n",
" ):\n",
" super().__init__()\n",
" self.schedule = mesa.time.RandomActivationByType(self)\n",
" self.space = mg.GeoSpace(warn_crs_conversion=False)\n",
" \n",
Expand Down Expand Up @@ -775,6 +776,7 @@
" self, pop_size=30, mobility_range=500, init_infection=0.2, exposure_dist=500, max_infection_risk=0.2,\n",
" max_recovery_time=5\n",
" ):\n",
" super().__init__()\n",
" #Scheduler\n",
" self.schedule = mesa.time.RandomActivationByType(self)\n",
" #Space\n",
Expand Down Expand Up @@ -1081,7 +1083,7 @@
},
"outputs": [],
"source": [
"page = mgv.GeoJupyterViz(\n",
"page = mgv.GeoSolaraViz(\n",
" GeoSIR,\n",
" model_params,\n",
" measures= [[\"infected\", \"susceptible\", \"recovered\", \"dead\"], [\"safe\", \"hotspot\"]],\n",
Expand Down Expand Up @@ -1111,7 +1113,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.12.5"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions mesa_geo/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Import specific classes or functions from the modules
from mesa_geo.visualization.geojupyter_viz import GeoJupyterViz
from mesa_geo.visualization.leaflet_viz import LeafletViz
from .geosolara_viz import GeoJupyterViz, GeoSolaraViz
from .leaflet_viz import LeafletViz

__all__ = ["GeoJupyterViz", "LeafletViz"]
__all__ = ["GeoJupyterViz", "GeoSolaraViz", "LeafletViz"]
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import matplotlib.pyplot as plt
import mesa.experimental.components.matplotlib as components_matplotlib
import mesa.visualization.components.matplotlib as components_matplotlib
import solara
import xyzservices.providers as xyz
from mesa.experimental import jupyter_viz as jv
from mesa.visualization import solara_viz as sv
from solara.alias import rv

import mesa_geo.visualization.leaflet_viz as leaflet_viz
Expand Down Expand Up @@ -91,7 +91,7 @@ def Card(


@solara.component
def GeoJupyterViz(
def GeoSolaraViz(
model_class,
model_params,
measures=None,
Expand Down Expand Up @@ -151,7 +151,7 @@ def GeoJupyterViz(
current_step = solara.use_reactive(0)

# 1. Set up model parameters
user_params, fixed_params = jv.split_model_params(model_params)
user_params, fixed_params = sv.split_model_params(model_params)
model_parameters, set_model_parameters = solara.use_state(
{**fixed_params, **{k: v.get("value") for k, v in user_params.items()}}
)
Expand Down Expand Up @@ -201,13 +201,13 @@ def handle_change_model_params(name: str, value: any):
if measures:
layout_types += [{"Measure": elem} for elem in range(len(measures))]

grid_layout_initial = jv.make_initial_grid_layout(layout_types=layout_types)
grid_layout_initial = sv.make_initial_grid_layout(layout_types=layout_types)
grid_layout, set_grid_layout = solara.use_state(grid_layout_initial)

with solara.Sidebar():
with solara.Card("Controls", margin=1, elevation=2):
jv.UserInputs(user_params, on_change=handle_change_model_params)
jv.ModelController(model, play_interval, current_step, reset_counter)
sv.UserInputs(user_params, on_change=handle_change_model_params)
sv.ModelController(model, play_interval, current_step, reset_counter)
with solara.Card("Progress", margin=1, elevation=2):
solara.Markdown(md_text=f"####Step - {current_step}")

Expand All @@ -234,3 +234,6 @@ def handle_change_model_params(name: str, value: any):
draggable=True,
on_grid_layout=set_grid_layout,
)


GeoJupyterViz = GeoSolaraViz
40 changes: 20 additions & 20 deletions tests/test_GeoJupyterViz.py → tests/test_GeoSolaraViz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

import solara

from mesa_geo.visualization.geojupyter_viz import Card, GeoJupyterViz
from mesa_geo.visualization.geosolara_viz import Card, GeoSolaraViz


class TestGeoViz(unittest.TestCase):
@patch("mesa_geo.visualization.geojupyter_viz.rv.CardTitle")
@patch("mesa_geo.visualization.geojupyter_viz.rv.Card")
@patch("mesa_geo.visualization.geojupyter_viz.components_matplotlib.PlotMatplotlib")
@patch("mesa_geo.visualization.geojupyter_viz.leaflet_viz.map")
@patch("mesa_geo.visualization.geosolara_viz.rv.CardTitle")
@patch("mesa_geo.visualization.geosolara_viz.rv.Card")
@patch("mesa_geo.visualization.geosolara_viz.components_matplotlib.PlotMatplotlib")
@patch("mesa_geo.visualization.geosolara_viz.leaflet_viz.map")
def test_card_function(
self,
mock_map,
Expand All @@ -31,7 +31,7 @@ def test_card_function(
layout_type = {"Map": "default", "Measure": "Measure1"}

with patch(
"mesa_geo.visualization.geojupyter_viz.rv.Card", return_value=MagicMock()
"mesa_geo.visualization.geosolara_viz.rv.Card", return_value=MagicMock()
) as mock_rv_card:
_ = Card(
model,
Expand All @@ -53,19 +53,19 @@ def test_card_function(
)
# mock_PlotMatplotlib.assert_called_once()

@patch("mesa_geo.visualization.geojupyter_viz.solara.GridDraggable")
@patch("mesa_geo.visualization.geojupyter_viz.solara.Sidebar")
@patch("mesa_geo.visualization.geojupyter_viz.solara.Card")
@patch("mesa_geo.visualization.geojupyter_viz.solara.Markdown")
@patch("mesa_geo.visualization.geojupyter_viz.jv.ModelController")
@patch("mesa_geo.visualization.geojupyter_viz.jv.UserInputs")
@patch("mesa_geo.visualization.geojupyter_viz.jv.split_model_params")
@patch("mesa_geo.visualization.geojupyter_viz.solara.use_memo")
@patch("mesa_geo.visualization.geojupyter_viz.solara.use_reactive")
@patch("mesa_geo.visualization.geojupyter_viz.solara.use_state")
@patch("mesa_geo.visualization.geojupyter_viz.solara.AppBarTitle")
@patch("mesa_geo.visualization.geojupyter_viz.solara.AppBar")
@patch("mesa_geo.visualization.geojupyter_viz.leaflet_viz.MapModule")
@patch("mesa_geo.visualization.geosolara_viz.solara.GridDraggable")
@patch("mesa_geo.visualization.geosolara_viz.solara.Sidebar")
@patch("mesa_geo.visualization.geosolara_viz.solara.Card")
@patch("mesa_geo.visualization.geosolara_viz.solara.Markdown")
@patch("mesa_geo.visualization.geosolara_viz.sv.ModelController")
@patch("mesa_geo.visualization.geosolara_viz.sv.UserInputs")
@patch("mesa_geo.visualization.geosolara_viz.sv.split_model_params")
@patch("mesa_geo.visualization.geosolara_viz.solara.use_memo")
@patch("mesa_geo.visualization.geosolara_viz.solara.use_reactive")
@patch("mesa_geo.visualization.geosolara_viz.solara.use_state")
@patch("mesa_geo.visualization.geosolara_viz.solara.AppBarTitle")
@patch("mesa_geo.visualization.geosolara_viz.solara.AppBar")
@patch("mesa_geo.visualization.geosolara_viz.leaflet_viz.MapModule")
def test_geojupyterviz_function(
self,
mock_MapModule, # noqa: N803
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_geojupyterviz_function(
mock_use_memo.return_value = MagicMock()

solara.render(
GeoJupyterViz(
GeoSolaraViz(
model_class=model_class,
model_params=model_params,
measures=measures,
Expand Down

0 comments on commit 8ac0f48

Please sign in to comment.