Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into offline_output
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcinnes committed Nov 14, 2024
2 parents 67af68b + f19048b commit 7618bda
Show file tree
Hide file tree
Showing 9 changed files with 431 additions and 113 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
.virtual_documents

# IPython
profile_default/
Expand Down Expand Up @@ -158,3 +159,10 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

# Sphinx and documentation intermediates
examples/*.zip
examples/*.html
doc/_build
doc/auto_examples
doc/sg_execution_times.rst
92 changes: 81 additions & 11 deletions datamapplot/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
from warnings import warn
from collections.abc import Sequence
import inspect as ins
import json
import platformdirs
from pathlib import Path
import platformdirs
from typing import Any, Callable, cast, ParamSpec, TypeVar, Union
from warnings import warn


P = ParamSpec("P")
T = TypeVar("T")


DEFAULT_CONFIG = {
"dpi": 100,
"figsize": (10, 10),
"cdn_url": "unpkg.com",
}


class ConfigError(Exception):

def __init__(self, message: str, parameter: ins.Parameter) -> None:
super().__init__(message)
self.parameter = parameter


UnconfigurableParameters = Sequence[str]


class ConfigManager:
"""Configuration manager for the datamapplot package."""

Expand All @@ -19,27 +38,27 @@ def __new__(cls):
cls._instance = super(ConfigManager, cls).__new__(cls)
cls._instance._config = {}
return cls._instance

def __init__(self):
if self._instance is None:
if not self._config:
self._config_dir = platformdirs.user_config_dir("datamapplot")
self._config_file = Path(self._config_dir) / "config.json"
self._config = DEFAULT_CONFIG.copy()

self._ensure_config_file()
self._load_config()

def _ensure_config_file(self) -> None:
"""Create config directory and file if they don't exist."""
try:
self._config_file.parent.mkdir(parents=True, exist_ok=True)

if not self._config_file.exists():
with open(self._config_file, 'w') as f:
json.dump(DEFAULT_CONFIG, f, indent=2)
except Exception as e:
warn(f"Error creating config file: {e}")

def _load_config(self) -> None:
"""Load configuration from file."""
try:
Expand All @@ -48,18 +67,18 @@ def _load_config(self) -> None:
self._config.update(loaded_config)
except Exception as e:
warn(f"Error loading config file: {e}")

def save(self) -> None:
"""Save current configuration to file."""
try:
with open(self._config_file, 'w') as f:
json.dump(self._config, f, indent=2)
except Exception as e:
warn(f"Error saving config file: {e}")

def __getitem__(self, key):
return self._config[key]

def __setitem__(self, key, value):
self._config[key] = value

Expand All @@ -69,4 +88,55 @@ def __delitem__(self, key):
def __contains__(self, key):
return key in self._config


def complete(
self,
fn_or_unc: Union[None, UnconfigurableParameters, Callable[P, T]] = None,
unconfigurable: UnconfigurableParameters = set(),
) -> Union[Callable[[Callable[P, T]], Callable[P, T]], Callable[P, T]]:
def decorator(fn: Callable[P, T]) -> Callable[P, T]:
sig = ins.signature(fn)

def fn_with_config(*args, **kwargs):
bound_args = sig.bind(*args, **kwargs)
bindings = bound_args.arguments
from_config = {}
for name, param in sig.parameters.items():
if name not in bindings and name in self:
if not _is_admissible(param):
raise ConfigError(
"Only keyword (or plausibly keyword) parameters "
"can be set through the DataMapPlot configuration "
f"file. Parameter {param.name} ({param.kind}) "
"is thus not admissible.",
param
)
if name in unconfigurable:
raise ConfigError(
f"Parameter {param.name} is deliberately listed as "
"forbidden from being defined through the DataMapPlot "
"configuration file.",
param
)
from_config[name] = self[name]
return fn(*bound_args.args, **(bound_args.kwargs | from_config))

fn_with_config._gets_completed = True
return fn_with_config

if fn_or_unc is None:
return decorator
elif not hasattr(fn_or_unc, "__call__"):
unconfigurable = cast(UnconfigurableParameters, fn_or_unc)
return decorator
return decorator(fn_or_unc)

@staticmethod
def gets_completed(func) -> bool:
return hasattr(func, "_gets_completed") and func._gets_completed


_KINDS_ADMISSIBLE = {ins.Parameter.POSITIONAL_OR_KEYWORD, ins.Parameter.KEYWORD_ONLY}


def _is_admissible(param: ins.Parameter) -> bool:
return param.kind in _KINDS_ADMISSIBLE
32 changes: 5 additions & 27 deletions datamapplot/create_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pandas as pd
import textwrap
import colorcet
import inspect

from matplotlib import pyplot as plt
from matplotlib.colors import to_rgb
Expand All @@ -23,6 +22,10 @@
from datamapplot.config import ConfigManager


cfg = ConfigManager()


@cfg.complete(unconfigurable={"data_map_coords", "labels"})
def create_plot(
data_map_coords,
labels=None,
Expand Down Expand Up @@ -172,19 +175,6 @@ def create_plot(
The axes contained within the figure that the plot is rendered to.
"""
function_signature = inspect.signature(create_plot)
function_args = locals()
config = ConfigManager()

for param_name, param_value in function_signature.parameters.items():
if param_name in ("data_map_coords", "labels"):
continue

provided_value = function_args.get(param_name)
if provided_value == param_value.default:
if param_name in config:
function_args[param_name] = config[param_name]

if labels is None:
label_locations = np.zeros((0, 2), dtype=np.float32)
label_text = []
Expand Down Expand Up @@ -333,6 +323,7 @@ def create_plot(
return fig, ax


@cfg.complete(unconfigurable={"data_map_coords", "label_layers", "hover_text"})
def create_interactive_plot(
data_map_coords,
*label_layers,
Expand Down Expand Up @@ -472,19 +463,6 @@ def create_interactive_plot(
-------
"""
function_signature = inspect.signature(create_interactive_plot)
function_args = locals()
config = ConfigManager()

for param_name, param_value in function_signature.parameters.items():
if param_name in ("data_map_coords", "label_layers", "hover_text"):
continue

provided_value = function_args.get(param_name)
if provided_value is param_value.default:
if param_name in config:
function_args[param_name] = config[param_name]

if len(label_layers) == 0:
label_dataframe = pd.DataFrame(
{
Expand Down
18 changes: 4 additions & 14 deletions datamapplot/interactive_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import warnings
import zipfile
import json
import inspect
import platformdirs

import jinja2
Expand All @@ -33,6 +32,9 @@
from datamapplot.config import ConfigManager
from datamapplot import offline_mode_caching


cfg = ConfigManager()

_DECKGL_TEMPLATE_STR = (files("datamapplot") / "deckgl_template.html").read_text(
encoding="utf-8"
)
Expand Down Expand Up @@ -399,6 +401,7 @@ def label_text_and_polygon_dataframes(
return pd.DataFrame(data)


@cfg.complete(unconfigurable={"point_dataframe", "label_dataframe"})
def render_html(
point_dataframe,
label_dataframe,
Expand Down Expand Up @@ -718,19 +721,6 @@ def render_html(
An interactive figure with hover, pan, and zoom. This will display natively
in a notebook, and can be saved to an HTML file via the `save` method.
"""
function_signature = inspect.signature(render_html)
function_args = locals()
config = ConfigManager()

for param_name, param_value in function_signature.parameters.items():
if param_name in ("point_dataframe", "label_dataframe"):
continue

provided_value = function_args.get(param_name)
if provided_value is param_value.default:
if param_name in config:
function_args[param_name] = config[param_name]

# Compute point scaling
n_points = point_dataframe.shape[0]
if point_size_scale is not None:
Expand Down
32 changes: 12 additions & 20 deletions datamapplot/plot_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@

import requests
import re
import inspect


cfg = ConfigManager()


class GoogleAPIUnreachable(Warning):
Expand Down Expand Up @@ -167,6 +169,15 @@ def add_glow_to_scatterplot(
)


@cfg.complete(
unconfigurable={
"data_map_coords",
"color_list",
"label_text",
"label_locations",
"label_cluster_sizes",
}
)
def render_plot(
data_map_coords,
color_list,
Expand Down Expand Up @@ -452,25 +463,6 @@ def render_plot(
The axes contained within the figure that the plot is rendered to.
"""
function_signature = inspect.signature(render_plot)
function_args = locals()
config = ConfigManager()

for param_name, param_value in function_signature.parameters.items():
if param_name in (
"data_map_coords",
"color_list",
"label_text",
"label_locations",
"label_cluster_sizes",
):
continue

provided_value = function_args.get(param_name)
if provided_value is param_value.default:
if param_name in config:
function_args[param_name] = config[param_name]

# Create the figure
if ax is None:
fig, ax = plt.subplots(figsize=figsize, dpi=dpi, constrained_layout=True)
Expand Down
Loading

0 comments on commit 7618bda

Please sign in to comment.