From 587335d0916705694fa7db17683c32601294e932 Mon Sep 17 00:00:00 2001 From: Haifeng Wu Date: Fri, 15 Dec 2023 18:26:18 +0800 Subject: [PATCH] update version --- deeptables/_version.py | 2 +- deeptables/models/deeptable.py | 2 +- deeptables/utils/dataset_generator.py | 17 +++++++++++++---- requirements.txt | 3 +-- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/deeptables/_version.py b/deeptables/_version.py index 13a85f7..44b1806 100644 --- a/deeptables/_version.py +++ b/deeptables/_version.py @@ -1 +1 @@ -__version__ = '0.2.5' +__version__ = '0.2.6' diff --git a/deeptables/models/deeptable.py b/deeptables/models/deeptable.py index 8a3ec1d..66dea62 100644 --- a/deeptables/models/deeptable.py +++ b/deeptables/models/deeptable.py @@ -5,7 +5,6 @@ import pickle import time -import dask import numpy as np import pandas as pd from joblib import Parallel, delayed @@ -415,6 +414,7 @@ def fit_cross_validation(self, X, y, X_eval=None, X_test=None, num_folds=5, stra oof_proba = np.full((X_shape[0], 1), np.nan) if is_dask_installed and DaskToolBox.exist_dask_object(X, y): + import dask X = DaskToolBox.reset_index(DaskToolBox.to_dask_frame_or_series(X)) y = DaskToolBox.to_dask_type(y) if DaskToolBox.is_dask_dataframe_or_series(y): diff --git a/deeptables/utils/dataset_generator.py b/deeptables/utils/dataset_generator.py index 1629a17..4702bd6 100644 --- a/deeptables/utils/dataset_generator.py +++ b/deeptables/utils/dataset_generator.py @@ -6,14 +6,12 @@ from distutils.version import LooseVersion from functools import partial -import dask -import dask.dataframe as dd import numpy as np import tensorflow as tf from tensorflow.keras.utils import to_categorical as tf_to_categorical from deeptables.utils import consts, dt_logging - +from hypernets.tabular import get_tool_box, is_dask_installed logger = dt_logging.get_logger(__name__) TFDG_DASK_CHUNK = 100 @@ -105,6 +103,7 @@ def __call__(self, X, y=None, *, batch_size, shuffle, drop_remainder): return ds def _to_ds20(self, X, y=None, *, batch_size, shuffle, drop_remainder): + import dask ds_types = {} ds_shapes = {} meta = self._get_meta(X) @@ -118,6 +117,7 @@ def _to_ds20(self, X, y=None, *, batch_size, shuffle, drop_remainder): ds_types[k] = 'int32' if y is not None: + import dask.dataframe as dd if isinstance(y, dd.Series): y = y.to_dask_array(lengths=True) if self.task == consts.TASK_MULTICLASS: @@ -149,6 +149,7 @@ def to_spec(name, dtype, idx): sig = {k: to_spec(k, dtype, idx) for k, (dtype, idx) in meta.items()} if y is not None: + import dask.dataframe as dd if isinstance(y, dd.Series): y = y.to_dask_array(lengths=True) if self.task == consts.TASK_MULTICLASS: @@ -167,6 +168,7 @@ def to_spec(name, dtype, idx): @staticmethod def _generate(meta, X, y, *, batch_size, shuffle, drop_remainder): + import dask total_size = dask.compute(X.shape)[0][0] chunk_size = min(total_size, batch_size * TFDG_DASK_CHUNK) fn = partial(_TFDGForDask._compute_chunk, X, y, chunk_size) @@ -205,6 +207,7 @@ def _generate(meta, X, y, *, batch_size, shuffle, drop_remainder): @staticmethod def _to_categorical(y, *, num_classes): + import dask if len(y.shape) == 1: y = y.reshape(dask.compute(y.shape[0])[0], 1) fn = partial(tf_to_categorical, num_classes=num_classes, dtype='float32') @@ -213,6 +216,7 @@ def _to_categorical(y, *, num_classes): @staticmethod def _compute_chunk(X, y, chunk_size, i): + import dask try: Xc = X[i:i + chunk_size] yc = y[i:i + chunk_size] if y is not None else None @@ -236,7 +240,12 @@ def _range(start, stop, step, shuffle): def to_dataset(config, task, num_classes, X, y=None, *, batch_size, shuffle, drop_remainder, categorical_columns, continuous_columns, var_len_categorical_columns): - cls = _TFDGForDask if isinstance(X, dd.DataFrame) else _TFDGForPandas + + if is_dask_installed: + import dask.dataframe as dd + cls = _TFDGForDask if isinstance(X, dd.DataFrame) else _TFDGForPandas + else: + cls = _TFDGForPandas logger.info(f'create dataset generator with {cls.__name__}, ' f'batch_size={batch_size}, shuffle={shuffle}, drop_remainder={drop_remainder}') diff --git a/requirements.txt b/requirements.txt index e8f0e21..7c9cb4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ numpy>=1.16.5 scikit-learn>=0.22.1 lightgbm>=2.2.0 category_encoders>=2.1.0 -hypernets>=0.2.5.1 +hypernets>=0.3.0 h5py>=2.10.0 eli5 -dask \ No newline at end of file