Skip to content

Commit

Permalink
Add Parallel prefer parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
vitorandreazza committed May 27, 2024
1 parent 6a8c238 commit 4640088
Showing 1 changed file with 44 additions and 77 deletions.
121 changes: 44 additions & 77 deletions src/fitter/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from scipy.stats import entropy as kl_div
from scipy.stats import kstest
from tqdm import tqdm
import multiprocessing

__all__ = ["get_common_distributions", "get_distributions", "Fitter"]

Expand Down Expand Up @@ -293,33 +294,28 @@ def hist(self):
_ = pylab.hist(self._data, bins=self.bins, density=self._density)
pylab.grid(True)

def _fit_single_distribution(self, distribution):
@staticmethod
def _fit_single_distribution(distribution, data, x, y, timeout):
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)
try:
# need a subprocess to check time it takes. If too long, skip it
dist = eval("scipy.stats." + distribution)

# TODO here, dist.fit may take a while or just hang forever
# with some distributions. So, I thought to use signal module
# to catch the error when signal takes too long. It did not work
# presumably because another try/exception is inside the
# fit function, so I used threading with a recipe from stackoverflow
# See timed_run function above
param = self._timed_run(dist.fit, distribution, args=self._data)
param = Fitter._with_timeout(dist.fit, args=(data,), timeout=timeout)

# with signal, does not work. maybe because another expection is caught
# hoping the order returned by fit is the same as in pdf
pdf_fitted = dist.pdf(self.x, *param)

self.fitted_param[distribution] = param[:]
self.fitted_pdf[distribution] = pdf_fitted
pdf_fitted = dist.pdf(x, *param)

# calculate error
sq_error = pylab.sum((self.fitted_pdf[distribution] - self.y) ** 2)
sq_error = pylab.sum((pdf_fitted - y) ** 2)

# calculate information criteria
logLik = np.sum(dist.logpdf(self.x, *param))
logLik = np.sum(dist.logpdf(x, *param))
k = len(param[:])
n = len(self._data)
n = len(data)
aic = 2 * k - 2 * logLik

# special case of gaussian distribution
Expand All @@ -328,30 +324,21 @@ def _fit_single_distribution(self, distribution):
bic = k * pylab.log(n) - 2 * logLik

# calculate kullback leibler divergence
kullback_leibler = kl_div(self.fitted_pdf[distribution], self.y)
kullback_leibler = kl_div(pdf_fitted, y)

# calculate goodness-of-fit statistic
dist_fitted = dist(*param)
ks_stat, ks_pval = kstest(self._data, dist_fitted.cdf)
ks_stat, ks_pval = kstest(data, dist_fitted.cdf)

logger.info("Fitted {} distribution with error={})".format(distribution, round(sq_error, 6)))

# compute some errors now
self._fitted_errors[distribution] = sq_error
self._aic[distribution] = aic
self._bic[distribution] = bic
self._kldiv[distribution] = kullback_leibler
self._ks_stat[distribution] = ks_stat
self._ks_pval[distribution] = ks_pval
return distribution, (param, pdf_fitted, sq_error, aic, bic, kullback_leibler, ks_stat, ks_pval)
except Exception: # pragma: no cover
logger.warning("SKIPPED {} distribution (taking more than {} seconds)".format(distribution, self.timeout))
# if we cannot compute the error, set it to large values
self._fitted_errors[distribution] = np.inf
self._aic[distribution] = np.inf
self._bic[distribution] = np.inf
self._kldiv[distribution] = np.inf

def fit(self, progress=False, n_jobs=-1, max_workers=-1):
logger.warning("SKIPPED {} distribution (taking more than {} seconds)".format(distribution, timeout))

return distribution, None

def fit(self, progress=False, n_jobs=-1, max_workers=-1, prefer="processes"):
r"""Loop over distributions and find best parameter to fit the data for each
When a distribution is fitted onto the data, we populate a set of
Expand All @@ -365,16 +352,30 @@ def fit(self, progress=False, n_jobs=-1, max_workers=-1):
Indices of the dataframes contains the name of the distribution.
"""
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)

N = len(self.distributions)
with tqdm_joblib(desc=f"Fitting {N} distributions", total=N, disable=not progress) as progress_bar:
Parallel(n_jobs=max_workers, backend="threading")(
delayed(self._fit_single_distribution)(dist) for dist in self.distributions
results = Parallel(n_jobs=max_workers, prefer=prefer)(
delayed(Fitter._fit_single_distribution)(dist, self._data, self.x, self.y, self.timeout) for dist in self.distributions
)

for distribution, values in results:
if values is not None:
param, pdf_fitted, sq_error, aic, bic, kullback_leibler, ks_stat, ks_pval = values

self.fitted_param[distribution] = param
self.fitted_pdf[distribution] = pdf_fitted
self._fitted_errors[distribution] = sq_error
self._aic[distribution] = aic
self._bic[distribution] = bic
self._kldiv[distribution] = kullback_leibler
self._ks_stat[distribution] = ks_stat
self._ks_pval[distribution] = ks_pval
else:
self._fitted_errors[distribution] = np.inf
self._aic[distribution] = np.inf
self._bic[distribution] = np.inf
self._kldiv[distribution] = np.inf

self.df_errors = pd.DataFrame(
{
"sumsquare_error": self._fitted_errors,
Expand Down Expand Up @@ -451,45 +452,11 @@ def summary(self, Nbest=5, lw=2, plot=True, method="sumsquare_error", clf=True):
names = self.df_errors.sort(method).index[0:Nbest]
return self.df_errors.loc[names]

def _timed_run(self, func, distribution, args=(), kwargs={}, default=None):
"""This function will spawn a thread and run the given function
using the args, kwargs and return the given default value if the
timeout is exceeded.
http://stackoverflow.com/questions/492519/timeout-on-a-python-function-call
"""

class InterruptableThread(threading.Thread):
def __init__(self):
threading.Thread.__init__(self)
self.result = default
self.exc_info = (None, None, None)

def run(self):
try:
self.result = func(args, **kwargs)
except Exception as err: # pragma: no cover
self.exc_info = sys.exc_info()

def suicide(self): # pragma: no cover
raise RuntimeError("Stop has been called")

it = InterruptableThread()
it.start()
started_at = datetime.now()
it.join(self.timeout)
ended_at = datetime.now()
diff = ended_at - started_at

if it.exc_info[0] is not None: # pragma: no cover ; if there were any exceptions
a, b, c = it.exc_info
raise Exception(a, b, c) # communicate that to caller

if it.is_alive(): # pragma: no cover
it.suicide()
raise RuntimeError
else:
return it.result
@staticmethod
def _with_timeout(func, args=(), kwargs={}, timeout=30):
with multiprocessing.pool.ThreadPool(1) as pool:
async_result = pool.apply_async(func, args, kwargs)
return async_result.get(timeout=timeout)


""" For book-keeping
Expand Down

0 comments on commit 4640088

Please sign in to comment.