Skip to content

Commit

Permalink
fix for auto optimizers running under threads
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Oct 17, 2023
1 parent 36058fd commit 990a258
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
43 changes: 29 additions & 14 deletions cotengra/presets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Preset configured optimizers.
"""
import threading

from .core import ContractionTree
from .hyperoptimizers.hyper import HyperOptimizer, ReusableHyperOptimizer
from .oe import (
Expand Down Expand Up @@ -48,22 +50,33 @@ def __init__(
self.optimal_cutoff = optimal_cutoff
self._optimize_optimal_fn = get_optimize_optimal()

hyperoptimizer_kwargs.setdefault("methods", ("rgreedy",))
hyperoptimizer_kwargs.setdefault("max_repeats", 128)
hyperoptimizer_kwargs.setdefault("max_time", "rate:1e9")
hyperoptimizer_kwargs.setdefault("parallel", False)
hyperoptimizer_kwargs.setdefault("reconf_opts", {})
hyperoptimizer_kwargs["reconf_opts"].setdefault("subtree_size", 4)
hyperoptimizer_kwargs["reconf_opts"].setdefault("maxiter", 100)
self.kwargs = hyperoptimizer_kwargs
self.kwargs.setdefault("methods", ("rgreedy",))
self.kwargs.setdefault("max_repeats", 128)
self.kwargs.setdefault("max_time", "rate:1e9")
self.kwargs.setdefault("parallel", False)
self.kwargs.setdefault("reconf_opts", {})
self.kwargs["reconf_opts"].setdefault("subtree_size", 4)
self.kwargs["reconf_opts"].setdefault("maxiter", 100)

self._hyperoptimizers_by_thread = {}
if cache:
self._optimizer_hyper = ReusableHyperOptimizer(
minimize=minimize, **hyperoptimizer_kwargs
)
self._optimizer_hyper_cls = ReusableHyperOptimizer
else:
self._optimizer_hyper = HyperOptimizer(
minimize=minimize, **hyperoptimizer_kwargs
self._optimizer_hyper_cls = HyperOptimizer

def _get_optimizer_hyper_threadsafe(self):
# since the hyperoptimizer is stateful while running,
# we need to instantiate a separate one for each thread
tid = threading.get_ident()
try:
return self._hyperoptimizers_by_thread[tid]
except KeyError:
opt = self._optimizer_hyper_cls(
minimize=self.minimize, **self.kwargs
)
self._hyperoptimizers_by_thread[tid] = opt
return opt

def search(self, inputs, output, size_dict, **kwargs):
if estimate_optimal_hardness(inputs) < self.optimal_cutoff:
Expand All @@ -84,7 +97,7 @@ def search(self, inputs, output, size_dict, **kwargs):
)
else:
# use hyperoptimizer
return self._optimizer_hyper.search(
return self._get_optimizer_hyper_threadsafe().search(
inputs,
output,
size_dict,
Expand All @@ -104,7 +117,9 @@ def __call__(self, inputs, output, size_dict, **kwargs):
)
else:
# use hyperoptimizer
return self._optimizer_hyper(inputs, output, size_dict, **kwargs)
return self._get_optimizer_hyper_threadsafe()(
inputs, output, size_dict, **kwargs
)


class AutoHQOptimizer(AutoOptimizer):
Expand Down
19 changes: 19 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,22 @@ def test_plotting():
opt.search(inputs, output, size_dict)
opt.plot_trials()
opt.plot_scatter()


def test_auto_optimizers_threadsafe():
from concurrent.futures import ThreadPoolExecutor

pool = ThreadPoolExecutor(2)

contractions = [ctg.utils.tree_equation(100, seed=i) for i in range(10)]
fs = [
pool.submit(
ctg.array_contract_tree,
inputs,
output,
size_dict,
optimize="auto-hq",
)
for inputs, output, _, size_dict in contractions
]
[f.result() for f in fs]

0 comments on commit 990a258

Please sign in to comment.