diff --git a/src/fitter/fitter.py b/src/fitter/fitter.py index 4241101..86b5a45 100644 --- a/src/fitter/fitter.py +++ b/src/fitter/fitter.py @@ -33,6 +33,7 @@ from scipy.stats import entropy as kl_div from scipy.stats import kstest from tqdm import tqdm +import multiprocessing __all__ = ["get_common_distributions", "get_distributions", "Fitter"] @@ -293,33 +294,28 @@ def hist(self): _ = pylab.hist(self._data, bins=self.bins, density=self._density) pylab.grid(True) - def _fit_single_distribution(self, distribution): + @staticmethod + def _fit_single_distribution(distribution, data, x, y, timeout): + import warnings + + warnings.filterwarnings("ignore", category=RuntimeWarning) try: # need a subprocess to check time it takes. If too long, skip it dist = eval("scipy.stats." + distribution) - # TODO here, dist.fit may take a while or just hang forever - # with some distributions. So, I thought to use signal module - # to catch the error when signal takes too long. It did not work - # presumably because another try/exception is inside the - # fit function, so I used threading with a recipe from stackoverflow - # See timed_run function above - param = self._timed_run(dist.fit, distribution, args=self._data) + param = Fitter._with_timeout(dist.fit, args=(data,), timeout=timeout) # with signal, does not work. maybe because another expection is caught # hoping the order returned by fit is the same as in pdf - pdf_fitted = dist.pdf(self.x, *param) - - self.fitted_param[distribution] = param[:] - self.fitted_pdf[distribution] = pdf_fitted + pdf_fitted = dist.pdf(x, *param) # calculate error - sq_error = pylab.sum((self.fitted_pdf[distribution] - self.y) ** 2) + sq_error = pylab.sum((pdf_fitted - y) ** 2) # calculate information criteria - logLik = np.sum(dist.logpdf(self.x, *param)) + logLik = np.sum(dist.logpdf(x, *param)) k = len(param[:]) - n = len(self._data) + n = len(data) aic = 2 * k - 2 * logLik # special case of gaussian distribution @@ -328,30 +324,21 @@ def _fit_single_distribution(self, distribution): bic = k * pylab.log(n) - 2 * logLik # calculate kullback leibler divergence - kullback_leibler = kl_div(self.fitted_pdf[distribution], self.y) + kullback_leibler = kl_div(pdf_fitted, y) # calculate goodness-of-fit statistic dist_fitted = dist(*param) - ks_stat, ks_pval = kstest(self._data, dist_fitted.cdf) + ks_stat, ks_pval = kstest(data, dist_fitted.cdf) logger.info("Fitted {} distribution with error={})".format(distribution, round(sq_error, 6))) - # compute some errors now - self._fitted_errors[distribution] = sq_error - self._aic[distribution] = aic - self._bic[distribution] = bic - self._kldiv[distribution] = kullback_leibler - self._ks_stat[distribution] = ks_stat - self._ks_pval[distribution] = ks_pval + return distribution, (param, pdf_fitted, sq_error, aic, bic, kullback_leibler, ks_stat, ks_pval) except Exception: # pragma: no cover - logger.warning("SKIPPED {} distribution (taking more than {} seconds)".format(distribution, self.timeout)) - # if we cannot compute the error, set it to large values - self._fitted_errors[distribution] = np.inf - self._aic[distribution] = np.inf - self._bic[distribution] = np.inf - self._kldiv[distribution] = np.inf - - def fit(self, progress=False, n_jobs=-1, max_workers=-1): + logger.warning("SKIPPED {} distribution (taking more than {} seconds)".format(distribution, timeout)) + + return distribution, None + + def fit(self, progress=False, n_jobs=-1, max_workers=-1, prefer="processes"): r"""Loop over distributions and find best parameter to fit the data for each When a distribution is fitted onto the data, we populate a set of @@ -365,16 +352,30 @@ def fit(self, progress=False, n_jobs=-1, max_workers=-1): Indices of the dataframes contains the name of the distribution. """ - import warnings - - warnings.filterwarnings("ignore", category=RuntimeWarning) - N = len(self.distributions) with tqdm_joblib(desc=f"Fitting {N} distributions", total=N, disable=not progress) as progress_bar: - Parallel(n_jobs=max_workers, backend="threading")( - delayed(self._fit_single_distribution)(dist) for dist in self.distributions + results = Parallel(n_jobs=max_workers, prefer=prefer)( + delayed(Fitter._fit_single_distribution)(dist, self._data, self.x, self.y, self.timeout) for dist in self.distributions ) + for distribution, values in results: + if values is not None: + param, pdf_fitted, sq_error, aic, bic, kullback_leibler, ks_stat, ks_pval = values + + self.fitted_param[distribution] = param + self.fitted_pdf[distribution] = pdf_fitted + self._fitted_errors[distribution] = sq_error + self._aic[distribution] = aic + self._bic[distribution] = bic + self._kldiv[distribution] = kullback_leibler + self._ks_stat[distribution] = ks_stat + self._ks_pval[distribution] = ks_pval + else: + self._fitted_errors[distribution] = np.inf + self._aic[distribution] = np.inf + self._bic[distribution] = np.inf + self._kldiv[distribution] = np.inf + self.df_errors = pd.DataFrame( { "sumsquare_error": self._fitted_errors, @@ -451,45 +452,11 @@ def summary(self, Nbest=5, lw=2, plot=True, method="sumsquare_error", clf=True): names = self.df_errors.sort(method).index[0:Nbest] return self.df_errors.loc[names] - def _timed_run(self, func, distribution, args=(), kwargs={}, default=None): - """This function will spawn a thread and run the given function - using the args, kwargs and return the given default value if the - timeout is exceeded. - - http://stackoverflow.com/questions/492519/timeout-on-a-python-function-call - """ - - class InterruptableThread(threading.Thread): - def __init__(self): - threading.Thread.__init__(self) - self.result = default - self.exc_info = (None, None, None) - - def run(self): - try: - self.result = func(args, **kwargs) - except Exception as err: # pragma: no cover - self.exc_info = sys.exc_info() - - def suicide(self): # pragma: no cover - raise RuntimeError("Stop has been called") - - it = InterruptableThread() - it.start() - started_at = datetime.now() - it.join(self.timeout) - ended_at = datetime.now() - diff = ended_at - started_at - - if it.exc_info[0] is not None: # pragma: no cover ; if there were any exceptions - a, b, c = it.exc_info - raise Exception(a, b, c) # communicate that to caller - - if it.is_alive(): # pragma: no cover - it.suicide() - raise RuntimeError - else: - return it.result + @staticmethod + def _with_timeout(func, args=(), kwargs={}, timeout=30): + with multiprocessing.pool.ThreadPool(1) as pool: + async_result = pool.apply_async(func, args, kwargs) + return async_result.get(timeout=timeout) """ For book-keeping