Skip to content

Commit

Permalink
Revert backend changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmey committed Aug 12, 2024
1 parent d64a88b commit b6c6339
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 62 deletions.
146 changes: 84 additions & 62 deletions tensorly/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings

from .core import Backend, backend_array, backend_types, backend_basic_math
from .core import Backend, backend_array
import importlib
import os
import threading
Expand Down Expand Up @@ -32,71 +32,93 @@ def __get__(self, instance, cls=None):


class BackendManager(types.ModuleType):
_functions = (
[
"moveaxis",
"trace",
"shape",
"ndim",
"where",
"copy",
"transpose",
"arange",
"eye",
"kron",
"concatenate",
"max",
"all",
"mean",
"sum",
"prod",
"sign",
"argmin",
"argmax",
"stack",
"conj",
"diag",
"dot",
"tensordot",
"clip",
"kron",
"diagonal",
"lstsq",
"eps",
"finfo",
"solve",
"svd",
"qr",
"randn",
"gamma",
"digamma",
"check_random_state",
"sort",
"eigh",
"index_update",
"context",
"tensor",
"norm",
"to_numpy",
"is_tensor",
"argsort",
"flip",
"asin",
"acos",
"atan",
"asinh",
"acosh",
"atanh",
"logsumexp",
]
+ backend_array
+ backend_basic_math
)
_functions = [
"moveaxis",
"trace",
"shape",
"ndim",
"where",
"copy",
"transpose",
"arange",
"eye",
"kron",
"concatenate",
"max",
"all",
"mean",
"sum",
"prod",
"sign",
"argmin",
"argmax",
"stack",
"conj",
"diag",
"log",
"log2",
"dot",
"tensordot",
"diagonal",
"exp",
"clip",
"kr",
"kron",
"lstsq",
"eps",
"finfo",
"solve",
"svd",
"qr",
"randn",
"gamma",
"digamma",
"check_random_state",
"sort",
"eigh",
"index_update",
"context",
"tensor",
"norm",
"to_numpy",
"is_tensor",
"argsort",
"flip",
"sin",
"cos",
"tan",
"asin",
"acos",
"atan",
"arcsin",
"arccos",
"arctan",
"sinh",
"cosh",
"tanh",
"arcsinh",
"arccosh",
"arctanh",
"asinh",
"acosh",
"atanh",
"partial_svd",
"logsumexp",
] + backend_array
_attributes = [
"int64",
"int32",
"float64",
"float32",
"pi",
"e",
"inf",
"nan",
"complex128",
"complex64",
"index",
"backend_name",
] + backend_types
]
available_backend_names = ["numpy", "pytorch", "tensorflow", "cupy", "jax"]
_default_backend = "numpy"
_loaded_backends = dict()
Expand Down
1 change: 1 addition & 0 deletions tensorly/contrib/sparse/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def inner(*args, **kwargs):
norm = dispatch_sparse(backend.norm)
dot = dispatch_sparse(backend.dot)
kron = dispatch_sparse(backend.kron)
kr = dispatch_sparse(backend.kr)
solve = dispatch_sparse(backend.solve)
qr = dispatch_sparse(backend.qr)
unfold = dispatch_sparse(base.unfold)
Expand Down

0 comments on commit b6c6339

Please sign in to comment.