diff --git a/mesa_geo/visualization/__init__.py b/mesa_geo/visualization/__init__.py index 256d887d..70f92ad9 100644 --- a/mesa_geo/visualization/__init__.py +++ b/mesa_geo/visualization/__init__.py @@ -1,5 +1,4 @@ # Import specific classes or functions from the modules -from mesa_geo.visualization.geojupyter_viz import GeoJupyterViz -from mesa_geo.visualization.leaflet_viz import LeafletViz +from .components.geospace_leaflet import MapModule, make_geospace_leaflet -__all__ = ["GeoJupyterViz", "LeafletViz"] +__all__ = ["make_geospace_leaflet", "MapModule"] diff --git a/mesa_geo/visualization/leaflet_viz.py b/mesa_geo/visualization/components/geospace_leaflet.py similarity index 86% rename from mesa_geo/visualization/leaflet_viz.py rename to mesa_geo/visualization/components/geospace_leaflet.py index 3c46b1a6..37682a89 100644 --- a/mesa_geo/visualization/leaflet_viz.py +++ b/mesa_geo/visualization/components/geospace_leaflet.py @@ -1,8 +1,3 @@ -""" -# ipyleaflet -Map visualization using [ipyleaflet](https://ipyleaflet.readthedocs.io/), a ipywidgets wrapper for [leaflet.js](https://leafletjs.com/) -""" - import dataclasses from dataclasses import dataclass @@ -11,30 +6,57 @@ import solara import xyzservices from folium.utilities import image_to_url +from mesa.visualization.utils import update_counter from shapely.geometry import Point, mapping from mesa_geo.raster_layers import RasterBase, RasterLayer from mesa_geo.tile_layers import LeafletOption, RasterWebTile -@solara.component -def map(model, map_drawer, zoom, center_default, scroll_wheel_zoom): - # render map in browser - zoom_map = solara.reactive(zoom) - center = solara.reactive(center_default) +def make_geospace_leaflet( + agent_portrayal, + view=None, + tiles=xyzservices.providers.OpenStreetMap.Mapnik, + **kwargs, +): + def MakeSpaceMatplotlib(model): + return GeoSpaceLeaflet(model, agent_portrayal, view, tiles, **kwargs) - base_map = map_drawer.tiles - layers = map_drawer.render(model) + return MakeSpaceMatplotlib + +@solara.component +def GeoSpaceLeaflet(model, agent_portrayal, view, tiles, **kwargs): + update_counter.get() + map_drawer = MapModule(portrayal_method=agent_portrayal, tiles=tiles) + model_view = map_drawer.render(model) + + if view is None: + # longlat [min_x, min_y, max_x, max_y] to latlong [min_y, min_x, max_y, max_x] + transformed_xx, transformed_yy = model.space.transformer.transform( + xx=[model.space.total_bounds[0], model.space.total_bounds[2]], + yy=[model.space.total_bounds[1], model.space.total_bounds[3]], + ) + view = [ + (transformed_yy[0] + transformed_yy[1]) / 2, + (transformed_xx[0] + transformed_xx[1]) / 2, + ] + + layers = ( + [ipyleaflet.TileLayer.element(url=map_drawer.tiles["url"])] if tiles else [] + ) + for layer in model_view["layers"]["rasters"]: + layers.append(ipyleaflet.ImageOverlay(element=image_to_url(layer))) + for layer in model_view["layers"]["vectors"]: + layers.append(ipyleaflet.GeoJSON(element=layer)) ipyleaflet.Map.element( - zoom=zoom_map.value, - center=center.value, - scroll_wheel_zoom=scroll_wheel_zoom, + center=view, layers=[ - ipyleaflet.TileLayer.element(url=base_map["url"]), - ipyleaflet.GeoJSON.element(data=layers["agents"][0]), - *layers["agents"][1], + *layers, + ipyleaflet.GeoJSON.element(data=model_view["agents"][0]), + *model_view["agents"][1], ], + **kwargs, ) @@ -65,9 +87,6 @@ class MapModule: def __init__( self, portrayal_method, - view, - zoom, - scroll_wheel_zoom, tiles, ): """ diff --git a/mesa_geo/visualization/geojupyter_viz.py b/mesa_geo/visualization/geojupyter_viz.py deleted file mode 100644 index 08599345..00000000 --- a/mesa_geo/visualization/geojupyter_viz.py +++ /dev/null @@ -1,236 +0,0 @@ -import matplotlib.pyplot as plt -import mesa.experimental.components.matplotlib as components_matplotlib -import solara -import xyzservices.providers as xyz -from mesa.experimental import solara_viz as jv -from solara.alias import rv - -import mesa_geo.visualization.leaflet_viz as leaflet_viz - -# Avoid interactive backend -plt.switch_backend("agg") - - -# TODO: Turn this function into a Solara component once the current_step.value -# dependency is passed to measure() -""" -Geo-Mesa Visualization Module -============================= -Card: Helper Function that initiates the Solara Card for Browser -GeoJupyterViz: Main Function users employ to create visualization -""" - - -def Card( - model, - measures, - agent_portrayal, - map_drawer, - center_default, - zoom, - scroll_wheel_zoom, - current_step, - color, - layout_type, -): - """ - - - Parameters - ---------- - model : Mesa Model Object - A pointer to the Mesa Model object this allows the visual to get get - model information, such as scheduler and space. - measures : List - Plots associated with model typically from datacollector that represent - critical information collected from the model. - agent_portrayal : Dictionary - Contains details of how visualization should plot key elements of the - such as agent color etc - map_drawer : Method - Function that generates map from GIS data of model - center_default : List - Latitude and Longitude of where center of map should be located - zoom : Int - Zoom level at which to initialize the map - scroll_wheel_zoom: Boolean - True of False on whether user can zoom on map with mouse scroll wheel - default is True - current_step : Int - Number on which step is the model - color : String - Background color for visual - layout_type : String - Type of layout Map or Measure - - Returns - ------- - main : Solara object - Visualization of model - - """ - - with rv.Card( - style_=f"background-color: {color}; width: 100%; height: 100%" - ) as main: - if "Map" in layout_type: - rv.CardTitle(children=["Map"]) - leaflet_viz.map(model, map_drawer, zoom, center_default, scroll_wheel_zoom) - - if "Measure" in layout_type: - rv.CardTitle(children=["Measure"]) - measure = measures[layout_type["Measure"]] - if callable(measure): - # Is a custom object - measure(model) - else: - components_matplotlib.PlotMatplotlib( - model, measure, dependencies=[current_step.value] - ) - return main - - -@solara.component -def GeoJupyterViz( - model_class, - model_params, - measures=None, - name=None, - agent_portrayal=None, - play_interval=150, - # parameters for leaflet_viz - view=None, - zoom=None, - scroll_wheel_zoom=True, - tiles=xyz.OpenStreetMap.Mapnik, - center_point=None, # Due to projection challenges in calculation allow user to specify center point -): - """ - - - Parameters - ---------- - model_class : Mesa Model Object - A pointer to the Mesa Model object this allows the visual to get get - model information, such as scheduler and space. - model_params : Dictionary - Parameters of model with key being the parameter as a string and values being the options - measures : List, optional - Plots associated with model typically from datacollector that represent - critical information collected from the model. The default is None. - name : String, optional - Name of simulation to appear on visual. The default is None. - agent_portrayal : Dictionary, optional - Dictionary of how the agent showed appear. The default is None. - play_interval : INT, optional - Rendering interval of model. The default is 150. - # parameters for leaflet_viz - view : List, optional - Bounds of map to be displayed; must be set with zoom. The default is None. - zoom : Int, optional - Zoom level of map on leaflet - scroll_wheel_zoom : Boolean, optional - True of False for whether or not to enable scroll wheel. The default is True. - Recommend False when using jupyter due to multiple scroll wheel options - tiles : Data source for GIS data, optional - Data Source for GIS map data. The default is xyz.OpenStreetMap.Mapnik. - # Due to projection challenges in calculation allow user to specify - center_point : List, optional - Option to pass in center coordinates of map The default is None.. The default is None. - - - Returns - ------- - Provides information to Card to render model - - """ - - if name is None: - name = model_class.__name__ - - current_step = solara.use_reactive(0) - - # 1. Set up model parameters - user_params, fixed_params = jv.split_model_params(model_params) - model_parameters, set_model_parameters = solara.use_state( - {**fixed_params, **{k: v.get("value") for k, v in user_params.items()}} - ) - - # 2. Set up Model - def make_model(): - model = model_class(**model_parameters) - current_step.value = 0 - return model - - reset_counter = solara.use_reactive(0) - model = solara.use_memo( - make_model, dependencies=[*list(model_parameters.values()), reset_counter.value] - ) - - def handle_change_model_params(name: str, value: any): - set_model_parameters({**model_parameters, name: value}) - - # 3. Set up UI - with solara.AppBar(): - solara.AppBarTitle(name) - - # 4. Set Up Map - # render layout, pass through map build parameters - map_drawer = leaflet_viz.MapModule( - portrayal_method=agent_portrayal, - view=view, - zoom=zoom, - tiles=tiles, - scroll_wheel_zoom=scroll_wheel_zoom, - ) - layers = map_drawer.render(model) - - # determine center point - if center_point: - center_default = center_point - else: - bounds = layers["layers"]["total_bounds"] - center_default = [ - (bounds[0][0] + bounds[1][0]) / 2, - (bounds[0][1] + bounds[1][1]) / 2, - ] - - # Build base data structure for layout - layout_types = [{"Map": "default"}] - - if measures: - layout_types += [{"Measure": elem} for elem in range(len(measures))] - - grid_layout_initial = jv.make_initial_grid_layout(layout_types=layout_types) - grid_layout, set_grid_layout = solara.use_state(grid_layout_initial) - - with solara.Sidebar(): - with solara.Card("Controls", margin=1, elevation=2): - jv.UserInputs(user_params, on_change=handle_change_model_params) - jv.ModelController(model, play_interval, current_step, reset_counter) - with solara.Card("Progress", margin=1, elevation=2): - solara.Markdown(md_text=f"####Step - {current_step}") - - items = [ - Card( - model, - measures, - agent_portrayal, - map_drawer, - center_default, - zoom, - scroll_wheel_zoom, - current_step, - color="white", - layout_type=layout_types[i], - ) - for i in range(len(layout_types)) - ] - - solara.GridDraggable( - items=items, - grid_layout=grid_layout, - resizable=True, - draggable=True, - on_grid_layout=set_grid_layout, - ) diff --git a/tests/test_GeoJupyterViz.py b/tests/test_GeoJupyterViz.py deleted file mode 100644 index 1a646475..00000000 --- a/tests/test_GeoJupyterViz.py +++ /dev/null @@ -1,124 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -import solara - -from mesa_geo.visualization.geojupyter_viz import Card, GeoJupyterViz - - -class TestGeoViz(unittest.TestCase): - @patch("mesa_geo.visualization.geojupyter_viz.rv.CardTitle") - @patch("mesa_geo.visualization.geojupyter_viz.rv.Card") - @patch("mesa_geo.visualization.geojupyter_viz.components_matplotlib.PlotMatplotlib") - @patch("mesa_geo.visualization.geojupyter_viz.leaflet_viz.map") - def test_card_function( - self, - mock_map, - mock_PlotMatplotlib, # noqa: N803 - mock_Card, # noqa: N803 - mock_CardTitle, # noqa: N803 - ): - model = MagicMock() - measures = {"Measure1": lambda x: x} - agent_portrayal = MagicMock() - map_drawer = (MagicMock(),) - zoom = (10,) - scroll_wheel_zoom = (True,) - center_default = ([0, 0],) - current_step = MagicMock() - current_step.value = 0 - color = "white" - layout_type = {"Map": "default", "Measure": "Measure1"} - - with patch( - "mesa_geo.visualization.geojupyter_viz.rv.Card", return_value=MagicMock() - ) as mock_rv_card: - _ = Card( - model, - measures, - agent_portrayal, - map_drawer, - center_default, - zoom, - scroll_wheel_zoom, - current_step, - color, - layout_type, - ) - - mock_rv_card.assert_called_once() - mock_CardTitle.assert_any_call(children=["Map"]) - mock_map.assert_called_once_with( - model, map_drawer, zoom, center_default, scroll_wheel_zoom - ) - # mock_PlotMatplotlib.assert_called_once() - - @patch("mesa_geo.visualization.geojupyter_viz.solara.GridDraggable") - @patch("mesa_geo.visualization.geojupyter_viz.solara.Sidebar") - @patch("mesa_geo.visualization.geojupyter_viz.solara.Card") - @patch("mesa_geo.visualization.geojupyter_viz.solara.Markdown") - @patch("mesa_geo.visualization.geojupyter_viz.jv.ModelController") - @patch("mesa_geo.visualization.geojupyter_viz.jv.UserInputs") - @patch("mesa_geo.visualization.geojupyter_viz.jv.split_model_params") - @patch("mesa_geo.visualization.geojupyter_viz.solara.use_memo") - @patch("mesa_geo.visualization.geojupyter_viz.solara.use_reactive") - @patch("mesa_geo.visualization.geojupyter_viz.solara.use_state") - @patch("mesa_geo.visualization.geojupyter_viz.solara.AppBarTitle") - @patch("mesa_geo.visualization.geojupyter_viz.solara.AppBar") - @patch("mesa_geo.visualization.geojupyter_viz.leaflet_viz.MapModule") - def test_geojupyterviz_function( - self, - mock_MapModule, # noqa: N803 - mock_AppBar, # noqa: N803 - mock_AppBarTitle, # noqa: N803 - mock_use_state, - mock_use_reactive, - mock_use_memo, - mock_split_model_params, - mock_UserInputs, # noqa: N803 - mock_ModelController, # noqa: N803 - mock_Markdown, # noqa: N803 - mock_Card, # noqa: N803 - mock_Sidebar, # noqa: N803 - mock_GridDraggable, # noqa: N803 - ): - model_class = MagicMock() - model_params = MagicMock() - measures = [lambda x: x] - name = "TestModel" - agent_portrayal = MagicMock() - play_interval = 150 - view = [0, 0] - zoom = 10 - center_point = [0, 0] - - mock_use_reactive.side_effect = [MagicMock(value=0), MagicMock(value=0)] - mock_split_model_params.return_value = ({}, {}) - mock_use_state.return_value = ({}, MagicMock()) - mock_use_memo.return_value = MagicMock() - - solara.render( - GeoJupyterViz( - model_class=model_class, - model_params=model_params, - measures=measures, - name=name, - agent_portrayal=agent_portrayal, - play_interval=play_interval, - view=view, - zoom=zoom, - center_point=center_point, - ) - ) - - mock_AppBar.assert_called_once() - mock_AppBarTitle.assert_called_once_with(name) - mock_split_model_params.assert_called_once_with(model_params) - mock_use_memo.assert_called_once() - mock_UserInputs.assert_called_once() - mock_ModelController.assert_called_once() - mock_Markdown.assert_called() - mock_Card.assert_called() - mock_Sidebar.assert_called_once() - mock_GridDraggable.assert_called_once() - mock_MapModule.assert_called_once() diff --git a/tests/test_MapModule.py b/tests/test_MapModule.py index b6f421e2..5bf07eb3 100644 --- a/tests/test_MapModule.py +++ b/tests/test_MapModule.py @@ -39,25 +39,19 @@ def tearDown(self) -> None: def test_render_point_agents(self): # test length point agents and Circle marker as default - map_module = mgv.leaflet_viz.MapModule( + map_module = mgv.MapModule( portrayal_method=lambda x: {"color": "Green"}, - view=None, - zoom=3, - scroll_wheel_zoom=True, tiles=xyz.OpenStreetMap.Mapnik, ) self.model.space.add_agents(self.point_agents) self.assertEqual(len(map_module.render(self.model).get("agents")[1]), 7) self.assertIsInstance(map_module.render(self.model).get("agents")[1][3], Circle) # test CircleMarker option - map_module = mgv.leaflet_viz.MapModule( + map_module = mgv.MapModule( portrayal_method=lambda x: { "marker_type": "CircleMarker", "color": "Green", }, - view=None, - zoom=3, - scroll_wheel_zoom=True, tiles=xyz.OpenStreetMap.Mapnik, ) self.model.space.add_agents(self.point_agents) @@ -66,30 +60,24 @@ def test_render_point_agents(self): ) # test Marker option - map_module = mgv.leaflet_viz.MapModule( + map_module = mgv.MapModule( portrayal_method=lambda x: { "marker_type": "AwesomeIcon", "name": "bus", "color": "Green", }, - view=None, - zoom=3, - scroll_wheel_zoom=True, tiles=xyz.OpenStreetMap.Mapnik, ) self.model.space.add_agents(self.point_agents) self.assertEqual(len(map_module.render(self.model).get("agents")[1]), 7) self.assertIsInstance(map_module.render(self.model).get("agents")[1][3], Marker) # test popupProperties for Point - map_module = mgv.leaflet_viz.MapModule( + map_module = mgv.MapModule( portrayal_method=lambda x: { "color": "Red", "radius": 7, "description": "popupMsg", }, - view=None, - zoom=3, - scroll_wheel_zoom=False, tiles=xyz.OpenStreetMap.Mapnik, ) self.model.space.add_agents(self.point_agents) @@ -103,11 +91,8 @@ def test_render_point_agents(self): ) # test ValueError if not known markertype - map_module = mgv.leaflet_viz.MapModule( + map_module = mgv.MapModule( portrayal_method=lambda x: {"marker_type": "Hexagon", "color": "Green"}, - view=None, - zoom=3, - scroll_wheel_zoom=True, tiles=xyz.OpenStreetMap.Mapnik, ) self.model.space.add_agents(self.point_agents) @@ -115,11 +100,8 @@ def test_render_point_agents(self): map_module.render(self.model) def test_render_line_agents(self): - map_module = mgv.leaflet_viz.MapModule( + map_module = mgv.MapModule( portrayal_method=lambda x: {"color": "#3388ff", "weight": 7}, - view=None, - zoom=3, - scroll_wheel_zoom=False, tiles=xyz.OpenStreetMap.Mapnik, ) self.model.space.add_agents(self.line_agents) @@ -141,15 +123,12 @@ def test_render_line_agents(self): }, ) - map_module = mgv.leaflet_viz.MapModule( + map_module = mgv.MapModule( portrayal_method=lambda x: { "color": "#3388ff", "weight": 7, "description": "popupMsg", }, - view=None, - zoom=3, - scroll_wheel_zoom=True, tiles=xyz.OpenStreetMap.Mapnik, ) self.model.space.add_agents(self.line_agents) @@ -177,11 +156,8 @@ def test_render_line_agents(self): def test_render_polygon_agents(self): self.maxDiff = None - map_module = mgv.leaflet_viz.MapModule( + map_module = mgv.MapModule( portrayal_method=lambda x: {"fillColor": "#3388ff", "fillOpacity": 0.7}, - view=None, - zoom=3, - scroll_wheel_zoom=False, tiles=xyz.OpenStreetMap.Mapnik, ) self.model.space.add_agents(self.polygon_agents) @@ -207,15 +183,12 @@ def test_render_polygon_agents(self): }, ) - map_module = mgv.leaflet_viz.MapModule( + map_module = mgv.MapModule( portrayal_method=lambda x: { "fillColor": "#3388ff", "fillOpacity": 0.7, "description": "popupMsg", }, - view=None, - zoom=3, - scroll_wheel_zoom=True, tiles=xyz.OpenStreetMap.Mapnik, ) self.model.space.add_agents(self.polygon_agents) @@ -243,11 +216,8 @@ def test_render_polygon_agents(self): ) def test_render_raster_layers(self): - map_module = mg.visualization.leaflet_viz.MapModule( + map_module = mgv.MapModule( portrayal_method=lambda x: (255, 255, 255, 0.5), - view=None, - zoom=3, - scroll_wheel_zoom=True, tiles=xyz.OpenStreetMap.Mapnik, ) self.model.space.add_layer(self.raster_layer) diff --git a/tests/test_geospace_leaflet.py b/tests/test_geospace_leaflet.py new file mode 100644 index 00000000..6c348046 --- /dev/null +++ b/tests/test_geospace_leaflet.py @@ -0,0 +1,28 @@ +import mesa +import solara +import xyzservices +from mesa.visualization.solara_viz import SolaraViz + +from mesa_geo.visualization import make_geospace_leaflet + + +def test_geospace_leaflet(mocker): + mock_geospace_leaflet = mocker.patch( + "mesa_geo.visualization.components.geospace_leaflet.GeoSpaceLeaflet" + ) + + model = mesa.Model() + mocker.patch.object(mesa.Model, "__new__", return_value=model) + mocker.patch.object(mesa.Model, "__init__", return_value=None) + + agent_portrayal = { + "Shape": "circle", + "color": "gray", + } + # initialize with space drawer unspecified (use default) + # component must be rendered for code to run + solara.render(SolaraViz(model, components=[make_geospace_leaflet(agent_portrayal)])) + # should call default method with class instance and agent portrayal + mock_geospace_leaflet.assert_called_with( + model, agent_portrayal, None, xyzservices.providers.OpenStreetMap.Mapnik + )