Skip to content

Commit

Permalink
Reduce import overhead to improve runtime (#1821)
Browse files Browse the repository at this point in the history
* consolidate how to get supported datatypes

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* fix get backend logic

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* import modin and pyspark in builtin_checks conditionally

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

* fix unit test pa.typing.pyspark

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>

---------

Signed-off-by: cosmicBboy <niels.bantilan@gmail.com>
  • Loading branch information
cosmicBboy authored Sep 30, 2024
1 parent 0207fba commit 4f8bdbf
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 201 deletions.
5 changes: 1 addition & 4 deletions pandera/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=wrong-import-position
"""A flexible and expressive pandas validation library."""

import os
import platform

from pandera._patch_numpy2 import _patch_numpy2
Expand All @@ -12,7 +13,6 @@
import pandera.backends.base.builtin_hypotheses
import pandera.backends.pandas
from pandera import errors, external_config
from pandera.accessors import pandas_accessor
from pandera.api import extensions
from pandera.api.checks import Check
from pandera.api.dataframe.model_components import (
Expand Down Expand Up @@ -82,9 +82,6 @@
from pandera.dtypes import Complex256, Float128


external_config._set_pyspark_environment_variables()


__all__ = [
# dtypes
"Bool",
Expand Down
2 changes: 2 additions & 0 deletions pandera/api/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def register_builtin_check(
fn=None,
strategy: Optional[Callable] = None,
_check_cls: Type = Check,
aliases: Optional[List[str]] = None,
**outer_kwargs,
):
"""Register a check method to the Check namespace.
Expand All @@ -41,6 +42,7 @@ def register_builtin_check(
register_builtin_check,
strategy=strategy,
_check_cls=_check_cls,
aliases=aliases,
**outer_kwargs,
)

Expand Down
191 changes: 139 additions & 52 deletions pandera/api/pandas/types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# pylint: disable=unused-import
"""Utility functions for pandas validation."""

from functools import lru_cache
from typing import NamedTuple, Tuple, Type, Union
from typing import Any, NamedTuple, Type, TypeVar, Union, Optional

import numpy as np
import pandas as pd

from pandera.dtypes import DataType
from pandera.errors import BackendNotFoundError


PandasDtypeInputTypes = Union[
str,
Expand All @@ -17,68 +20,146 @@
np.dtype,
]

SupportedTypes = NamedTuple(
"SupportedTypes",
(
("table_types", Tuple[type, ...]),
("field_types", Tuple[type, ...]),
("index_types", Tuple[type, ...]),
("multiindex_types", Tuple[type, ...]),
),
PANDAS_LIKE_CLS_NAMES = frozenset(
[
"DataFrame",
"Series",
"Index",
"MultiIndex",
"GeoDataFrame",
"GeoSeries",
]
)


@lru_cache(maxsize=None)
def supported_types() -> SupportedTypes:
"""Get the types supported by pandera schemas."""
# pylint: disable=import-outside-toplevel
table_types = [pd.DataFrame]
field_types = [pd.Series]
index_types = [pd.Index]
multiindex_types = [pd.MultiIndex]
class BackendTypes(NamedTuple):

try:
import pyspark.pandas as ps
# list of datatypes available
dataframe_datatypes: tuple
series_datatypes: tuple
index_datatypes: tuple
multiindex_datatypes: tuple
check_backend_types: tuple

table_types.append(ps.DataFrame)
field_types.append(ps.Series)
index_types.append(ps.Index)
multiindex_types.append(ps.MultiIndex)
except ImportError:
pass
try: # pragma: no cover
import modin.pandas as mpd

table_types.append(mpd.DataFrame)
field_types.append(mpd.Series)
index_types.append(mpd.Index)
multiindex_types.append(mpd.MultiIndex)
except ImportError:
pass
try:
@lru_cache
def get_backend_types(check_cls_fqn: str):

dataframe_datatypes = []
series_datatypes = []
index_datatypes = []
multiindex_datatypes = []

mod_name, *mod_path, cls_name = check_cls_fqn.split(".")
if mod_name != "pandera":
if cls_name not in PANDAS_LIKE_CLS_NAMES:
raise BackendNotFoundError(
f"cls_name {cls_name} not in {PANDAS_LIKE_CLS_NAMES}"
)

if mod_name == "pandera":
# assume mod_path e.g. ["typing", "pandas"]
assert mod_path[0] == "typing"
*_, mod_name = mod_path
else:
mod_name = mod_name.split(".")[-1]

def register_pandas_backend():
from pandera.accessors import pandas_accessor

dataframe_datatypes.append(pd.DataFrame)
series_datatypes.append(pd.Series)
index_datatypes.append(pd.Index)
multiindex_datatypes.append(pd.MultiIndex)

def register_dask_backend():
import dask.dataframe as dd
from pandera.accessors import dask_accessor

dataframe_datatypes.append(dd.DataFrame)
series_datatypes.append(dd.Series)
index_datatypes.append(dd.Index)

def register_modin_backend():
import modin.pandas as mpd
from pandera.accessors import modin_accessor

dataframe_datatypes.append(mpd.DataFrame)
series_datatypes.append(mpd.Series)
index_datatypes.append(mpd.Index)
multiindex_datatypes.append(mpd.MultiIndex)

table_types.append(dd.DataFrame)
field_types.append(dd.Series)
index_types.append(dd.Index)
except ImportError:
pass

return SupportedTypes(
tuple(table_types),
tuple(field_types),
tuple(index_types),
tuple(multiindex_types),
def register_pyspark_backend():
import pyspark.pandas as ps
from pandera.accessors import pyspark_accessor

dataframe_datatypes.append(ps.DataFrame)
series_datatypes.append(ps.Series)
index_datatypes.append(ps.Index)
multiindex_datatypes.append(ps.MultiIndex)

def register_geopandas_backend():
import geopandas as gpd

register_pandas_backend()
dataframe_datatypes.append(gpd.GeoDataFrame)
series_datatypes.append(gpd.GeoSeries)

register_fn = {
"pandas": register_pandas_backend,
"dask_expr": register_dask_backend,
"modin": register_modin_backend,
"pyspark": register_pyspark_backend,
"geopandas": register_geopandas_backend,
"pandera": lambda: None,
}[mod_name]

register_fn()

check_backend_types = [
*dataframe_datatypes,
*series_datatypes,
*index_datatypes,
]

return BackendTypes(
dataframe_datatypes=tuple(dataframe_datatypes),
series_datatypes=tuple(series_datatypes),
index_datatypes=tuple(index_datatypes),
multiindex_datatypes=tuple(multiindex_datatypes),
check_backend_types=tuple(check_backend_types),
)


T = TypeVar("T")


def _get_fullname(_cls: Type) -> str:
return f"{_cls.__module__}.{_cls.__name__}"


def get_backend_types_from_mro(_cls: Type) -> Optional[BackendTypes]:
try:
return get_backend_types(_get_fullname(_cls))
except BackendNotFoundError:
for base_cls in _cls.__bases__:
try:
return get_backend_types(_get_fullname(base_cls))
except BackendNotFoundError:
pass
return None


def is_table(obj):
"""Verifies whether an object is table-like.
Where a table is a 2-dimensional data matrix of rows and columns, which
can be indexed in multiple different ways.
"""
return isinstance(obj, supported_types().table_types)
backend_types = get_backend_types_from_mro(type(obj))
return backend_types is not None and isinstance(
obj, backend_types.dataframe_datatypes
)


def is_field(obj):
Expand All @@ -87,27 +168,33 @@ def is_field(obj):
Where a field is a columnar representation of data in a table-like
data structure.
"""
return isinstance(obj, supported_types().field_types)
backend_types = get_backend_types_from_mro(type(obj))
return backend_types is not None and isinstance(
obj, backend_types.series_datatypes
)


def is_index(obj):
"""Verifies whether an object is a table index."""
return isinstance(obj, supported_types().index_types)
backend_types = get_backend_types_from_mro(type(obj))
return backend_types is not None and isinstance(
obj, backend_types.index_datatypes
)


def is_multiindex(obj):
"""Verifies whether an object is a multi-level table index."""
return isinstance(obj, supported_types().multiindex_types)
backend_types = get_backend_types_from_mro(type(obj))
return backend_types is not None and isinstance(
obj, backend_types.multiindex_datatypes
)


def is_table_or_field(obj):
"""Verifies whether an object is table- or field-like."""
return is_table(obj) or is_field(obj)


is_supported_check_obj = is_table_or_field


def is_bool(x):
"""Verifies whether an object is a boolean type."""
return isinstance(x, (bool, np.bool_))
16 changes: 11 additions & 5 deletions pandera/backends/pandas/builtin_checks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Pandas implementation of built-in checks"""

import sys
import operator
import re
from typing import Any, Iterable, Optional, TypeVar, Union, cast
Expand All @@ -10,18 +11,21 @@
from pandera.api.extensions import register_builtin_check


from pandera.typing.modin import MODIN_INSTALLED
from pandera.typing.pyspark import PYSPARK_INSTALLED
MODIN_IMPORTED = "modin" in sys.modules
PYSPARK_IMPORTED = "pyspark" in sys.modules

if MODIN_INSTALLED and not PYSPARK_INSTALLED: # pragma: no cover

# TODO: create a separate module for each framework: dask, modin, pyspark
# so checks are registered for the correct framework.
if MODIN_IMPORTED and not PYSPARK_IMPORTED: # pragma: no cover
import modin.pandas as mpd

PandasData = Union[pd.Series, pd.DataFrame, mpd.Series, mpd.DataFrame]
elif not MODIN_INSTALLED and PYSPARK_INSTALLED: # pragma: no cover
elif not MODIN_IMPORTED and PYSPARK_IMPORTED: # pragma: no cover
import pyspark.pandas as ppd

PandasData = Union[pd.Series, pd.DataFrame, ppd.Series, ppd.DataFrame] # type: ignore[misc]
elif MODIN_INSTALLED and PYSPARK_INSTALLED: # pragma: no cover
elif MODIN_IMPORTED and PYSPARK_IMPORTED: # pragma: no cover
import modin.pandas as mpd
import pyspark.pandas as ppd

Expand All @@ -40,6 +44,8 @@
T = TypeVar("T")


# TODO: remove aliases, it's not needed anymore since aliases are defined in the
# Check class.
@register_builtin_check(
aliases=["eq"],
strategy=st.eq_strategy,
Expand Down
8 changes: 0 additions & 8 deletions pandera/backends/pandas/hypotheses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,6 @@
from pandera.backends.pandas.checks import PandasCheckBackend


try:
from scipy import stats # pylint: disable=unused-import
except ImportError: # pragma: no cover
HAS_SCIPY = False
else:
HAS_SCIPY = True


DEFAULT_ALPHA = 0.01


Expand Down
Loading

0 comments on commit 4f8bdbf

Please sign in to comment.