diff --git a/.github/workflows/Docs.yaml b/.github/workflows/Docs.yaml index e8102ce..3c7267c 100644 --- a/.github/workflows/Docs.yaml +++ b/.github/workflows/Docs.yaml @@ -2,10 +2,10 @@ name: Docs on: push: branches: - - master + - main pull_request: branches: - - master + - main jobs: deploy: @@ -25,7 +25,7 @@ jobs: - name: Build Docs run: mkdocs build - # Only deploy if master, otherwise just build for testing + # Only deploy if main, otherwise just build for testing - name: Deploy - if: endsWith(github.ref, '/master') + if: endsWith(github.ref, '/main') run: mkdocs gh-deploy --force diff --git a/.github/workflows/Linting.yml b/.github/workflows/Linting.yml index e590ee3..68fca58 100644 --- a/.github/workflows/Linting.yml +++ b/.github/workflows/Linting.yml @@ -2,9 +2,9 @@ name: Linting on: push: - branches: master + branches: main pull_request: - branches: master + branches: main jobs: mypy: @@ -46,26 +46,37 @@ jobs: run: | mypy opt_einsum - black: - name: Black + ruff: + name: Ruff + strategy: + fail-fast: false + matrix: + python-version: [3.12] + environment: ["min-deps"] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Python Setup - uses: actions/setup-python@v2.2.1 + - uses: conda-incubator/setup-miniconda@v3 with: - python-version: 3.8 + python-version: ${{ matrix.python-version }} + mamba-version: "*" + channel-priority: true + activate-environment: test + environment-file: devtools/conda-envs/${{ matrix.environment }}-environment.yaml - - name: Create Environment - shell: bash + - name: Environment Information + shell: bash -l {0} run: | - python -m pip install --upgrade pip - python -m pip install black + conda info + conda list + conda config --show-sources + conda config --show - name: Lint - shell: bash + shell: bash -l {0} run: | set -e - black opt_einsum --check + make format-check diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index 18ab6d2..60cab06 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -2,9 +2,9 @@ name: Tests on: push: - branches: master + branches: main pull_request: - branches: master + branches: main jobs: miniconda-setup: @@ -13,7 +13,7 @@ jobs: fail-fast: false matrix: include: - - python-version: 3.9 + - python-version: 3.8 environment: "min-deps" - python-version: 3.12 environment: "min-deps" diff --git a/.gitignore b/.gitignore index 98c59ce..3236e13 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Hatch replaces this file +opt_einsum/_version.py + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index c5adf08..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,9 +0,0 @@ -recursive-include opt_einsum *.py - -include setup.py -include README.md -include LICENSE -include MANIFEST.in - -include versioneer.py -include opt_einsum/_version.py diff --git a/Makefile b/Makefile index 4b91187..1a1ff8e 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,4 @@ .DEFAULT_GOAL := all -isort = isort opt_einsum scripts/ -black = black opt_einsum scripts/ -autoflake = autoflake -ir --remove-all-unused-imports --ignore-init-module-imports --remove-unused-variables opt_einsum scripts/ -mypy = mypy --ignore-missing-imports opt_einsum scripts/ .PHONY: install install: @@ -10,24 +6,17 @@ install: .PHONY: format format: - $(autoflake) - $(isort) - $(black) + ruff check opt_einsum --fix + ruff format opt_einsum .PHONY: format-check format-check: - $(isort) --check-only - $(black) --check - -.PHONY: check-dist -check-dist: - python setup.py check -ms - python setup.py sdist - twine check dist/* + ruff check opt_einsum + ruff format --check opt_einsum .PHONY: mypy mypy: - $(mypy) + mypy opt_einsum .PHONY: test test: diff --git a/devtools/conda-envs/full-environment.yaml b/devtools/conda-envs/full-environment.yaml index a69311b..31969e7 100644 --- a/devtools/conda-envs/full-environment.yaml +++ b/devtools/conda-envs/full-environment.yaml @@ -3,7 +3,7 @@ channels: - conda-forge dependencies: # Base depends - - python >=3.9 + - python >=3.8 - numpy >=1.23 - nomkl @@ -15,10 +15,8 @@ dependencies: - jax # Testing - - autoflake - - black - codecov - - isort - mypy - pytest - pytest-cov + - ruff ==0.5.* diff --git a/devtools/conda-envs/min-deps-environment.yaml b/devtools/conda-envs/min-deps-environment.yaml index d439666..4d93a53 100644 --- a/devtools/conda-envs/min-deps-environment.yaml +++ b/devtools/conda-envs/min-deps-environment.yaml @@ -3,13 +3,11 @@ channels: - conda-forge dependencies: # Base depends - - python >=3.9 + - python >=3.8 # Testing - - autoflake - - black - codecov - - isort - mypy - pytest - pytest-cov + - ruff ==0.5.* diff --git a/devtools/conda-envs/min-ver-environment.yaml b/devtools/conda-envs/min-ver-environment.yaml index 17778c1..c7e02ea 100644 --- a/devtools/conda-envs/min-ver-environment.yaml +++ b/devtools/conda-envs/min-ver-environment.yaml @@ -3,7 +3,7 @@ channels: - conda-forge dependencies: # Base depends - - python >=3.9 + - python >=3.8 - numpy >=1.23 - nomkl @@ -12,10 +12,8 @@ dependencies: - dask ==2021.* # Testing - - autoflake - - black - codecov - - isort - mypy - pytest - pytest-cov + - ruff ==0.5.* diff --git a/devtools/conda-envs/torch-only-environment.yaml b/devtools/conda-envs/torch-only-environment.yaml index fa469fb..6f62222 100644 --- a/devtools/conda-envs/torch-only-environment.yaml +++ b/devtools/conda-envs/torch-only-environment.yaml @@ -4,16 +4,14 @@ channels: - conda-forge dependencies: # Base depends - - python >=3.9 + - python >=3.8 - pytorch::pytorch >=2.0,<3.0.0a - pytorch::cpuonly - mkl # Testing - - autoflake - - black - codecov - - isort - mypy - pytest - pytest-cov + - ruff ==0.5.* diff --git a/opt_einsum/__init__.py b/opt_einsum/__init__.py index 853587e..820bfb1 100644 --- a/opt_einsum/__init__.py +++ b/opt_einsum/__init__.py @@ -1,8 +1,7 @@ -""" -Main init function for opt_einsum. -""" +"""Main init function for opt_einsum.""" from opt_einsum import blas, helpers, path_random, paths +from opt_einsum._version import __version__ from opt_einsum.contract import contract, contract_expression, contract_path from opt_einsum.parser import get_symbol from opt_einsum.path_random import RandomGreedy @@ -10,6 +9,7 @@ from opt_einsum.sharing import shared_intermediates __all__ = [ + "__version__", "blas", "helpers", "path_random", @@ -24,13 +24,6 @@ "shared_intermediates", ] -# Handle versioneer -from opt_einsum._version import get_versions # isort:skip - -versions = get_versions() -__version__ = versions["version"] -__git_revision__ = versions["full-revisionid"] -del get_versions, versions paths.register_path_fn("random-greedy", path_random.random_greedy) paths.register_path_fn("random-greedy-128", path_random.random_greedy_128) diff --git a/opt_einsum/_version.py b/opt_einsum/_version.py index 4234f24..c0ab3b4 100644 --- a/opt_einsum/_version.py +++ b/opt_einsum/_version.py @@ -1,550 +1,16 @@ -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "" - cfg.parentdir_prefix = "None" - cfg.versionfile_source = "opt_einsum/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, - cwd=cwd, - env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr else None), - ) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return { - "version": dirname[len(parentdir_prefix) :], - "full-revisionid": None, - "dirty": False, - "error": None, - "date": None, - } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r"\d", r)]) - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix) :] - if verbose: - print("picking %s" % r) - return { - "version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": None, - "date": date, - } - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return { - "version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, - "error": "no suitable tags", - "date": None, - } - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( - GITS, - [ - "describe", - "--tags", - "--dirty", - "--always", - "--long", - "--match", - "%s*" % tag_prefix, - ], - cwd=root, - ) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[: git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( - full_tag, - tag_prefix, - ) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix) :] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return { - "version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None, - } - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return { - "version": rendered, - "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], - "error": None, - "date": pieces.get("date"), - } - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split("/"): - root = os.path.dirname(root) - except NameError: - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None, - } - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return { - "version": "0+unknown", - "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", - "date": None, - } +# file generated by setuptools_scm +# don't change, don't track in version control +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = '0.0.0.dev' +__version_tuple__ = version_tuple = (0, 0, 0, 'dev', '') diff --git a/opt_einsum/backends/__init__.py b/opt_einsum/backends/__init__.py index 99839d2..487f685 100644 --- a/opt_einsum/backends/__init__.py +++ b/opt_einsum/backends/__init__.py @@ -1,6 +1,4 @@ -""" -Compute backends for opt_einsum. -""" +"""Compute backends for opt_einsum.""" # Backends from opt_einsum.backends.cupy import to_cupy diff --git a/opt_einsum/backends/cupy.py b/opt_einsum/backends/cupy.py index 9fd25ba..34e4763 100644 --- a/opt_einsum/backends/cupy.py +++ b/opt_einsum/backends/cupy.py @@ -1,6 +1,4 @@ -""" -Required functions for optimized contractions of numpy arrays using cupy. -""" +"""Required functions for optimized contractions of numpy arrays using cupy.""" from opt_einsum.helpers import has_array_interface from opt_einsum.sharing import to_backend_cache_wrap diff --git a/opt_einsum/backends/dispatch.py b/opt_einsum/backends/dispatch.py index 0abad45..dd71642 100644 --- a/opt_einsum/backends/dispatch.py +++ b/opt_einsum/backends/dispatch.py @@ -1,5 +1,4 @@ -""" -Handles dispatching array operations to the correct backend library, as well +"""Handles dispatching array operations to the correct backend library, as well as converting arrays to backend formats and then potentially storing them as constants. """ @@ -60,7 +59,7 @@ def _import_func(func: str, backend: str, default: Any = None) -> Any: } try: - import numpy as np + import numpy as np # type: ignore _cached_funcs[("tensordot", "numpy")] = np.tensordot _cached_funcs[("transpose", "numpy")] = np.transpose diff --git a/opt_einsum/backends/jax.py b/opt_einsum/backends/jax.py index d234693..45aa4b2 100644 --- a/opt_einsum/backends/jax.py +++ b/opt_einsum/backends/jax.py @@ -1,6 +1,4 @@ -""" -Required functions for optimized contractions of numpy arrays using jax. -""" +"""Required functions for optimized contractions of numpy arrays using jax.""" from opt_einsum.sharing import to_backend_cache_wrap @@ -12,7 +10,7 @@ def _get_jax_and_to_jax(): global _JAX if _JAX is None: - import jax + import jax # type: ignore @to_backend_cache_wrap @jax.jit @@ -31,7 +29,7 @@ def build_expression(_, expr): # pragma: no cover jax_expr = jax.jit(expr._contract) def jax_contract(*arrays): - import numpy as np + import numpy as np # type: ignore return np.asarray(jax_expr(arrays)) diff --git a/opt_einsum/backends/object_arrays.py b/opt_einsum/backends/object_arrays.py index d870beb..8954a18 100644 --- a/opt_einsum/backends/object_arrays.py +++ b/opt_einsum/backends/object_arrays.py @@ -1,6 +1,4 @@ -""" -Functions for performing contractions with array elements which are objects. -""" +"""Functions for performing contractions with array elements which are objects.""" import functools import operator @@ -24,12 +22,12 @@ def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType: These can be any indexable arrays as long as addition and multiplication is defined on the elements. - Returns + Returns: ------- out : numpy.ndarray The output tensor, with ``dtype=object``. """ - import numpy as np + import numpy as np # type: ignore # when called by ``opt_einsum`` we will always be given a full eq lhs, output = eq.split("->") @@ -47,7 +45,6 @@ def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType: inner_size = tuple(sizes[k] for k in inner) for coo_o in np.ndindex(*out_size): - coord = dict(zip(output, coo_o)) def gen_inner_sum(): diff --git a/opt_einsum/backends/tensorflow.py b/opt_einsum/backends/tensorflow.py index ef2a23c..a6544f8 100644 --- a/opt_einsum/backends/tensorflow.py +++ b/opt_einsum/backends/tensorflow.py @@ -1,6 +1,4 @@ -""" -Required functions for optimized contractions of numpy arrays using tensorflow. -""" +"""Required functions for optimized contractions of numpy arrays using tensorflow.""" from opt_einsum.helpers import has_array_interface from opt_einsum.sharing import to_backend_cache_wrap @@ -14,7 +12,7 @@ def _get_tensorflow_and_device(): global _CACHED_TF_DEVICE if _CACHED_TF_DEVICE is None: - import tensorflow as tf + import tensorflow as tf # type: ignore try: eager = tf.executing_eagerly() diff --git a/opt_einsum/backends/theano.py b/opt_einsum/backends/theano.py index c7ab2ab..5b54aab 100644 --- a/opt_einsum/backends/theano.py +++ b/opt_einsum/backends/theano.py @@ -1,6 +1,4 @@ -""" -Required functions for optimized contractions of numpy arrays using theano. -""" +"""Required functions for optimized contractions of numpy arrays using theano.""" from opt_einsum.helpers import has_array_interface from opt_einsum.sharing import to_backend_cache_wrap @@ -11,7 +9,7 @@ @to_backend_cache_wrap(constants=True) def to_theano(array, constant=False): """Convert a numpy array to ``theano.tensor.TensorType`` instance.""" - import theano + import theano # type: ignore if has_array_interface(array): if constant: diff --git a/opt_einsum/backends/torch.py b/opt_einsum/backends/torch.py index 561a0a0..f7f01f2 100644 --- a/opt_einsum/backends/torch.py +++ b/opt_einsum/backends/torch.py @@ -1,6 +1,4 @@ -""" -Required functions for optimized contractions of numpy arrays using pytorch. -""" +"""Required functions for optimized contractions of numpy arrays using pytorch.""" from opt_einsum.helpers import has_array_interface from opt_einsum.parser import convert_to_valid_einsum_chars @@ -26,7 +24,7 @@ def _get_torch_and_device(): global _TORCH_HAS_TENSORDOT if _TORCH_DEVICE is None: - import torch + import torch # type: ignore device = "cuda" if torch.cuda.is_available() else "cpu" _TORCH_DEVICE = torch, device diff --git a/opt_einsum/blas.py b/opt_einsum/blas.py index 487ace2..7d119b5 100644 --- a/opt_einsum/blas.py +++ b/opt_einsum/blas.py @@ -1,12 +1,10 @@ -""" -Determines if a contraction can use BLAS or not -""" +"""Determines if a contraction can use BLAS or not.""" from typing import List, Sequence, Tuple, Union from opt_einsum.typing import ArrayIndexType -__all__ = ["can_blas", "tensor_blas"] +__all__ = ["can_blas"] def can_blas( @@ -15,8 +13,7 @@ def can_blas( idx_removed: ArrayIndexType, shapes: Union[Sequence[Tuple[int]], None] = None, ) -> Union[str, bool]: - """ - Checks if we can use a BLAS call. + """Checks if we can use a BLAS call. Parameters ---------- @@ -29,19 +26,19 @@ def can_blas( shapes : sequence of tuple[int], optional If given, check also that none of the indices are broadcast dimensions. - Returns + Returns: ------- type : str or bool The type of BLAS call to be used or False if none. - Notes + Notes: ----- We assume several operations are not efficient such as a transposed DDOT, therefore 'ijk,jki->' should prefer einsum. These return the blas type appended with "/EINSUM" to differentiate when they can still be done with tensordot if required, e.g. when a backend has no einsum. - Examples + Examples: -------- >>> can_blas(['ij', 'jk'], 'ik', set('j')) 'GEMM' diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index ee3c2db..7a7dac9 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -1,6 +1,4 @@ -""" -Contains the primary optimization and contraction routines. -""" +"""Contains the primary optimization and contraction routines.""" from decimal import Decimal from functools import lru_cache @@ -63,7 +61,7 @@ def __init__( self.size_dict = size_dict self.shapes = [tuple(size_dict[k] for k in ks) for ks in input_subscripts.split(",")] - self.eq = "{}->{}".format(input_subscripts, output_subscript) + self.eq = f"{input_subscripts}->{output_subscript}" self.largest_intermediate = Decimal(max(size_list, default=1)) def __repr__(self) -> str: @@ -71,13 +69,13 @@ def __repr__(self) -> str: header = ("scaling", "BLAS", "current", "remaining") path_print = [ - " Complete contraction: {}\n".format(self.eq), - " Naive scaling: {}\n".format(len(self.indices)), - " Optimized scaling: {}\n".format(max(self.scale_list, default=0)), - " Naive FLOP count: {:.3e}\n".format(self.naive_cost), - " Optimized FLOP count: {:.3e}\n".format(self.opt_cost), - " Theoretical speedup: {:.3e}\n".format(self.speedup), - " Largest intermediate: {:.3e} elements\n".format(self.largest_intermediate), + f" Complete contraction: {self.eq}\n", + f" Naive scaling: {len(self.indices)}\n", + f" Optimized scaling: {max(self.scale_list, default=0)}\n", + f" Naive FLOP count: {self.naive_cost:.3e}\n", + f" Optimized FLOP count: {self.opt_cost:.3e}\n", + f" Theoretical speedup: {self.speedup:.3e}\n", + f" Largest intermediate: {self.largest_intermediate:.3e} elements\n", "-" * 80 + "\n", "{:>6} {:>11} {:>22} {:>37}\n".format(*header), "-" * 80, @@ -181,8 +179,7 @@ def contract_path( shapes: bool = False, **kwargs: Any, ) -> Tuple[PathType, PathInfo]: - """ - Find a contraction order `path`, without performing the contraction. + """Find a contraction order `path`, without performing the contraction. Parameters: subscripts: Specifies the subscripts for summation. @@ -224,16 +221,16 @@ def contract_path( shapes: Whether ``contract_path`` should assume arrays (the default) or array shapes have been supplied. - Returns: + Returns: path: The optimized einsum contraciton path PathInfo: A printable object containing various information about the path found. - Notes: + Notes: The resulting path indicates which terms of the input contraction should be contracted first, the result of this contraction is then appended to the end of the contraction list. - Examples: + Examples: We can begin with a chain dot example. In this case, it is optimal to contract the b and c tensors represented by the first element of the path (1, 2). The resulting tensor is added to the end of the contraction and the @@ -316,8 +313,8 @@ def contract_path( if len(sh) != len(term): raise ValueError( - "Einstein sum subscript '{}' does not contain the " - "correct number of indices for operand {}.".format(input_list[tnum], tnum) + f"Einstein sum subscript '{input_list[tnum]}' does not contain the " + f"correct number of indices for operand {tnum}." ) for cnum, char in enumerate(term): dim = int(sh[cnum]) @@ -328,8 +325,8 @@ def contract_path( size_dict[char] = dim elif dim not in (1, size_dict[char]): raise ValueError( - "Size of label '{}' for operand {} ({}) does not match previous " - "terms ({}).".format(char, tnum, size_dict[char], dim) + f"Size of label '{char}' for operand {tnum} ({size_dict[char]}) does not match previous " + f"terms ({dim})." ) else: size_dict[char] = dim @@ -368,7 +365,7 @@ def contract_path( # Build contraction tuple (positions, gemm, einsum_str, remaining) for cnum, contract_inds in enumerate(path_tuple): # Make sure we remove inds from right to left - contract_inds = tuple(sorted(list(contract_inds), reverse=True)) + contract_inds = tuple(sorted(contract_inds, reverse=True)) contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set) out_inds, input_sets, idx_removed, idx_contract = contract_tuple @@ -445,7 +442,6 @@ def _einsum(*operands: Any, **kwargs: Any) -> ArrayType: # Do we need to temporarily map indices into [a-z,A-Z] range? if not parser.has_valid_einsum_chars_only(einsum_str): - # Explicitly find output str first so as to maintain order if "->" not in einsum_str: einsum_str += "->" + parser.find_output_str(einsum_str) @@ -523,8 +519,7 @@ def contract( backend: BackendType = "auto", **kwargs: Any, ) -> ArrayType: - """ - Evaluates the Einstein summation convention on the operands. A drop in + """Evaluates the Einstein summation convention on the operands. A drop in replacement for NumPy's einsum function that optimizes the order of contraction to reduce overall scaling at the cost of several intermediate arrays. @@ -663,7 +658,6 @@ def _core_contract( """Inner loop used to perform an actual contraction given the output from a ``contract_path(..., einsum_call=True)`` call. """ - # Special handling if out is specified specified_out = out is not None @@ -689,7 +683,6 @@ def _core_contract( # Call tensordot (check if should prefer einsum, but only if available) if blas_flag and ("EINSUM" not in blas_flag or no_einsum): # type: ignore - # Checks have already been handled input_str, results_index = einsum_str.split("->") input_left, input_right = input_str.split(",") @@ -714,7 +707,6 @@ def _core_contract( # Build a new view if needed if (tensor_result != results_index) or handle_out: - transpose = tuple(map(tensor_result.index, results_index)) new_view = _transpose(new_view, axes=transpose, backend=backend) @@ -757,7 +749,7 @@ def format_const_einsum_str(einsum_str: str, constants: Iterable[int]) -> str: else: lhs, rhs, arrow = einsum_str, "", "" - wrapped_terms = ["[{}]".format(t) if i in constants else t for i, t in enumerate(lhs.split(","))] + wrapped_terms = [f"[{t}]" if i in constants else t for i, t in enumerate(lhs.split(","))] formatted_einsum_str = "{}{}{}".format(",".join(wrapped_terms), arrow, rhs) @@ -904,15 +896,14 @@ def __call__( Returns: The contracted result. """ - backend = parse_backend(arrays, backend) correct_num_args = self._full_num_args if evaluate_constants else self.num_args if len(arrays) != correct_num_args: raise ValueError( - "This `ContractExpression` takes exactly {} array arguments " - "but received {}.".format(self.num_args, len(arrays)) + f"This `ContractExpression` takes exactly {self.num_args} array arguments " + f"but received {len(arrays)}." ) if self._constants_dict and not evaluate_constants: @@ -935,23 +926,23 @@ def __call__( msg = ( "Internal error while evaluating `ContractExpression`. Note that few checks are performed" " - the number and rank of the array arguments must match the original expression. " - "The internal error was: '{}'".format(original_msg), + f"The internal error was: '{original_msg}'", ) err.args = msg raise def __repr__(self) -> str: if self._constants_dict: - constants_repr = ", constants={}".format(sorted(self._constants_dict)) + constants_repr = f", constants={sorted(self._constants_dict)}" else: constants_repr = "" - return "".format(self.contraction, constants_repr) + return f"" def __str__(self) -> str: s = [self.__repr__()] for i, c in enumerate(self.contraction_list): - s.append("\n {}. ".format(i + 1)) - s.append("'{}'".format(c[2]) + (" [{}]".format(c[-1]) if c[-1] else "")) + s.append(f"\n {i + 1}. ") + s.append(f"'{c[2]}'" + (f" [{c[-1]}]" if c[-1] else "")) kwargs = {"dtype": self.dtype, "order": self.order, "casting": self.casting} s.append(f"\neinsum_kwargs={kwargs}") return "".join(s) @@ -1018,7 +1009,6 @@ def contract_expression( Callable with signature `expr(*arrays, out=None, backend='numpy')` where the array's shapes should match `shapes`. Notes: - The `out` keyword argument should be supplied to the generated expression rather than this function. The `backend` keyword argument should also be supplied to the generated @@ -1031,7 +1021,6 @@ def contract_expression( backend, then subsequently reused. Examples: - Basic usage: ```python @@ -1061,8 +1050,7 @@ def contract_expression( for arg in ("out", "backend"): if kwargs.get(arg, None) is not None: raise ValueError( - "'{}' should only be specified when calling a " - "`ContractExpression`, not when building it.".format(arg) + f"'{arg}' should only be specified when calling a " "`ContractExpression`, not when building it." ) if not isinstance(subscripts, str): @@ -1071,7 +1059,7 @@ def contract_expression( kwargs["_gen_expression"] = True # build dict of constant indices mapped to arrays - constants = constants or tuple() + constants = constants or () constants_dict = {i: shapes[i] for i in constants} kwargs["_constants_dict"] = constants_dict diff --git a/opt_einsum/helpers.py b/opt_einsum/helpers.py index 5775259..3d8b41e 100644 --- a/opt_einsum/helpers.py +++ b/opt_einsum/helpers.py @@ -1,16 +1,14 @@ -""" -Contains helper functions for opt_einsum testing scripts -""" +"""Contains helper functions for opt_einsum testing scripts.""" from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Tuple, overload from opt_einsum.typing import ArrayIndexType, ArrayType -__all__ = ["build_views", "compute_size_by_dict", "find_contraction", "flop_count"] +__all__ = ["compute_size_by_dict", "find_contraction", "flop_count"] _valid_chars = "abcdefghijklmopqABC" _sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4] -_default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)} +_default_dim_dict = dict(zip(_valid_chars, _sizes)) @overload @@ -22,8 +20,7 @@ def compute_size_by_dict(indices: Collection[str], idx_dict: Dict[str, int]) -> def compute_size_by_dict(indices: Any, idx_dict: Any) -> int: - """ - Computes the product of the elements in indices based on the dictionary + """Computes the product of the elements in indices based on the dictionary idx_dict. Parameters @@ -33,12 +30,12 @@ def compute_size_by_dict(indices: Any, idx_dict: Any) -> int: idx_dict : dictionary Dictionary of index _sizes - Returns + Returns: ------- ret : int The resulting product. - Examples + Examples: -------- >>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5}) 90 @@ -55,8 +52,7 @@ def find_contraction( input_sets: List[ArrayIndexType], output_set: ArrayIndexType, ) -> Tuple[FrozenSet[str], List[ArrayIndexType], ArrayIndexType, ArrayIndexType]: - """ - Finds the contraction for a given set of input and output sets. + """Finds the contraction for a given set of input and output sets. Parameters ---------- @@ -67,7 +63,7 @@ def find_contraction( output_set : set Set that represents the rhs side of the overall einsum subscript - Returns + Returns: ------- new_result : set The indices of the resulting contraction @@ -79,9 +75,8 @@ def find_contraction( idx_contraction : set The indices used in the current contraction - Examples + Examples: -------- - # A simple dot product test case >>> pos = (0, 1) >>> isets = [set('ab'), set('bc')] @@ -96,7 +91,6 @@ def find_contraction( >>> find_contraction(pos, isets, oset) ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'}) """ - remaining = list(input_sets) inputs = (remaining.pop(i) for i in sorted(positions, reverse=True)) idx_contract = frozenset.union(*inputs) @@ -115,8 +109,7 @@ def flop_count( num_terms: int, size_dictionary: Dict[str, int], ) -> int: - """ - Computes the number of FLOPS in the contraction. + """Computes the number of FLOPS in the contraction. Parameters ---------- @@ -129,14 +122,13 @@ def flop_count( size_dictionary : dict The size of each of the indices in idx_contraction - Returns + Returns: ------- flop_count : int The total number of FLOPS required for the contraction. - Examples + Examples: -------- - >>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5}) 30 @@ -144,7 +136,6 @@ def flop_count( 60 """ - overall_size = compute_size_by_dict(idx_contraction, size_dictionary) op_factor = max(1, num_terms - 1) if inner: diff --git a/opt_einsum/parser.py b/opt_einsum/parser.py index 15f7181..316a421 100644 --- a/opt_einsum/parser.py +++ b/opt_einsum/parser.py @@ -1,10 +1,7 @@ -""" -A functionally equivalent parser of the numpy.einsum input parser -""" +"""A functionally equivalent parser of the numpy.einsum input parser.""" import itertools -from collections.abc import Sequence -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Sequence, Tuple from opt_einsum.typing import ArrayType, TensorShapeType @@ -107,7 +104,7 @@ def convert_to_valid_einsum_chars(einsum_str: str) -> str: valid for numpy einsum. If there are too many symbols, let the backend throw an error. - Examples + Examples: -------- >>> oe.parser.convert_to_valid_einsum_chars("Ĥěļļö") 'cbdda' @@ -120,7 +117,7 @@ def convert_to_valid_einsum_chars(einsum_str: str) -> str: def alpha_canonicalize(equation: str) -> str: """Alpha convert an equation in an order-independent canonical way. - Examples + Examples: -------- >>> oe.parser.alpha_canonicalize("dcba") 'abcd' @@ -138,11 +135,10 @@ def alpha_canonicalize(equation: str) -> str: def find_output_str(subscripts: str) -> str: - """ - Find the output string for the inputs ``subscripts`` under canonical einstein summation rules. + """Find the output string for the inputs ``subscripts`` under canonical einstein summation rules. That is, repeated indices are summed over by default. - Examples + Examples: -------- >>> oe.parser.find_output_str("ab,bc") 'ac' @@ -161,7 +157,7 @@ def find_output_shape(inputs: List[str], shapes: List[TensorShapeType], output: """Find the output shape for given inputs, shapes and output string, taking into account broadcasting. - Examples + Examples: -------- >>> oe.parser.find_output_shape(["ab", "bc"], [(2, 3), (3, 4)], "ac") (2, 4) @@ -182,11 +178,10 @@ def get_shape(x: Any) -> TensorShapeType: Array-like objects are those that have a `shape` attribute, are sequences of BaseTypes, or are BaseTypes. BaseTypes are defined as `bool`, `int`, `float`, `complex`, `str`, and `bytes`. """ - if hasattr(x, "shape"): return x.shape elif isinstance(x, _BaseTypes): - return tuple() + return () elif isinstance(x, Sequence): shape = [] while isinstance(x, Sequence) and not isinstance(x, _BaseTypes): @@ -200,7 +195,7 @@ def get_shape(x: Any) -> TensorShapeType: def possibly_convert_to_numpy(x: Any) -> Any: """Convert things without a 'shape' to ndarrays, but leave everything else. - Examples + Examples: -------- >>> oe.parser.possibly_convert_to_numpy(5) array(5) @@ -221,10 +216,9 @@ def possibly_convert_to_numpy(x: Any) -> Any: >>> oe.parser.possibly_convert_to_numpy(myshape) <__main__.Shape object at 0x10f850710> """ - if not hasattr(x, "shape"): try: - import numpy as np + import numpy as np # type: ignore except ModuleNotFoundError: raise ModuleNotFoundError( "numpy is required to convert non-array objects to arrays. This function will be deprecated in the future." @@ -238,7 +232,7 @@ def possibly_convert_to_numpy(x: Any) -> Any: def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str: """Convert user custom subscripts list to subscript string according to `symbol_map`. - Examples + Examples: -------- >>> oe.parser.convert_subscripts(['abc', 'def'], {'abc':'a', 'def':'b'}) 'ab' @@ -293,8 +287,7 @@ def convert_interleaved_input(operands: Sequence[Any]) -> Tuple[str, Tuple[Any, def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]: - """ - A reproduction of einsum c side einsum parsing in python. + """A reproduction of einsum c side einsum parsing in python. Parameters: operands: Intakes the same inputs as `contract_path`, but NOT the keyword args. The only @@ -302,7 +295,7 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L shapes: Whether ``parse_einsum_input`` should assume arrays (the default) or array shapes have been supplied. - Returns + Returns: input_strings: Parsed input strings output_string: Parsed output string operands: The operands to use in the numpy contraction @@ -320,14 +313,13 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L ('za,xza', 'xz', [a, b]) ``` """ - if len(operands) == 0: raise ValueError("No input operands") if isinstance(operands[0], str): subscripts = operands[0].replace(" ", "") if shapes: - if any([hasattr(o, "shape") for o in operands[1:]]): + if any(hasattr(o, "shape") for o in operands[1:]): raise ValueError( "shapes is set to True but given at least one operand looks like an array" " (at least one operand has a shape attribute). " @@ -409,9 +401,9 @@ def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, L # Make sure output subscripts are unique and in the input for char in output_subscript: if output_subscript.count(char) != 1: - raise ValueError("Output character '{}' appeared more than once in the output.".format(char)) + raise ValueError(f"Output character '{char}' appeared more than once in the output.") if char not in input_subscripts: - raise ValueError("Output character '{}' did not appear in the input".format(char)) + raise ValueError(f"Output character '{char}' did not appear in the input") # Make sure number operands is equivalent to the number of terms if len(input_subscripts.split(",")) != len(operands): diff --git a/opt_einsum/path_random.py b/opt_einsum/path_random.py index 3ed51e6..1376822 100644 --- a/opt_einsum/path_random.py +++ b/opt_einsum/path_random.py @@ -1,6 +1,4 @@ -""" -Support for random optimizers, including the random-greedy path. -""" +"""Support for random optimizers, including the random-greedy path.""" import functools import heapq @@ -72,7 +70,6 @@ def __init__( parallel: Union[bool, Decimal, int] = False, pre_dispatch: int = 128, ): - if minimize not in ("flops", "size"): raise ValueError("`minimize` should be one of {'flops', 'size'}.") @@ -186,7 +183,6 @@ def __call__( # assess the trials for ssa_path, cost, size in trials: - # keep track of all costs and sizes self.costs.append(cost) self.sizes.append(size) @@ -328,7 +324,6 @@ def _trial_greedy_ssa_path_and_cost( class RandomGreedy(RandomOptimizer): - def __init__( self, cost_fn: str = "memory-removed-jitter", @@ -337,24 +332,23 @@ def __init__( nbranch: int = 8, **kwargs: Any, ): - """ - Parameters: - cost_fn: A function that returns a heuristic 'cost' of a potential contraction - with which to sort candidates. Should have signature - `cost_fn(size12, size1, size2, k12, k1, k2)`. - temperature: When choosing a possible contraction, its relative probability will be - proportional to `exp(-cost / temperature)`. Thus the larger - `temperature` is, the further random paths will stray from the normal - 'greedy' path. Conversely, if set to zero, only paths with exactly the - same cost as the best at each step will be explored. - rel_temperature: Whether to normalize the ``temperature`` at each step to the scale of - the best cost. This is generally beneficial as the magnitude of costs - can vary significantly throughout a contraction. If False, the - algorithm will end up branching when the absolute cost is low, but - stick to the 'greedy' path when the cost is high - this can also be - beneficial. - nbranch: How many potential paths to calculate probability for and choose from at each step. - kwargs: Supplied to RandomOptimizer. + """Parameters: + cost_fn: A function that returns a heuristic 'cost' of a potential contraction + with which to sort candidates. Should have signature + `cost_fn(size12, size1, size2, k12, k1, k2)`. + temperature: When choosing a possible contraction, its relative probability will be + proportional to `exp(-cost / temperature)`. Thus the larger + `temperature` is, the further random paths will stray from the normal + 'greedy' path. Conversely, if set to zero, only paths with exactly the + same cost as the best at each step will be explored. + rel_temperature: Whether to normalize the ``temperature`` at each step to the scale of + the best cost. This is generally beneficial as the magnitude of costs + can vary significantly throughout a contraction. If False, the + algorithm will end up branching when the absolute cost is low, but + stick to the 'greedy' path when the cost is high - this can also be + beneficial. + nbranch: How many potential paths to calculate probability for and choose from at each step. + kwargs: Supplied to RandomOptimizer. """ self.cost_fn = cost_fn self.temperature = temperature @@ -396,7 +390,7 @@ def random_greedy( memory_limit: Optional[int] = None, **optimizer_kwargs: Any, ) -> ArrayType: - """ """ + """A simple wrapper around the `RandomGreedy` optimizer.""" optimizer = RandomGreedy(**optimizer_kwargs) return optimizer(inputs, output, idx_dict, memory_limit) diff --git a/opt_einsum/paths.py b/opt_einsum/paths.py index ac330b8..2f46903 100644 --- a/opt_einsum/paths.py +++ b/opt_einsum/paths.py @@ -1,6 +1,4 @@ -""" -Contains the path technology behind opt_einsum in addition to several path helpers -""" +"""Contains the path technology behind opt_einsum in addition to several path helpers.""" import bisect import functools @@ -9,10 +7,9 @@ import operator import random import re -from collections import Counter, OrderedDict, defaultdict -from typing import Any, Callable +from collections import Counter, defaultdict +from typing import Any, Callable, Dict, FrozenSet, Generator, List, Optional, Sequence, Set, Tuple, Union from typing import Counter as CounterType -from typing import Dict, FrozenSet, Generator, List, Optional, Sequence, Set, Tuple, Union from opt_einsum.helpers import compute_size_by_dict, flop_count from opt_einsum.typing import ArrayIndexType, PathSearchFunctionType, PathType, TensorShapeType @@ -33,7 +30,7 @@ class PathOptimizer: - """Base class for different path optimizers to inherit from. + r"""Base class for different path optimizers to inherit from. Subclassed optimizers should define a call method with signature: @@ -83,8 +80,7 @@ def __call__( def ssa_to_linear(ssa_path: PathType) -> PathType: - """ - Convert a path with static single assignment ids to a path with recycled + """Convert a path with static single assignment ids to a path with recycled linear ids. Example: @@ -101,10 +97,10 @@ def ssa_to_linear(ssa_path: PathType) -> PathType: # ids[ssa_id:] -= 1 # return path - N = sum(map(len, ssa_path)) - len(ssa_path) + 1 - ids = list(range(N)) + n = sum(map(len, ssa_path)) - len(ssa_path) + 1 + ids = list(range(n)) path = [] - ssa = N + ssa = n for scon in ssa_path: con = sorted([bisect.bisect_left(ids, s) for s in scon]) for j in reversed(con): @@ -130,8 +126,7 @@ def ssa_to_linear(ssa_path: PathType) -> PathType: def linear_to_ssa(path: PathType) -> PathType: - """ - Convert a path with recycled linear ids to a path with static single + """Convert a path with recycled linear ids to a path with static single assignment ids. Exmaple: @@ -160,8 +155,7 @@ def calc_k12_flops( j: int, size_dict: Dict[str, int], ) -> Tuple[FrozenSet[str], int]: - """ - Calculate the resulting indices and flops for a potential pairwise + """Calculate the resulting indices and flops for a potential pairwise contraction - used in the recursive (optimal/branch) algorithms. Parameters: @@ -195,8 +189,7 @@ def _compute_oversize_flops( output: ArrayIndexType, size_dict: Dict[str, int], ) -> int: - """ - Compute the flop count for a contraction of all remaining arguments. This + """Compute the flop count for a contraction of all remaining arguments. This is used when a memory limit means that no pairwise contractions can be made. """ idx_contraction = frozenset.union(*map(inputs.__getitem__, remaining)) # type: ignore @@ -211,8 +204,7 @@ def optimal( size_dict: Dict[str, int], memory_limit: Optional[int] = None, ) -> PathType: - """ - Computes all possible pair contractions in a depth-first recursive manner, + """Computes all possible pair contractions in a depth-first recursive manner, sieving results based on `memory_limit` and the best path found so far. Parameters: @@ -225,7 +217,6 @@ def optimal( path: The optimal contraction order within the memory limit constraint. Examples: - ```python isets = [set('abd'), set('ac'), set('bdc')] oset = set('') @@ -243,7 +234,6 @@ def optimal( result_cache: Dict[Tuple[ArrayIndexType, ArrayIndexType], Tuple[FrozenSet[str], int]] = {} def _optimal_iterate(path, remaining, inputs, flops): - # reached end of path (only ever get here if flops is best found so far) if len(remaining) == 1: best_flops["flops"] = flops @@ -345,8 +335,7 @@ def __init__( minimize: str = "flops", cost_fn: str = "memory-removed", ): - """ - Explores possible pair contractions in a depth-first recursive manner like + """Explores possible pair contractions in a depth-first recursive manner like the `optimal` approach, but with extra heuristic early pruning of branches as well sieving by `memory_limit` and the best path found so far. @@ -388,19 +377,16 @@ def __call__( size_dict: Dict[str, int], memory_limit: Optional[int] = None, ) -> PathType: - """ - - Parameters: + """Parameters: inputs_: List of sets that represent the lhs side of the einsum subscript output_: Set that represents the rhs side of the overall einsum subscript size_dict: Dictionary of index sizes - memory_limit: The maximum number of elements in a temporary array + memory_limit: The maximum number of elements in a temporary array. Returns: path: The contraction order within the memory limit constraint. Examples: - ```python isets = [set('abd'), set('ac'), set('bdc')] oset = set('') @@ -417,7 +403,6 @@ def __call__( result_cache: Dict[Tuple[FrozenSet[str], FrozenSet[str]], Tuple[FrozenSet[str], int]] = {} def _branch_iterate(path, inputs, remaining, flops, size): - # reached end of path (only ever get here if flops is best found so far) if len(remaining) == 1: self.best["size"] = size @@ -612,8 +597,7 @@ def ssa_greedy_optimize( choose_fn: Any = None, cost_fn: Any = "memory-removed", ) -> PathType: - """ - This is the core function for :func:`greedy` but produces a path with + """This is the core function for :func:`greedy` but produces a path with static single assignment ids rather than recycled linear ids. SSA ids are cheaper to work with and easier to reason about. """ @@ -660,7 +644,7 @@ def ssa_greedy_optimize( # used it can be contracted. Since we specialize to binary ops, we only care about # ref counts of >=2 or >=3. dim_ref_counts = { - count: set(dim for dim, keys in dim_to_keys.items() if len(keys) >= count) - output for count in [2, 3] + count: {dim for dim, keys in dim_to_keys.items() if len(keys) >= count} - output for count in [2, 3] } # Compute separable part of the objective function for contractions. @@ -687,7 +671,6 @@ def ssa_greedy_optimize( # Greedily contract pairs of tensors. while queue: - con = choose_fn(queue, remaining) if con is None: continue # allow choose_fn to flag all candidates obsolete @@ -711,7 +694,7 @@ def ssa_greedy_optimize( # Find new candidate contractions. k1 = k12 - k2s = set(k2 for dim in k1 for k2 in dim_to_keys[dim]) + k2s = {k2 for dim in k1 for k2 in dim_to_keys[dim]} k2s.discard(k1) if k2s: _push_candidate( @@ -750,8 +733,7 @@ def greedy( choose_fn: Any = None, cost_fn: str = "memory-removed", ) -> PathType: - """ - Finds the path by a three stage algorithm: + """Finds the path by a three stage algorithm: 1. Eagerly compute Hadamard products. 2. Greedily compute contractions to maximize `removed_size` @@ -772,7 +754,6 @@ def greedy( path: The contraction order (a list of tuples of ints). Examples: - ```python isets = [set('abd'), set('ac'), set('bdc')] oset = set('') @@ -789,8 +770,7 @@ def greedy( def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType: - """ - Converts a contraction tree to a contraction path as it has to be + """Converts a contraction tree to a contraction path as it has to be returned by path optimizers. A contraction tree can either be an int (=no contraction) or a tuple containing the terms to be contracted. An arbitrary number (>= 1) of terms can be contracted at once. Note that @@ -809,7 +789,6 @@ def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType: #> [(1, 2), (1, 2, 3), (0, 2), (0, 1)] ``` """ - # ((1,2),(0,(4,5,3))) --> [(1, 2), (1, 2, 3), (0, 2), (0, 1)] # # 0 0 0 (1,2) --> ((1,2),(0,(3,4,5))) @@ -821,7 +800,7 @@ def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType: # # this function iterates through the table shown above from right to left; - if type(tree) == int: + if type(tree) == int: # noqa: E721 return [] c: List[Tuple[Any, ...]] = [tree] # list of remaining contractions (lower part of columns shown above) @@ -830,13 +809,13 @@ def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType: while len(c) > 0: j = c.pop(-1) - s.insert(0, tuple()) + s.insert(0, ()) - for i in sorted([i for i in j if type(i) == int]): + for i in sorted([i for i in j if type(i) == int]): # noqa: E721 s[0] += (sum(1 for q in t if q < i),) t.insert(s[0][-1], i) - for i_tup in [i_tup for i_tup in j if type(i_tup) != int]: + for i_tup in [i_tup for i_tup in j if type(i_tup) != int]: # noqa: E721 s[0] += (len(t) + len(c),) c.append(i_tup) @@ -844,8 +823,7 @@ def _tree_to_sequence(tree: Tuple[Any, ...]) -> PathType: def _find_disconnected_subgraphs(inputs: List[FrozenSet[int]], output: FrozenSet[int]) -> List[FrozenSet[int]]: - """ - Finds disconnected subgraphs in the given list of inputs. Inputs are + """Finds disconnected subgraphs in the given list of inputs. Inputs are connected if they share summation indices. Note: Disconnected subgraphs can be contracted independently before forming outer products. @@ -865,7 +843,6 @@ def _find_disconnected_subgraphs(inputs: List[FrozenSet[int]], output: FrozenSet #> [{0}, {1}, {2}] ``` """ - subgraphs = [] unused_inputs = set(range(len(inputs))) @@ -941,7 +918,6 @@ def _dp_compare_flops( 3. If the intermediate tensor corresponding to `s` is going to break the memory limit. """ - # TODO: Odd usage with an Iterable[int] to map a dict of type List[int] cost = cost1 + cost2 + compute_size_by_dict(i1_union_i2, size_dict) if cost <= cost_cap: @@ -974,7 +950,6 @@ def _dp_compare_size( on the size of the intermediate tensor created, rather than the number of operations, and so calculates that first. """ - s = s1 | s2 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2) mem = compute_size_by_dict(i, size_dict) @@ -1039,7 +1014,7 @@ def _dp_compare_combo( combine: Callable = sum, ) -> None: """Like ``_dp_compare_flops`` but sieves the potential contraction based - on some combination of both the flops and size, + on some combination of both the flops and size,. """ s = s1 | s2 i = _dp_calc_legs(g, all_tensors, s, inputs, i1_cut_i2_wo_output, i1_union_i2) @@ -1126,8 +1101,7 @@ def _dp_parse_out_single_term_ops( class DynamicProgramming(PathOptimizer): - """ - Finds the optimal path of pairwise contractions without intermediate outer + """Finds the optimal path of pairwise contractions without intermediate outer products based a dynamic programming approach presented in Phys. Rev. E 90, 033315 (2014) (the corresponding preprint is publicly available at https://arxiv.org/abs/1304.6112). This method is especially @@ -1173,12 +1147,11 @@ def __call__( size_dict_: Dict[str, int], memory_limit_: Optional[int] = None, ) -> PathType: - """ - Parameters: + """Parameters: inputs_: List of sets that represent the lhs side of the einsum subscript output_: Set that represents the rhs side of the overall einsum subscript size_dict_: Dictionary of index sizes - memory_limit_: The maximum number of elements in a temporary array + memory_limit_: The maximum number of elements in a temporary array. Returns: path: The contraction order (a list of tuples of ints). @@ -1240,14 +1213,13 @@ def __call__( all_tensors = (1 << len(inputs)) - 1 for g in subgraphs: - # dynamic programming approach to compute x[n] for subgraph g; # x[n][set of n tensors] = (indices, cost, contraction) # the set of n tensors is represented by a bitmap: if bit j is 1, # tensor j is in the set, e.g. 0b100101 = {0,2,5}; set unions # (intersections) can then be computed by bitwise or (and); - x: List[Any] = [None] * 2 + [dict() for j in range(len(g) - 1)] - x[1] = OrderedDict((1 << j, (inputs[j], 0, inputs_contractions[j])) for j in g) + x: List[Any] = [None] * 2 + [{} for j in range(len(g) - 1)] + x[1] = {1 << j: (inputs[j], 0, inputs_contractions[j]) for j in g} # convert set of tensors g to a bitmap set: bitmap_g = functools.reduce(lambda x, y: x | y, (1 << j for j in g)) @@ -1277,7 +1249,6 @@ def __call__( for m in range(1, n // 2 + 1): for s1, (i1, cost1, contract1) in x[m].items(): for s2, (i2, cost2, contract2) in x[n - m].items(): - # can only merge if s1 and s2 are disjoint # and avoid e.g. s1={0}, s2={1} and s1={1}, s2={0} if (not s1 & s2) and (m != n - m or s1 < s2): @@ -1285,7 +1256,6 @@ def __call__( # maybe ignore outer products: if _check_outer(i1_cut_i2_wo_output): - i1_union_i2 = i1 | i2 _check_contraction( cost1, @@ -1362,8 +1332,7 @@ def auto( """Finds the contraction path by automatically choosing the method based on how many input arguments there are. """ - N = len(inputs) - return _AUTO_CHOICES.get(N, greedy)(inputs, output, size_dict, memory_limit) + return _AUTO_CHOICES.get(len(inputs), greedy)(inputs, output, size_dict, memory_limit) _AUTO_HQ_CHOICES = {} @@ -1385,8 +1354,7 @@ def auto_hq( """ from opt_einsum.path_random import random_greedy_128 - N = len(inputs) - return _AUTO_HQ_CHOICES.get(N, random_greedy_128)(inputs, output, size_dict, memory_limit) + return _AUTO_HQ_CHOICES.get(len(inputs), random_greedy_128)(inputs, output, size_dict, memory_limit) _PATH_OPTIONS: Dict[str, PathSearchFunctionType] = { @@ -1407,7 +1375,7 @@ def auto_hq( def register_path_fn(name: str, fn: PathSearchFunctionType) -> None: """Add path finding function ``fn`` as an option with ``name``.""" if name in _PATH_OPTIONS: - raise KeyError("Path optimizer '{}' already exists.".format(name)) + raise KeyError(f"Path optimizer '{name}' already exists.") _PATH_OPTIONS[name.lower()] = fn @@ -1416,8 +1384,6 @@ def get_path_fn(path_type: str) -> PathSearchFunctionType: """Get the correct path finding function from str ``path_type``.""" path_type = path_type.lower() if path_type not in _PATH_OPTIONS: - raise KeyError( - "Path optimizer '{}' not found, valid options are {}.".format(path_type, set(_PATH_OPTIONS.keys())) - ) + raise KeyError(f"Path optimizer '{path_type}' not found, valid options are {set(_PATH_OPTIONS.keys())}.") return _PATH_OPTIONS[path_type] diff --git a/opt_einsum/sharing.py b/opt_einsum/sharing.py index 721d227..c9ee583 100644 --- a/opt_einsum/sharing.py +++ b/opt_einsum/sharing.py @@ -1,5 +1,4 @@ -""" -A module for sharing intermediates between contractions. +"""A module for sharing intermediates between contractions. Copyright (c) 2018 Uber Technologies """ @@ -9,9 +8,8 @@ import numbers import threading from collections import Counter, defaultdict -from typing import Any +from typing import Any, Dict, Generator, List, Optional, Tuple, Union from typing import Counter as CounterType -from typing import Dict, Generator, List, Optional, Tuple, Union from opt_einsum.parser import alpha_canonicalize, parse_einsum_input from opt_einsum.typing import ArrayType diff --git a/opt_einsum/testing.py b/opt_einsum/testing.py index 5c41bf9..2f18b3a 100644 --- a/opt_einsum/testing.py +++ b/opt_einsum/testing.py @@ -1,6 +1,4 @@ -""" -Testing routines for opt_einsum. -""" +"""Testing routines for opt_einsum.""" import random from typing import Any, Dict, List, Literal, Optional, Tuple, Union, overload @@ -12,18 +10,17 @@ _valid_chars = "abcdefghijklmopqABC" _sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4] -_default_dim_dict = {c: s for c, s in zip(_valid_chars, _sizes)} +_default_dim_dict = dict(zip(_valid_chars, _sizes)) def build_shapes(string: str, dimension_dict: Optional[Dict[str, int]] = None) -> Tuple[TensorShapeType, ...]: - """ - Builds random tensor shapes for testing. + """Builds random tensor shapes for testing. Parameters: string: List of tensor strings to build dimension_dict: Dictionary of index sizes, defaults to indices size of 2-7 - Returns + Returns: The resulting shapes. Examples: @@ -34,7 +31,6 @@ def build_shapes(string: str, dimension_dict: Optional[Dict[str, int]] = None) - ``` """ - if dimension_dict is None: dimension_dict = _default_dim_dict @@ -49,15 +45,14 @@ def build_shapes(string: str, dimension_dict: Optional[Dict[str, int]] = None) - def build_views( string: str, dimension_dict: Optional[Dict[str, int]] = None, array_function: Optional[Any] = None ) -> Tuple[ArrayType]: - """ - Builds random numpy arrays for testing. + """Builds random numpy arrays for testing. Parameters: string: List of tensor strings to build dimension_dict: Dictionary of index _sizes array_function: Function to build the arrays, defaults to np.random.rand - Returns + Returns: The resulting views. Examples: @@ -156,7 +151,6 @@ def rand_equation( (9, 5, 3, 3, 9, 5)] ``` """ - np = pytest.importorskip("numpy") if seed is not None: np.random.seed(seed) @@ -223,7 +217,8 @@ def build_arrays_from_tuples(path: PathType) -> List[Any]: path: The path to build arrays from. Returns: - The resulting arrays.""" + The resulting arrays. + """ np = pytest.importorskip("numpy") return [np.random.rand(*x) for x in path] diff --git a/opt_einsum/tests/test_backends.py b/opt_einsum/tests/test_backends.py index 9088335..8145fb1 100644 --- a/opt_einsum/tests/test_backends.py +++ b/opt_einsum/tests/test_backends.py @@ -9,10 +9,10 @@ try: # needed so tensorflow doesn't allocate all gpu mem try: - from tensorflow import ConfigProto + from tensorflow import ConfigProto # type: ignore from tensorflow import Session as TFSession except ImportError: - from tensorflow.compat.v1 import ConfigProto + from tensorflow.compat.v1 import ConfigProto # type: ignore from tensorflow.compat.v1 import Session as TFSession _TF_CONFIG = ConfigProto() _TF_CONFIG.gpu_options.allow_growth = True diff --git a/opt_einsum/tests/test_contract.py b/opt_einsum/tests/test_contract.py index 3037503..313198d 100644 --- a/opt_einsum/tests/test_contract.py +++ b/opt_einsum/tests/test_contract.py @@ -177,7 +177,7 @@ def test_printing(): @pytest.mark.parametrize("out_spec", [False, True]) def test_contract_expressions(string: str, optimize: OptimizeKind, use_blas: bool, out_spec: bool) -> None: views = build_views(string) - shapes = [view.shape if hasattr(view, "shape") else tuple() for view in views] + shapes = [view.shape if hasattr(view, "shape") else () for view in views] expected = contract(string, *views, optimize=False, use_blas=False) expr = contract_expression(string, *shapes, optimize=optimize, use_blas=use_blas) @@ -220,7 +220,7 @@ def test_contract_expression_with_constants(string: str, constants: List[int]) - views = build_views(string) expected = contract(string, *views, optimize=False, use_blas=False) - shapes = [view.shape if hasattr(view, "shape") else tuple() for view in views] + shapes = [view.shape if hasattr(view, "shape") else () for view in views] expr_args: List[Any] = [] ctrc_args = [] diff --git a/opt_einsum/tests/test_input.py b/opt_einsum/tests/test_input.py index fefbf57..8ce9b0a 100644 --- a/opt_einsum/tests/test_input.py +++ b/opt_einsum/tests/test_input.py @@ -2,7 +2,7 @@ Tests the input parsing for opt_einsum. Duplicates the np.einsum input tests. """ -from typing import Any +from typing import Any, List import pytest @@ -12,12 +12,12 @@ np = pytest.importorskip("numpy") -def build_views(string: str) -> list[ArrayType]: +def build_views(string: str) -> List[ArrayType]: """Builds random numpy arrays for testing by using a fixed size dictionary and an input string.""" chars = "abcdefghij" sizes_array = np.array([2, 3, 4, 5, 4, 3, 2, 6, 5, 4]) - sizes = {c: s for c, s in zip(chars, sizes_array)} + sizes = dict(zip(chars, sizes_array)) views = [] @@ -84,7 +84,7 @@ def test_type_errors() -> None: contract(views[0], [Ellipsis, 0], [Ellipsis, ["a"]]) with pytest.raises(TypeError): - contract(views[0], [Ellipsis, dict()], [Ellipsis, "a"]) + contract(views[0], [Ellipsis, {}], [Ellipsis, "a"]) @pytest.mark.parametrize("contract_fn", [contract, contract_path]) diff --git a/opt_einsum/tests/test_parser.py b/opt_einsum/tests/test_parser.py index 6fcd3b2..b39bf8b 100644 --- a/opt_einsum/tests/test_parser.py +++ b/opt_einsum/tests/test_parser.py @@ -62,12 +62,12 @@ def test_parse_with_ellisis() -> None: [(5, 5), (2,)], [[[[[[5, 2]]]]], (1, 1, 1, 1, 2)], [[[[[["abcdef", "b"]]]]], (1, 1, 1, 1, 2)], - ["A", tuple()], - [b"A", tuple()], - [True, tuple()], - [5, tuple()], - [5.0, tuple()], - [5.0 + 0j, tuple()], + ["A", ()], + [b"A", ()], + [True, ()], + [5, ()], + [5.0, ()], + [5.0 + 0j, ()], ], ) def test_get_shapes(array: Any, shape: Tuple[int]) -> None: diff --git a/opt_einsum/tests/test_paths.py b/opt_einsum/tests/test_paths.py index 70f0904..fff2ad3 100644 --- a/opt_einsum/tests/test_paths.py +++ b/opt_einsum/tests/test_paths.py @@ -79,13 +79,11 @@ def check_path(test_output: PathType, benchmark: PathType, bypass: bool = False) def assert_contract_order(func: Any, test_data: Any, max_size: int, benchmark: PathType) -> None: - test_output = func(test_data[0], test_data[1], test_data[2], max_size) assert check_path(test_output, benchmark) def test_size_by_dict() -> None: - sizes_dict = {} for ind, val in zip("abcdez", [2, 5, 9, 11, 13, 0]): sizes_dict[ind] = val @@ -105,7 +103,6 @@ def test_size_by_dict() -> None: def test_flop_cost() -> None: - size_dict = {v: 10 for v in "abcdef"} # Loop over an array @@ -138,7 +135,6 @@ def test_explicit_path() -> None: def test_path_optimal() -> None: - test_func = oe.paths.optimal test_data = explicit_path_tests["GEMM1"] @@ -147,7 +143,6 @@ def test_path_optimal() -> None: def test_path_greedy() -> None: - test_func = oe.paths.greedy test_data = explicit_path_tests["GEMM1"] @@ -156,7 +151,6 @@ def test_path_greedy() -> None: def test_memory_paths() -> None: - expression = "abc,bdef,fghj,cem,mhk,ljk->adgl" views = build_shapes(expression) @@ -197,7 +191,6 @@ def test_path_scalar_cases(alg: OptimizeKind, expression: str, order: PathType) def test_optimal_edge_cases() -> None: - # Edge test5 expression = "a,ac,ab,ad,cd,bd,bc->" edge_test4 = build_shapes(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20}) @@ -209,7 +202,6 @@ def test_optimal_edge_cases() -> None: def test_greedy_edge_cases() -> None: - expression = "abc,cfd,dbe,efa" dim_dict = {k: 20 for k in expression.replace(",", "")} tensors = build_shapes(expression, dimension_dict=dim_dict) @@ -313,8 +305,7 @@ def test_dp_errors_when_no_contractions_found() -> None: @pytest.mark.parametrize("optimize", ["greedy", "branch-2", "branch-all", "optimal", "dp"]) def test_can_optimize_outer_products(optimize: OptimizeKind) -> None: - - a, b, c = [(10, 10) for _ in range(3)] + a, b, c = ((10, 10) for _ in range(3)) d = (10, 2) assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize, shapes=True)[0] == [ @@ -539,6 +530,5 @@ def custom_optimizer( def test_path_with_assumed_shapes() -> None: - path, _ = oe.contract_path("ab,bc,cd", [[5, 3]], [[2], [4]], [[3, 2]]) assert path == [(0, 1), (0, 1)] diff --git a/opt_einsum/tests/test_sharing.py b/opt_einsum/tests/test_sharing.py index 42717fb..8b0f1c0 100644 --- a/opt_einsum/tests/test_sharing.py +++ b/opt_einsum/tests/test_sharing.py @@ -16,7 +16,7 @@ pytest.importorskip("numpy") try: - import numpy as np # noqa # type: ignore + import numpy as np # type: ignore numpy_if_found = "numpy" except ImportError: @@ -30,7 +30,7 @@ cupy_if_found = pytest.param("cupy", marks=[pytest.mark.skip(reason="CuPy not installed.")]) # type: ignore try: - import torch # noqa + import torch # type: ignore # noqa torch_if_found = "torch" except ImportError: @@ -89,8 +89,8 @@ def test_complete_sharing(backend: BackendType) -> None: actual = count_cached_ops(cache) print("-" * 40) - print("Without sharing: {} expressions".format(expected)) - print("With sharing: {} expressions".format(actual)) + print(f"Without sharing: {expected} expressions") + print(f"With sharing: {actual} expressions") assert actual == expected @@ -115,8 +115,8 @@ def test_sharing_reused_cache(backend: BackendType) -> None: actual = count_cached_ops(cache) print("-" * 40) - print("Without sharing: {} expressions".format(expected)) - print("With sharing: {} expressions".format(actual)) + print(f"Without sharing: {expected} expressions") + print(f"With sharing: {actual} expressions") assert actual == expected @@ -143,8 +143,8 @@ def test_no_sharing_separate_cache(backend: BackendType) -> None: actual.update(count_cached_ops(cache2)) print("-" * 40) - print("Without sharing: {} expressions".format(expected)) - print("With sharing: {} expressions".format(actual)) + print(f"Without sharing: {expected} expressions") + print(f"With sharing: {actual} expressions") assert actual == expected @@ -211,8 +211,8 @@ def test_sharing_modulo_commutativity(eq: str, backend: BackendType) -> None: actual = count_cached_ops(cache) print("-" * 40) - print("Without sharing: {} expressions".format(expected)) - print("With sharing: {} expressions".format(actual)) + print(f"Without sharing: {expected} expressions") + print(f"With sharing: {actual} expressions") assert actual == expected @@ -241,8 +241,8 @@ def test_partial_sharing(backend: BackendType) -> None: num_exprs_sharing = count_cached_ops(cache) print("-" * 40) - print("Without sharing: {} expressions".format(num_exprs_nosharing)) - print("With sharing: {} expressions".format(num_exprs_sharing)) + print(f"Without sharing: {num_exprs_nosharing} expressions") + print(f"With sharing: {num_exprs_sharing} expressions") assert num_exprs_nosharing["einsum"] > num_exprs_sharing["einsum"] @@ -278,7 +278,7 @@ def test_chain(size: int, backend: BackendType) -> None: print(inputs) for i in range(size + 1): target = alphabet[i] - eq = "{}->{}".format(inputs, target) + eq = f"{inputs}->{target}" path_info = contract_path(eq, *xs) print(path_info[1]) expr = contract_expression(eq, *shapes) @@ -299,7 +299,7 @@ def test_chain_2(size: int, backend: BackendType) -> None: print(inputs) for i in range(size): target = alphabet[i : i + 2] - eq = "{}->{}".format(inputs, target) + eq = f"{inputs}->{target}" path_info = contract_path(eq, *xs) print(path_info[1]) expr = contract_expression(eq, *shapes) @@ -325,15 +325,15 @@ def test_chain_2_growth(backend: BackendType) -> None: with shared_intermediates() as cache: for i in range(size): target = alphabet[i : i + 2] - eq = "{}->{}".format(inputs, target) + eq = f"{inputs}->{target}" expr = contract_expression(eq, *(x.shape for x in xs)) expr(*xs, backend=backend) costs.append(_compute_cost(cache)) - print("sizes = {}".format(repr(sizes))) - print("costs = {}".format(repr(costs))) + print(f"sizes = {repr(sizes)}") + print(f"costs = {repr(costs)}") for size, cost in zip(sizes, costs): - print("{}\t{}".format(size, cost)) + print(f"{size}\t{cost}") @pytest.mark.parametrize("size", [3, 4, 5]) @@ -348,7 +348,7 @@ def test_chain_sharing(size: int, backend: BackendType) -> None: for i in range(size + 1): with shared_intermediates() as cache: target = alphabet[i] - eq = "{}->{}".format(inputs, target) + eq = f"{inputs}->{target}" expr = contract_expression(eq, *tuple(x.shape for x in xs)) expr(*xs, backend=backend) num_exprs_nosharing += _compute_cost(cache) @@ -357,16 +357,16 @@ def test_chain_sharing(size: int, backend: BackendType) -> None: print(inputs) for i in range(size + 1): target = alphabet[i] - eq = "{}->{}".format(inputs, target) + eq = f"{inputs}->{target}" path_info = contract_path(eq, *xs) print(path_info[1]) - expr = contract_expression(eq, *list(x.shape for x in xs)) + expr = contract_expression(eq, *[x.shape for x in xs]) expr(*xs, backend=backend) num_exprs_sharing = _compute_cost(cache) print("-" * 40) - print("Without sharing: {} expressions".format(num_exprs_nosharing)) - print("With sharing: {} expressions".format(num_exprs_sharing)) + print(f"Without sharing: {num_exprs_nosharing} expressions") + print(f"With sharing: {num_exprs_sharing} expressions") assert num_exprs_nosharing > num_exprs_sharing @@ -374,11 +374,11 @@ def test_multithreaded_sharing() -> None: from multiprocessing.pool import ThreadPool def fn(): - X, Y, Z = build_views("ab,bc,cd") + x, y, z = build_views("ab,bc,cd") with shared_intermediates(): - contract("ab,bc,cd->a", X, Y, Z) - contract("ab,bc,cd->b", X, Y, Z) + contract("ab,bc,cd->a", x, y, z) + contract("ab,bc,cd->b", x, y, z) return len(get_sharing_cache()) diff --git a/opt_einsum/typing.py b/opt_einsum/typing.py index 5057011..a807621 100644 --- a/opt_einsum/typing.py +++ b/opt_einsum/typing.py @@ -1,6 +1,4 @@ -""" -Types used in the opt_einsum package -""" +"""Types used in the opt_einsum package.""" from collections import namedtuple from typing import Any, Callable, Collection, Dict, FrozenSet, List, Literal, Optional, Tuple, Union diff --git a/pyproject.toml b/pyproject.toml index f83c639..0d185dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,92 @@ -[tool.black] +[build-system] +requires = ['hatchling', 'hatch-fancy-pypi-readme>=22.5.0', 'hatch-vcs'] +build-backend = 'hatchling.build' + +[project] +name = 'opt_einsum' +description = 'Path optimization of einsum functions.' +authors = [ + {name = 'Daniel Smith', email = 'dgasmith@icloud.com'}, +] +license = 'MIT' +classifiers = [ + 'Development Status :: 5 - Production/Stable', + 'Programming Language :: Python', + 'Programming Language :: Python :: Implementation :: CPython', + 'Programming Language :: Python :: Implementation :: PyPy', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3 :: Only', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', + 'Intended Audience :: Developers', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: MIT License', + 'Topic :: Software Development :: Libraries :: Python Modules', + +] +requires-python = '>=3.8' +dependencies = [ +] +dynamic = ['version', 'readme'] + +[tool.hatch.version] +source = "vcs" +path = 'opt_einsum/_version.py' + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.hooks.vcs] +version-file = "opt_einsum/_version.py" + +[tool.hatch.metadata.hooks.fancy-pypi-readme] +content-type = 'text/markdown' +# construct the PyPI readme from README.md and HISTORY.md +fragments = [ + {path = "README.md"}, +] + +[tool.hatch.build.targets.sdist] +exclude = [ + "/.github", + "/devtools", + "/docs", + "/paper", + "/scripts" +] + +[tool.hatch.build.targets.wheel] +packages = ["opt_einsum"] + +[tool.pytest.ini_options] +filterwarnings = [ + 'ignore::DeprecationWarning:tensorflow', + 'ignore::DeprecationWarning:tensorboard', +] + +[tool.ruff] line-length = 120 -target-version = ['py39'] +target-version = 'py38' + +[tool.ruff.lint] +extend-select = ['RUF100', 'UP', 'C', 'D', 'I', 'N', 'NPY', 'Q', 'T', 'W'] +extend-ignore = ['C901', 'D101', 'D102', 'D103', 'D105', 'D107', 'D205', 'D415'] +isort = { known-first-party = ['opt_einsum'] } +mccabe = { max-complexity = 14 } +pydocstyle = { convention = 'google' } + +[tool.ruff.lint.per-file-ignores] +'opt_einsum/tests/*' = ['D', 'T201', 'NPY002', 'ANN001', 'ANN202'] + +[tool.coverage.run] +source = ['opt_einsum'] +omit = ['*/tests/*', 'opt_einsum/_version.py'] +branch = true +relative_files = true + +[[tool.mypy.overrides]] +module = "cupy.*, jax.*, numpy.*, theano.*, tensorflow.*, torch.*" +ignore_missing_imports = true \ No newline at end of file diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index b65d815..0000000 --- a/setup.cfg +++ /dev/null @@ -1,85 +0,0 @@ -# Helper file to handle all configs - -[coverage:run] -# .coveragerc to control coverage.py and pytest-cov -# Omit the test directory from test coverage -omit = - */tests/* - opt_einsum/_version.py - versioneer.py - -[tool:isort] -line_length=120 -include_trailing_comma=True -force_grid_wrap=0 -use_parentheses=True -multi_line_output=3 - -[tool:pytest] -filterwarnings = - ignore::DeprecationWarning:tensorflow - ignore::DeprecationWarning:tensorboard - -[flake8] -# Flake8, PyFlakes, etc -max-line-length = 119 -dictionaries = en_US,python,technical -spellcheck-targets = comments -ignore = F401, # Imports, handled by autoflake - E203, - E225, - E226, - W605, - E501 # Lines, handled by black - -[versioneer] -# Automatic version numbering scheme -VCS = git -style = pep440 -versionfile_source = opt_einsum/_version.py -versionfile_build = opt_einsum/_version.py -tag_prefix = '' - -[mypy] -#plugins = numpy.typing.mypy_plugin - -follow_imports = normal -strict_optional = True -warn_redundant_casts = True -# no_implicit_reexport = True -warn_unused_configs = True -disallow_incomplete_defs = True -warn_unused_ignores = True - -[mypy-opt_einsum.tests.*] -ignore_missing_imports = True - -[mypy-opt_einsum._version] -ignore_errors = True - -[mypy-autograd.*] -ignore_missing_imports = True - -[mypy-cupy.*] -ignore_missing_imports = True - -[mypy-numpy.*] -ignore_missing_imports = True - -[mypy-tensorflow.*] -ignore_missing_imports = True - -[mypy-tf.*] -ignore_missing_imports = True - -[mypy-theano.*] -ignore_missing_imports = True - -[mypy-torch.*] -ignore_missing_imports = True - -[mypy-jax.*] -ignore_missing_imports = True - -[mypy-xla.*] -ignore_missing_imports = True diff --git a/setup.py b/setup.py deleted file mode 100644 index fb2281e..0000000 --- a/setup.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- - -import setuptools -import versioneer - -short_description = "Optimizing numpys einsum function" -try: - with open("README.md", "r") as handle: - long_description = handle.read() -except: # noqa - long_description = short_description - - -if __name__ == "__main__": - setuptools.setup( - name='opt_einsum', - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - description=short_description, - author='Daniel Smith', - author_email='dgasmith@icloud.com', - url="https://github.com/dgasmith/opt_einsum", - license='MIT', - packages=setuptools.find_packages(), - python_requires='>=3.9', - install_requires=[ - 'numpy>=1.23', - ], - extras_require={ - 'tests': [ - 'pytest', - 'pytest-cov', - ], - }, - - tests_require=[ - 'pytest', - 'pytest-cov', - ], - - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Science/Research', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - ], - zip_safe=True, - long_description=long_description, - long_description_content_type="text/markdown" - ) diff --git a/versioneer.py b/versioneer.py deleted file mode 100644 index 6d732af..0000000 --- a/versioneer.py +++ /dev/null @@ -1,1822 +0,0 @@ - -# Version: 0.18 - -"""The Versioneer - like a rocketeer, but for versions. - -The Versioneer -============== - -* like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer -* Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based -python projects. The goal is to remove the tedious and error-prone "update -the embedded version string" step from your release process. Making a new -release should be as easy as recording a new tag in your version-control -system, and maybe making new tarballs. - - -## Quick Install - -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results - -## Version Identifiers - -Source trees come from a variety of places: - -* a version-control system checkout (mostly used by developers) -* a nightly tarball, produced by build automation -* a snapshot tarball, produced by a web-based VCS browser, like github's - "tarball from tag" feature -* a release tarball, produced by "setup.py sdist", distributed through PyPI - -Within each source tree, the version identifier (either a string or a number, -this tool is format-agnostic) can come from a variety of places: - -* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows - about recent "tags" and an absolute revision-id -* the name of the directory into which the tarball was unpacked -* an expanded VCS keyword ($Id$, etc) -* a `_version.py` created by some earlier build step - -For released software, the version identifier is closely related to a VCS -tag. Some projects use tag names that include more than just the version -string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool -needs to strip the tag prefix to extract the version identifier. For -unreleased software (between tags), the version identifier should provide -enough information to help developers recreate the same tree, while also -giving them an idea of roughly how old the tree is (after version 1.2, before -version 1.3). Many VCS systems can report a description that captures this, -for example `git describe --tags --dirty --always` reports things like -"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the -0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. - -The version identifier is used for multiple purposes: - -* to allow the module to self-identify its version: `myproject.__version__` -* to choose a name and prefix for a 'setup.py sdist' tarball - -## Theory of Operation - -Versioneer works by adding a special `_version.py` file into your source -tree, where your `__init__.py` can import it. This `_version.py` knows how to -dynamically ask the VCS tool for version information at import time. - -`_version.py` also contains `$Revision$` markers, and the installation -process marks `_version.py` to have this marker rewritten with a tag name -during the `git archive` command. As a result, generated tarballs will -contain enough information to get the proper version. - -To allow `setup.py` to compute a version too, a `versioneer.py` is added to -the top level of your source tree, next to `setup.py` and the `setup.cfg` -that configures it. This overrides several distutils/setuptools commands to -compute the version when invoked, and changes `setup.py build` and `setup.py -sdist` to replace `_version.py` with a small static file that contains just -the generated version data. - -## Installation - -See [INSTALL.md](./INSTALL.md) for detailed installation instructions. - -## Version-String Flavors - -Code which uses Versioneer can learn about its version string at runtime by -importing `_version` from your main `__init__.py` file and running the -`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can -import the top-level `versioneer.py` and run `get_versions()`. - -Both functions return a dictionary with different flavors of version -information: - -* `['version']`: A condensed version string, rendered using the selected - style. This is the most commonly used value for the project's version - string. The default "pep440" style yields strings like `0.11`, - `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section - below for alternative styles. - -* `['full-revisionid']`: detailed revision identifier. For Git, this is the - full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". - -* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the - commit date in ISO 8601 format. This will be None if the date is not - available. - -* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that - this is only accurate if run in a VCS checkout, otherwise it is likely to - be False or None - -* `['error']`: if the version string could not be computed, this will be set - to a string describing the problem, otherwise it will be None. It may be - useful to throw an exception in setup.py if this is set, to avoid e.g. - creating tarballs with a version string of "unknown". - -Some variants are more useful than others. Including `full-revisionid` in a -bug report should allow developers to reconstruct the exact code being tested -(or indicate the presence of local changes that should be shared with the -developers). `version` is suitable for display in an "about" box or a CLI -`--version` output: it can be easily compared against release notes and lists -of bugs fixed in various releases. - -The installer adds the following text to your `__init__.py` to place a basic -version in `YOURPROJECT.__version__`: - - from ._version import get_versions - __version__ = get_versions()['version'] - del get_versions - -## Styles - -The setup.cfg `style=` configuration controls how the VCS information is -rendered into a version string. - -The default style, "pep440", produces a PEP440-compliant string, equal to the -un-prefixed tag name for actual releases, and containing an additional "local -version" section with more detail for in-between builds. For Git, this is -TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags ---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the -tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and -that this commit is two revisions ("+2") beyond the "0.11" tag. For released -software (exactly equal to a known tag), the identifier will only contain the -stripped tag, e.g. "0.11". - -Other styles are available. See [details.md](details.md) in the Versioneer -source tree for descriptions. - -## Debugging - -Versioneer tries to avoid fatal errors: if something goes wrong, it will tend -to return a version of "0+unknown". To investigate the problem, run `setup.py -version`, which will run the version-lookup code in a verbose mode, and will -display the full contents of `get_versions()` (including the `error` string, -which may help identify what went wrong). - -## Known Limitations - -Some situations are known to cause problems for Versioneer. This details the -most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). - -### Subprojects - -Versioneer has limited support for source trees in which `setup.py` is not in -the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are -two common reasons why `setup.py` might not be in the root: - -* Source trees which contain multiple subprojects, such as - [Buildbot](https://github.com/buildbot/buildbot), which contains both - "master" and "slave" subprojects, each with their own `setup.py`, - `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI - distributions (and upload multiple independently-installable tarballs). -* Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other languages) in subdirectories. - -Versioneer will look for `.git` in parent directories, and most operations -should get the right version string. However `pip` and `setuptools` have bugs -and implementation details which frequently cause `pip install .` from a -subproject directory to fail to find a correct version string (so it usually -defaults to `0+unknown`). - -`pip install --editable .` should work correctly. `setup.py install` might -work too. - -Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in -some later version. - -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking -this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the -issue from the Versioneer side in more detail. -[pip PR#3176](https://github.com/pypa/pip/pull/3176) and -[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve -pip to let Versioneer work correctly. - -Versioneer-0.16 and earlier only looked for a `.git` directory next to the -`setup.cfg`, so subprojects were completely unsupported with those releases. - -### Editable installs with setuptools <= 18.5 - -`setup.py develop` and `pip install --editable .` allow you to install a -project into a virtualenv once, then continue editing the source code (and -test) without re-installing after every change. - -"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a -convenient way to specify executable scripts that should be installed along -with the python package. - -These both work as expected when using modern setuptools. When using -setuptools-18.5 or earlier, however, certain operations will cause -`pkg_resources.DistributionNotFound` errors when running the entrypoint -script, which must be resolved by re-installing the package. This happens -when the install happens with one version, then the egg_info data is -regenerated while a different version is checked out. Many setup.py commands -cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into -a different virtualenv), so this can be surprising. - -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes -this one, but upgrading to a newer version of setuptools should probably -resolve it. - -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - - -## Updating Versioneer - -To upgrade your project to a new release of Versioneer, do the following: - -* install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace - `SRC/_version.py` -* commit any changed files - -## Future Directions - -This tool is designed to make it easily extended to other version-control -systems: all VCS-specific components are in separate directories like -src/git/ . The top-level `versioneer.py` script is assembled from these -components by running make-versioneer.py . In the future, make-versioneer.py -will take a VCS name as an argument, and will construct a version of -`versioneer.py` that is specific to the given VCS. It might also take the -configuration arguments that are currently provided manually during -installation by editing setup.py . Alternatively, it might go the other -direction and include code from all supported VCS systems, reducing the -number of intermediate scripts. - - -## License - -To make Versioneer easier to embed, all its code is dedicated to the public -domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . - -""" - -from __future__ import print_function -try: - import configparser -except ImportError: - import ConfigParser as configparser -import errno -import json -import os -import re -import subprocess -import sys - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_root(): - """Get the project root directory. - - We require that all commands are run from the project root, i.e. the - directory that contains setup.py, setup.cfg, and versioneer.py . - """ - root = os.path.realpath(os.path.abspath(os.getcwd())) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - # allow 'python path/to/setup.py COMMAND' - root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") - raise VersioneerBadRootError(err) - try: - # Certain runtime workflows (setup.py install/develop in a setuptools - # tree) execute all dependencies in a single python process, so - # "versioneer" may be imported multiple times, and python's shared - # module-import table will cache the first one. So we can't use - # os.path.dirname(__file__), as that will find whichever - # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) - vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) - except NameError: - pass - return root - - -def get_config_from_root(root): - """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or - # configparser.NoSectionError (if it lacks a [versioneer] section), or - # configparser.NoOptionError (if it lacks "VCS="). See the docstring at - # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.ConfigParser() - with open(setup_cfg, "r") as f: - parser.read_file(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None - cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): - cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -# these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -LONG_VERSION_PY['git'] = ''' -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" - git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" - git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "%(STYLE)s" - cfg.tag_prefix = "%(TAG_PREFIX)s" - cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" - cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %%s" %% dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %%s" %% (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %%s (error)" %% dispcmd) - print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %%s but none started with prefix %%s" %% - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %%d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%%s', no digits" %% ",".join(refs - tags)) - if verbose: - print("likely tags: %%s" %% ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %%s" %% r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %%s not under git control" %% root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%%s'" - %% describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%%s' doesn't start with prefix '%%s'" - print(fmt %% (full_tag, tag_prefix)) - pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" - %% (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%%s'" %% style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} -''' - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def do_vcs_install(manifest_in, versionfile_source, ipy): - """Git-specific installation logic for Versioneer. - - For Git, this means creating/changing .gitattributes to mark _version.py - for export-subst keyword substitution. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] - if ipy: - files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) - present = False - try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except EnvironmentError: - pass - if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() - files.append(".gitattributes") - run_command(GITS, ["add", "--"] + files) - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from -# revision-control system data, or from the parent directory name of an -# unpacked source archive. Distribution tarballs contain a pre-generated copy -# of this file. - -import json - -version_json = ''' -%s -''' # END VERSION_JSON - - -def get_versions(): - return json.loads(version_json) -""" - - -def versions_from_file(filename): - """Try to determine the version from _version.py if present.""" - try: - with open(filename) as f: - contents = f.read() - except EnvironmentError: - raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - raise NotThisMethod("no version_json in _version.py") - return json.loads(mo.group(1)) - - -def write_to_version_file(filename, versions): - """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) - with open(filename, "w") as f: - f.write(SHORT_VERSION_PY % contents) - - print("set %s to '%s'" % (filename, versions["version"])) - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -class VersioneerBadRootError(Exception): - """The project root directory is unknown or missing key files.""" - - -def get_versions(verbose=False): - """Get the project version from whatever source is available. - - Returns dict with two keys: 'version' and 'full'. - """ - if "versioneer" in sys.modules: - # see the discussion in cmdclass.py:get_cmdclass() - del sys.modules["versioneer"] - - root = get_root() - cfg = get_config_from_root(root) - - assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" - handlers = HANDLERS.get(cfg.VCS) - assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" - assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" - - versionfile_abs = os.path.join(root, cfg.versionfile_source) - - # extract version from first of: _version.py, VCS command (e.g. 'git - # describe'), parentdir. This is meant to work for developers using a - # source checkout, for users of a tarball created by 'setup.py sdist', - # and for users of a tarball/zipball created by 'git archive' or github's - # download-from-tag feature or the equivalent in other VCSes. - - get_keywords_f = handlers.get("get_keywords") - from_keywords_f = handlers.get("keywords") - if get_keywords_f and from_keywords_f: - try: - keywords = get_keywords_f(versionfile_abs) - ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) - if verbose: - print("got version from expanded keyword %s" % ver) - return ver - except NotThisMethod: - pass - - try: - ver = versions_from_file(versionfile_abs) - if verbose: - print("got version from file %s %s" % (versionfile_abs, ver)) - return ver - except NotThisMethod: - pass - - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: - try: - pieces = from_vcs_f(cfg.tag_prefix, root, verbose) - ver = render(pieces, cfg.style) - if verbose: - print("got version from VCS %s" % ver) - return ver - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - if verbose: - print("got version from parentdir %s" % ver) - return ver - except NotThisMethod: - pass - - if verbose: - print("unable to compute version") - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} - - -def get_version(): - """Get the short version string for this project.""" - return get_versions()["version"] - - -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" - if "versioneer" in sys.modules: - del sys.modules["versioneer"] - # this fixes the "python setup.py develop" case (also 'install' and - # 'easy_install .'), in which subdependencies of the main project are - # built (using setup.py bdist_egg) in the same python process. Assume - # a main project A and a dependency B, which use different versions - # of Versioneer. A's setup.py imports A's Versioneer, leaving it in - # sys.modules by the time B's setup.py is executed, causing B to run - # with the wrong versioneer. Setuptools wraps the sub-dep builds in a - # sandbox that restores sys.modules to it's pre-build state, so the - # parent is protected against the child's "import versioneer". By - # removing ourselves from sys.modules here, before the child build - # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 - - cmds = {} - - # we add "version" to both distutils and setuptools - from distutils.core import Command - - class cmd_version(Command): - description = "report generated version string" - user_options = [] - boolean_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - vers = get_versions(verbose=True) - print("Version: %s" % vers["version"]) - print(" full-revisionid: %s" % vers.get("full-revisionid")) - print(" dirty: %s" % vers.get("dirty")) - print(" date: %s" % vers.get("date")) - if vers["error"]: - print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - - # we override "build_py" in both distutils and setuptools - # - # most invocation pathways end up running build_py: - # distutils/build -> build_py - # distutils/install -> distutils/build ->.. - # setuptools/bdist_wheel -> distutils/install ->.. - # setuptools/bdist_egg -> distutils/install_lib -> build_py - # setuptools/install -> bdist_egg ->.. - # setuptools/develop -> ? - # pip install: - # copies source tree to a tempdir before running egg_info/etc - # if .git isn't copied too, 'git describe' will fail - # then does setup.py bdist_wheel, or sometimes setup.py install - # setup.py egg_info -> ? - - # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py - else: - from distutils.command.build_py import build_py as _build_py - - class cmd_build_py(_build_py): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_py.run(self) - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe - # nczeczulin reports that py2exe won't like the pep440-style string - # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. - # setup(console=[{ - # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION - # "product_version": versioneer.get_version(), - # ... - - class cmd_build_exe(_build_exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _build_exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["build_exe"] = cmd_build_exe - del cmds["build_py"] - - if 'py2exe' in sys.modules: # py2exe enabled? - try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 - except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 - - class cmd_py2exe(_py2exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _py2exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["py2exe"] = cmd_py2exe - - # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist - else: - from distutils.command.sdist import sdist as _sdist - - class cmd_sdist(_sdist): - def run(self): - versions = get_versions() - self._versioneer_generated_versions = versions - # unless we update this, the command will keep using the old - # version - self.distribution.metadata.version = versions["version"] - return _sdist.run(self) - - def make_release_tree(self, base_dir, files): - root = get_root() - cfg = get_config_from_root(root) - _sdist.make_release_tree(self, base_dir, files) - # now locate _version.py in the new base_dir directory - # (remembering that it may be a hardlink) and replace it with an - # updated value - target_versionfile = os.path.join(base_dir, cfg.versionfile_source) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) - cmds["sdist"] = cmd_sdist - - return cmds - - -CONFIG_ERROR = """ -setup.cfg is missing the necessary Versioneer configuration. You need -a section like: - - [versioneer] - VCS = git - style = pep440 - versionfile_source = src/myproject/_version.py - versionfile_build = myproject/_version.py - tag_prefix = - parentdir_prefix = myproject- - -You will also need to edit your setup.py to use the results: - - import versioneer - setup(version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), ...) - -Please read the docstring in ./versioneer.py for configuration instructions, -edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. -""" - -SAMPLE_CONFIG = """ -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -#VCS = git -#style = pep440 -#versionfile_source = -#versionfile_build = -#tag_prefix = -#parentdir_prefix = - -""" - -INIT_PY_SNIPPET = """ -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -""" - - -def do_setup(): - """Main VCS-independent setup function for installing Versioneer.""" - root = get_root() - try: - cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, - configparser.NoOptionError) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) - with open(os.path.join(root, "setup.cfg"), "a") as f: - f.write(SAMPLE_CONFIG) - print(CONFIG_ERROR, file=sys.stderr) - return 1 - - print(" creating %s" % cfg.versionfile_source) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") - if os.path.exists(ipy): - try: - with open(ipy, "r") as f: - old = f.read() - except EnvironmentError: - old = "" - if INIT_PY_SNIPPET not in old: - print(" appending to %s" % ipy) - with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) - else: - print(" %s unmodified" % ipy) - else: - print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except EnvironmentError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") - - # Make VCS-specific changes. For git, this means creating/changing - # .gitattributes to mark _version.py for export-subst keyword - # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) - return 0 - - -def scan_setup_py(): - """Validate the contents of setup.py against Versioneer's expectations.""" - found = set() - setters = False - errors = 0 - with open("setup.py", "r") as f: - for line in f.readlines(): - if "import versioneer" in line: - found.add("import") - if "versioneer.get_cmdclass()" in line: - found.add("cmdclass") - if "versioneer.get_version()" in line: - found.add("get_version") - if "versioneer.VCS" in line: - setters = True - if "versioneer.versionfile_source" in line: - setters = True - if len(found) != 3: - print("") - print("Your setup.py appears to be missing some important items") - print("(but I might be wrong). Please make sure it has something") - print("roughly like the following:") - print("") - print(" import versioneer") - print(" setup( version=versioneer.get_version(),") - print(" cmdclass=versioneer.get_cmdclass(), ...)") - print("") - errors += 1 - if setters: - print("You should remove lines like 'versioneer.VCS = ' and") - print("'versioneer.versionfile_source = ' . This configuration") - print("now lives in setup.cfg, and should be removed from setup.py") - print("") - errors += 1 - return errors - - -if __name__ == "__main__": - cmd = sys.argv[1] - if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1)