Skip to content

Commit

Permalink
alias total_cost -> combo_cost. add more log kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Apr 20, 2024
1 parent f6250a0 commit 5f8becf
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 25 deletions.
74 changes: 50 additions & 24 deletions cotengra/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def get_centrality(self, node):
self.compute_centralities()
return self.info[node]["centrality"]

def total_flops(self, dtype=None):
def total_flops(self, dtype=None, log=None):
"""Sum the flops contribution from every node in the tree.
Parameters
Expand All @@ -895,13 +895,18 @@ def total_flops(self, dtype=None):
C = self.multiplicity * self._flops

if dtype is None:
return C
pass
elif "float" in dtype:
C *= 2
elif "complex" in dtype:
C *= 4
else:
raise ValueError(f"Unknown dtype {dtype}")

if "float" in dtype:
return 2 * C
if log is not None:
C = math.log(C, log)

if "complex" in dtype:
return 8 * C
return C

def total_write(self):
"""Sum the total amount of memory that will be created and operated on."""
Expand All @@ -914,7 +919,7 @@ def total_write(self):

return self.multiplicity * self._write

def total_cost(self, factor=DEFAULT_COMBO_FACTOR, combine=sum, log=None):
def combo_cost(self, factor=DEFAULT_COMBO_FACTOR, combine=sum, log=None):
t = 0
for p in self.children:
f = self.get_flops(p)
Expand All @@ -928,6 +933,8 @@ def total_cost(self, factor=DEFAULT_COMBO_FACTOR, combine=sum, log=None):

return t

total_cost = combo_cost

def max_size(self, log=None):
"""The size of the largest intermediate tensor."""
if self.N == 1:
Expand Down Expand Up @@ -1011,10 +1018,7 @@ def contraction_scaling(self):

def contraction_cost(self, log=None):
"""Get the total number of scalar operations ~ time complexity."""
C = float(self.total_flops(dtype=None))
if log is not None:
C = math.log(C, log)
return C
return self.total_flops(dtype=None, log=log)

def contraction_width(self, log=2):
"""Get log2 of the size of the largest tensor."""
Expand Down Expand Up @@ -1072,6 +1076,7 @@ def total_flops_compressed(
order="surface_order",
compress_late=None,
dtype=None,
log=None,
):
"""Estimate the total flops for a compressed contraction of this tree
with maximum bond size ``chi``. This includes basic estimates of the
Expand All @@ -1083,31 +1088,44 @@ def total_flops_compressed(
"number of abstract scalar ops."
)

return self.compressed_contract_stats(
F = self.compressed_contract_stats(
chi=chi,
order=order,
compress_late=compress_late,
).flops

if log is not None:
F = math.log(F, log)

return F

contraction_cost_compressed = total_flops_compressed

def total_write_compressed(
self,
chi=None,
order="surface_order",
compress_late=None,
accel="auto",
log=None,
):
"""Compute the total size of all intermediate tensors when a
compressed contraction is performed with maximum bond size ``chi``,
ordered by ``order``. This is relevant maybe for time complexity and
e.g. autodiff space complexity (since every intermediate is kept).
"""
return self.compressed_contract_stats(
W = self.compressed_contract_stats(
chi=chi,
order=order,
compress_late=compress_late,
).write

def total_cost_compressed(
if log is not None:
W = math.log(W, log)

return W

def combo_cost_compressed(
self,
chi=None,
order="surface_order",
Expand All @@ -1123,10 +1141,14 @@ def total_cost_compressed(
) + factor * self.total_write_compressed(
chi=chi, order=order, compress_late=compress_late
)

if log is not None:
C = math.log(C, log)

return C

total_cost_compressed = combo_cost_compressed

def max_size_compressed(
self, chi=None, order="surface_order", compress_late=None, log=None
):
Expand All @@ -1140,8 +1162,10 @@ def max_size_compressed(
order=order,
compress_late=compress_late,
).max_size

if log is not None:
S = math.log(S, log)

return S

def peak_size_compressed(
Expand All @@ -1162,11 +1186,11 @@ def peak_size_compressed(
order=order,
compress_late=compress_late,
).peak_size

if log is not None:
P = math.log(P, log)
return P

contraction_cost_compressed = total_cost_compressed
return P

def contraction_width_compressed(
self, chi=None, order="surface_order", compress_late=None, log=2
Expand Down Expand Up @@ -3526,16 +3550,16 @@ def describe(self, info="normal", join=" "):
if info == "normal":
return join.join(
(
f"log10[FLOPs]={self.contraction_cost(log=10):.4g}",
f"log2[SIZE]={self.contraction_width(log=2):.4g}",
f"log10[FLOPs]={self.total_flops(log=10):.4g}",
f"log2[SIZE]={self.max_size(log=2):.4g}",
)
)

elif info == "full":
s = [
f"log10[FLOPS]={self.contraction_cost(log=10):.4g}",
f"log10[COMBO]={self.total_cost(log=10):.4g}",
f"log2[SIZE]={self.contraction_width(log=2):.4g}",
f"log10[FLOPS]={self.total_flops(log=10):.4g}",
f"log10[COMBO]={self.combo_cost(log=10):.4g}",
f"log2[SIZE]={self.max_size(log=2):.4g}",
f"log2[PEAK]={self.peak_size(log=2):.4g}",
]
if self.sliced_inds:
Expand All @@ -3544,9 +3568,9 @@ def describe(self, info="normal", join=" "):

elif info == "concise":
s = [
f"F={self.contraction_cost(log=10):.4g}",
f"C={self.total_cost(log=10):.4g}",
f"S={self.contraction_width(log=2):.4g}",
f"F={self.total_flops(log=10):.4g}",
f"C={self.combo_cost(log=10):.4g}",
f"S={self.max_size(log=2):.4g}",
f"P={self.peak_size(log=2):.4g}",
]
if self.sliced_inds:
Expand Down Expand Up @@ -3663,6 +3687,7 @@ def get_default_compress_late(self):

total_flops = ContractionTree.total_flops_compressed
total_write = ContractionTree.total_write_compressed
combo_cost = ContractionTree.combo_cost_compressed
total_cost = ContractionTree.total_cost_compressed
max_size = ContractionTree.max_size_compressed
peak_size = ContractionTree.peak_size_compressed
Expand All @@ -3671,6 +3696,7 @@ def get_default_compress_late(self):

total_flops_exact = ContractionTree.total_flops
total_write_exact = ContractionTree.total_write
combo_cost_exact = ContractionTree.combo_cost
total_cost_exact = ContractionTree.total_cost
max_size_exact = ContractionTree.max_size
peak_size_exact = ContractionTree.peak_size
Expand Down
2 changes: 1 addition & 1 deletion cotengra/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def get_dynamic_programming_minimize(self):

def __call__(self, trial):
tree = trial["tree"]
return math.log2(tree.total_cost(factor=self.factor, combine=max))
return math.log2(tree.combo_cost(factor=self.factor, combine=max))


# --------------------- compressed contraction scoring ---------------------- #
Expand Down

0 comments on commit 5f8becf

Please sign in to comment.