Skip to content

Commit

Permalink
improve slicing performance
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Apr 25, 2024
1 parent 5f8becf commit 183aff0
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 89 deletions.
167 changes: 80 additions & 87 deletions cotengra/slicer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
"""Functionality for identifying indices to sliced.
"""
"""Functionality for identifying indices to sliced."""

import collections
from math import log

from .core import ContractionTree
from .plot import plot_slicings, plot_slicings_alt
from .scoring import get_score_fn
from .utils import MaxCounter, compute_size_by_dict, get_rng, oset
from .utils import MaxCounter, get_rng

IDX_INVOLVED = 0
IDX_LEGS = 1
IDX_SIZE = 2
IDX_FLOPS = 3


class ContractionCosts:
Expand Down Expand Up @@ -43,7 +48,7 @@ def __init__(
original_flops=None,
):
self.size_dict = dict(size_dict)
self.contractions = tuple(contractions)
self.contractions = list(contractions)

self._flops = 0
self._sizes = MaxCounter()
Expand All @@ -52,15 +57,15 @@ def __init__(
self._where = collections.defaultdict(set)

for i, c in enumerate(self.contractions):
self._flops += c["flops"]
self._sizes.add(c["size"])
self._flops += c[IDX_FLOPS]
self._sizes.add(c[IDX_SIZE])

for ix in c["involved"].union(c["legs"]):
for ix in c[IDX_INVOLVED]:
d = self.size_dict[ix]
self._flop_reductions[ix] += int((1 - 1 / d) * c["flops"])
self._flop_reductions[ix] += c[IDX_FLOPS] - c[IDX_FLOPS] // d
self._where[ix].add(i)
if ix in c["legs"]:
self._write_reductions[ix] += int((1 - 1 / d) * c["size"])
if ix in c[IDX_LEGS]:
self._write_reductions[ix] += c[IDX_SIZE] - c[IDX_SIZE] // d

self.nslices = nslices
if original_flops is None:
Expand All @@ -70,7 +75,7 @@ def __init__(
def _set_state_from(self, other):
"""Copy all internal structure from another ``ContractionCosts``."""
self.size_dict = other.size_dict.copy()
self.contractions = tuple(c.copy() for c in other.contractions)
self.contractions = other.contractions.copy()
self.nslices = other.nslices
self.original_flops = other.original_flops
self._flops = other._flops
Expand All @@ -85,6 +90,31 @@ def copy(self):
new._set_state_from(self)
return new

@classmethod
def from_contraction_tree(cls, contraction_tree, **kwargs):
"""Generate a set of contraction costs from a ``ContractionTree``
object.
"""
size_dict = contraction_tree.size_dict
contractions = (
(
set(contraction_tree.get_involved(node)),
set(contraction_tree.get_legs(node)),
contraction_tree.get_size(node),
contraction_tree.get_flops(node),
)
for node in contraction_tree.info
# ignore leaf nodes
if len(node) != 1
)
return cls(contractions, size_dict, **kwargs)

@classmethod
def from_info(cls, info, **kwargs):
"""Generate a set of contraction costs from a ``PathInfo`` object."""
tree = ContractionTree.from_info(info)
return cls.from_contraction_tree(tree, **kwargs)

@property
def size(self):
return self._sizes.max()
Expand All @@ -101,97 +131,58 @@ def total_flops(self):
def overhead(self):
return self.total_flops / self.original_flops

@classmethod
def from_info(cls, info, **kwargs):
"""Generate a set of contraction costs from a ``PathInfo`` object."""
cs = []
size_dict = info.size_dict

# add all the input 'contractions'
for term in info.input_subscripts.split(","):
cs.append(
{
"involved": oset(),
"legs": oset(term),
"size": compute_size_by_dict(term, size_dict),
"flops": 0,
}
)

for c in info.contraction_list:
eq = c[2]
lhs, rhs = eq.split("->")
legs = oset(rhs)
involved = oset.union(*map(oset, lhs.split(",")))

cs.append(
{
"involved": involved,
"legs": legs,
"size": compute_size_by_dict(legs, size_dict),
"flops": compute_size_by_dict(involved, size_dict),
}
)

return cls(cs, size_dict)

@classmethod
def from_contraction_tree(cls, contraction_tree, **kwargs):
"""Generate a set of contraction costs from a ``ContractionTree``
object.
"""
size_dict = contraction_tree.size_dict
cs = (
{
"involved": oset(contraction_tree.get_involved(node)),
"legs": oset(contraction_tree.get_legs(node)),
"size": contraction_tree.get_size(node),
"flops": contraction_tree.get_flops(node),
}
for node in contraction_tree.info
)
return cls(cs, size_dict, **kwargs)

def remove(self, ix, inplace=False):
""" """
cost = self if inplace else self.copy()

d = cost.size_dict[ix]
cost.nslices *= d
ix_s = oset([ix])

for i in cost._where[ix]:
c = cost.contractions[i]
for i in cost._where.pop(ix):

# update the potential flops reductions of other inds
for oix in c["involved"]:
di = cost.size_dict[oix]
cost._flop_reductions[oix] -= int(
(1 - 1 / di) * c["flops"] * (1 - 1 / d)
)
old_involved, old_legs, old_size, old_flops = cost.contractions[i]

# update the actual flops reduction
old_flops = c["flops"]
new_flops = old_flops // d
cost._flops += new_flops - old_flops
c["flops"] = new_flops
c["involved"] = c["involved"].difference(ix_s)
new_involved = old_involved.copy()
new_involved.discard(ix)

# update the tensor sizes
if ix in c["legs"]:
# update the potential size reductions of other inds
for oix in c["legs"]:
di = cost.size_dict[oix]
cost._write_reductions[oix] -= int(
(1 - 1 / di) * c["size"] * (1 - 1 / d)
)
# update the potential flops reductions of other inds
for oix in new_involved:
di = cost.size_dict[oix]
old_flops_reduction = old_flops - old_flops // di
new_flops_reduction = old_flops_reduction // d
cost._flop_reductions[oix] += (
new_flops_reduction - old_flops_reduction
)

old_size = c["size"]
# update the tensor sizes
if ix in old_legs:
new_size = old_size // d
cost._sizes.discard(old_size)
cost._sizes.add(new_size)
c["size"] = new_size
c["legs"] = c["legs"].difference(ix_s)
new_legs = old_legs.copy()
new_legs.discard(ix)

# update the potential size reductions of other inds
for oix in new_legs:
di = cost.size_dict[oix]
old_size_reduction = old_size - old_size // di
new_size_reduction = old_size_reduction // d
cost._write_reductions[oix] -= (
old_size_reduction - new_size_reduction
)
else:
new_size = old_size
new_legs = old_legs

cost.contractions[i] = (
new_involved,
new_legs,
new_size,
new_flops,
)

del cost.size_dict[ix]
del cost._flop_reductions[ix]
Expand Down Expand Up @@ -376,9 +367,11 @@ def trial(
cost.size_dict,
key=lambda ix:
# the base score
self.minimize.score_slice_index(cost, ix) -
self.minimize.score_slice_index(cost, ix)
-
# a smudge that replicates boltzmann sampling
temperature * log(-log(self.rng.random())) -
temperature * log(-log(self.rng.random()))
-
# penalize forbidden (outer) indices
(0 if ix not in self.forbidden else float("inf")),
)
Expand Down
5 changes: 3 additions & 2 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

**Enhancements**

- add [RandomGreedyOptimizer](cotengra.pathfinders.path_basic.RandomGreedyOptimizer) which is a lightweight and performant randomized greedy optimizer, eschewing
both hyper parameter tuning and full contraction tree construction, making it suitable for very large contractions (10,000s of tensors+).
- add [RandomGreedyOptimizer](cotengra.pathfinders.path_basic.RandomGreedyOptimizer) which is a lightweight and performant randomized greedy optimizer, eschewing both hyper parameter tuning and full contraction tree construction, making it suitable for very large contractions (10,000s of tensors+).
- add [optimize_random_greedy_track_flops](cotengra.pathfinders.path_basic.optimize_random_greedy_track_flops) which runs N trials of (random) greedy path optimization, whilst computing the FLOP count simultaneously. This or its accelerated rust counterpart in `cotengrust` is the driver for the above optimizer.
- add `parallel="threads"` backend, and make it the default for `RandomGreedyOptimizer` when `cotengrust` is present, since its version of `optimize_random_greedy_track_flops` releases the GIL.
- significantly improve both the speed and memory usage of [`SliceFinder`](cotengra.slicer.SliceFinder)
- alias `tree.total_cost()` to `tree.combo_cost()`


## v0.6.0 (2024-04-10)
Expand Down

0 comments on commit 183aff0

Please sign in to comment.