diff --git a/annopro/__init__.py b/annopro/__init__.py deleted file mode 100644 index 58bf170..0000000 --- a/annopro/__init__.py +++ /dev/null @@ -1,82 +0,0 @@ -def console_main(): - import argparse - parser = argparse.ArgumentParser(description='Arguments for AnnoPRO') - parser.add_argument("--fasta_file", "-i", help="The protein sequences file") - parser.add_argument('--output', "-o", default=None, - type=str, help="Output directory") - parser.add_argument('--used_gpu', default="-1", type=str, - help="GPU device selected, default is CPU") - parser.add_argument('--disable_diamond', - action='store_true', default=False, - help="Disable blast with diamond") - parser.add_argument('--overwrite', - action="store_true", - default=False, - help="Overwrite existed output" - ) - parser.add_argument("--version", - action="store_true", default=False, help="Show version") - args = parser.parse_args() - if args.version: - print("{} {}, Copyright Zhejiang University.".format( - __name__, __version__)) - exit(0) - elif args.fasta_file is None: - parser.print_help() - exit(1) - main( - proteins_fasta_file=args.fasta_file, - output_dir=args.output, - used_gpu=args.used_gpu, - with_diamond=(not args.disable_diamond), - overwrite=args.overwrite - ) - - -def main(proteins_fasta_file: str, output_dir: str = None, - used_gpu: str = None, with_diamond: bool = True, overwrite: bool = False): - from annopro.data_procession import process - from diamond4py import Diamond - from annopro import resources - from os.path import join, exists - from annopro.prediction import predict - from shutil import rmtree - import profeat - - if output_dir is None: - output_dir = proteins_fasta_file + ".output" - - if exists(output_dir): - if overwrite: - rmtree(output_dir) - else: - print(f"Output directory {output_dir} already existed!") - exit(1) - - profeat.run(proteins_fasta_file, output_dir) - - diamond_scores_file: str = None - if with_diamond: - diamond_scores_file = join(output_dir, "diamond_scores.txt") - diamond = Diamond( - database=resources.get_resource_path("cafa4.dmnd"), - n_threads=4 - ) - diamond.blastp( - query=proteins_fasta_file, - out=diamond_scores_file - ) - - promap_features_file = join(output_dir, "promap_features.pkl") - process( - proteins_fasta_file=proteins_fasta_file, - profeat_file=join(output_dir, "output-protein.dat"), - save_file=promap_features_file) - predict(output_dir=output_dir, - promap_features_file=promap_features_file, - used_gpu=used_gpu, - diamond_scores_file=diamond_scores_file) - - -from . import _version -__version__ = _version.get_versions()['version'] diff --git a/annopro/__main__.py b/annopro/__main__.py deleted file mode 100644 index 40827e9..0000000 --- a/annopro/__main__.py +++ /dev/null @@ -1,4 +0,0 @@ -from annopro import console_main - -if __name__ == "__main__": - console_main() \ No newline at end of file diff --git a/annopro/_version.py b/annopro/_version.py deleted file mode 100644 index 334f7ac..0000000 --- a/annopro/_version.py +++ /dev/null @@ -1,658 +0,0 @@ - -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. -# Generated by versioneer-0.28 -# https://github.com/python-versioneer/python-versioneer - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys -from typing import Callable, Dict -import functools - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "v" - cfg.parentdir_prefix = "annopro-" - cfg.versionfile_source = "annopro/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs, method): # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, process.returncode - return stdout, process.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces): - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver): - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces): - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%d" % (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} diff --git a/annopro/data_procession/__init__.py b/annopro/data_procession/__init__.py deleted file mode 100644 index 78452d8..0000000 --- a/annopro/data_procession/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from annopro.data_procession.data_predict import Data_process - - -def process(proteins_fasta_file: str, profeat_file: str, save_file: str): - if proteins_fasta_file == None: - raise ValueError("Must provide the input fasta sequences.") - - data = Data_process(protein_file=profeat_file, - proteins_fasta_file=proteins_fasta_file, - save_file=save_file, num=1484) - data.calculate_feature(row_num=39, size=(39, 39, 7)) diff --git a/annopro/data_procession/data_predict.py b/annopro/data_procession/data_predict.py deleted file mode 100644 index 387102b..0000000 --- a/annopro/data_procession/data_predict.py +++ /dev/null @@ -1,99 +0,0 @@ -import numpy as np -import pandas as pd -import pickle - -from sklearn.metrics.pairwise import cosine_similarity -from annopro.data_procession.utils import Ontology, load_data, MinMaxScaleClip -from profeat import profeat_to_df -from annopro import resources -from fasta import FASTA - -class Data_process(): - - def __init__( - self, - protein_file, - proteins_fasta_file, - save_file, - num, - grid_file="data_grid.pkl", - assess_file="row_asses.pkl", - prosim_file="cafa4_del.csv"): - ''' - protein_file 是生成的蛋白特征文件,即profeat生成的文件 - split_file是需要包含所需信息的蛋白序列文件, - save_file是生成的文件需要保存的位置, - prosim_file是蛋白相似性比对的库 - grid_file,assess_file是使用的map的位置文件 - ''' - self.protein_file = protein_file - self.split_file = proteins_fasta_file - self.save_file = save_file - self.grid_file = grid_file - self.assess_file = assess_file - self.prosim_file = prosim_file - self.num = num - self.__data__() - - def __data__(self): - proteins_f = profeat_to_df(self.protein_file) - proteins_f.dropna(axis=0, inplace=True) - feature_data = proteins_f.iloc[:, :self.num] - with resources.open_text("cafa4_del.csv") as cafa4_del: - mia_data = load_data(cafa4_del, 1485) - mia_data.columns = range(len(mia_data.columns)) - feature_data = (feature_data - mia_data.min()) / \ - ((mia_data.max() - mia_data.min()) + 1e-8) - self.feature_data:pd.DataFrame = feature_data - self.proteins = list(proteins_f.index) - with resources.open_binary(self.grid_file) as gf: - self.data_grid = pickle.load(gf) - with resources.open_binary(self.assess_file) as af: - self.row_asses = pickle.load(af) - with resources.open_text(self.prosim_file) as pf: - prosim_data = load_data(pf, 1485) - # this will be minmax by column - prosim_standard = MinMaxScaleClip(prosim_data) - prosim_map = prosim_standard.to_numpy() - # this will be global minmax scale - prosim_map:np.ndarray = MinMaxScaleClip(prosim_map) - self.prosim_map = prosim_map - self.go = Ontology() - - def calculate_feature(self, row_num, size): - protein_seqs = FASTA(self.split_file).sequences - class_labels = ['Composition', 'Autocorrelation', 'Physiochemical', 'Interaction', - 'Quasi-sequence-order descriptors', 'PAAC for amino acid index set', 'Amphiphilic Pseudo amino acid composition'] - protein_all = [] - sequences_all = [] - feature_all = [] - prosim_all = [] - data_grid = self.data_grid - row_asses = self.row_asses - proteins = self.proteins - result = self.feature_data.to_numpy() - consine_similarity = 1-cosine_similarity(result, self.prosim_map) - - for i, protein in enumerate(proteins): - if protein in protein_seqs: - col_list = np.zeros(size) - row = 0 - col = 0 - protein_all.append(protein) - sequences_all.append(str(protein_seqs[protein].seq)) - prosim_all.append(consine_similarity[i]) - # 构建promap特征 - for j in range(len(data_grid['x'])): - channel = data_grid['subtype'][j] - index = class_labels.index(channel) - feature_index = row_asses[j] - row = j % row_num - col = j//row_num - col_list[col][row][index] = self.feature_data.iloc[i, feature_index] - feature_all.append(col_list) - data_t = [protein_all, sequences_all, feature_all, prosim_all] - data_t = pd.DataFrame(data_t) - data_t = data_t.T - data_t.columns = ['Proteins', 'Sequence', - 'Promap_feature', 'Protein_similary'] - data_t.to_pickle(self.save_file) diff --git a/annopro/data_procession/utils.py b/annopro/data_procession/utils.py deleted file mode 100644 index fdb43cd..0000000 --- a/annopro/data_procession/utils.py +++ /dev/null @@ -1,273 +0,0 @@ -from collections import deque, Counter -from tensorflow.keras.utils import Sequence -import numpy as np -import pandas as pd -import math -from annopro import resources -from typing import Union - -BIOLOGICAL_PROCESS = 'GO:0008150' -MOLECULAR_FUNCTION = 'GO:0003674' -CELLULAR_COMPONENT = 'GO:0005575' -FUNC_DICT = { - 'cc': CELLULAR_COMPONENT, - 'mf': MOLECULAR_FUNCTION, - 'bp': BIOLOGICAL_PROCESS} - -NAMESPACES = { - 'cc': 'cellular_component', - 'mf': 'molecular_function', - 'bp': 'biological_process' -} - -EXP_CODES = set([ - 'EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC', - 'HTP', 'HDA', 'HMP', 'HGI', 'HEP']) - -# CAFA4 Targets -CAFA_TARGETS = set([ - '287', '3702', '4577', '6239', '7227', '7955', '9606', '9823', '10090', - '10116', '44689', '83333', '99287', '226900', '243273', '284812', '559292']) - - -def is_cafa_target(org): - return org in CAFA_TARGETS - - -def is_exp_code(code): - return code in EXP_CODES - - -class Ontology(object): - - def __init__(self, with_rels=False): - self.ont = self.load(with_rels) - self.ic = None - - def has_term(self, term_id): - return term_id in self.ont - - def get_term(self, term_id): - if self.has_term(term_id): - return self.ont[term_id] - return None - - def calculate_ic(self, annots): - cnt = Counter() - for x in annots: - cnt.update(x) - self.ic = {} - for go_id, n in cnt.items(): - parents = self.get_parents(go_id) - if len(parents) == 0: - min_n = n - else: - min_n = min([cnt[x] for x in parents]) - - self.ic[go_id] = math.log(min_n / n, 2) - - def get_ic(self, go_id): - if self.ic is None: - raise Exception('Not yet calculated') - if go_id not in self.ic: - return 0.0 - return self.ic[go_id] - - def load(self, with_rels): - ont = dict() - obj = None - with resources.open_text("go.txt") as f: - for line in f: - line = line.strip() - if not line: - continue - if line == '[Term]': - if obj is not None: - ont[obj['id']] = obj - obj = dict() - obj['is_a'] = list() - obj['part_of'] = list() - obj['regulates'] = list() - obj['alt_ids'] = list() - obj['is_obsolete'] = False - continue - elif line == '[Typedef]': - if obj is not None: - ont[obj['id']] = obj - obj = None - else: - if obj is None: - continue - l = line.split(": ") - if l[0] == 'id': - obj['id'] = l[1] - elif l[0] == 'alt_id': - obj['alt_ids'].append(l[1]) - elif l[0] == 'namespace': - obj['namespace'] = l[1] - elif l[0] == 'is_a': - obj['is_a'].append(l[1].split(' ! ')[0]) - elif with_rels and l[0] == 'relationship': - it = l[1].split() - # add all types of relationships - obj['is_a'].append(it[1]) - elif l[0] == 'name': - obj['name'] = l[1] - elif l[0] == 'is_obsolete' and l[1] == 'true': - obj['is_obsolete'] = True - if obj is not None: - ont[obj['id']] = obj - for term_id in list(ont.keys()): - for t_id in ont[term_id]['alt_ids']: - ont[t_id] = ont[term_id] - if ont[term_id]['is_obsolete']: - del ont[term_id] - for term_id, val in ont.items(): - if 'children' not in val: - val['children'] = set() - for p_id in val['is_a']: - if p_id in ont: - if 'children' not in ont[p_id]: - ont[p_id]['children'] = set() - ont[p_id]['children'].add(term_id) - return ont - - def get_anchestors(self, term_id): - if term_id not in self.ont: - return set() - term_set = set() - q = deque() - q.append(term_id) - while(len(q) > 0): - t_id = q.popleft() - if t_id not in term_set: - term_set.add(t_id) - for parent_id in self.ont[t_id]['is_a']: - if parent_id in self.ont: - q.append(parent_id) - return term_set - - def get_parents(self, term_id): - if term_id not in self.ont: - return set() - term_set = set() - for parent_id in self.ont[term_id]['is_a']: - if parent_id in self.ont: - term_set.add(parent_id) - return term_set - - def get_namespace_terms(self, namespace): - terms = set() - for go_id, obj in self.ont.items(): - if obj['namespace'] == namespace: - terms.add(go_id) - return terms - - def get_namespace(self, term_id): - return self.ont[term_id]['namespace'] - - def get_term_set(self, term_id): - if term_id not in self.ont: - return set() - term_set = set() - q = deque() - q.append(term_id) - while len(q) > 0: - t_id = q.popleft() - if t_id not in term_set: - term_set.add(t_id) - for ch_id in self.ont[t_id]['children']: - q.append(ch_id) - return term_set - - -def read_fasta(filename): - seqs = list() - info = list() - seq = '' - inf = '' - with open(filename, 'r') as f: - for line in f: - line = line.strip() - if line.startswith('>'): - if seq != '': - seqs.append(seq) - info.append(inf) - seq = '' - inf = line[1:] - else: - seq += line - seqs.append(seq) - info.append(inf) - return info, seqs - - -class DFGenerator(Sequence): - def __init__(self, df, terms_dict, nb_classes, batch_size): - self.start = 0 - self.size = len(df) - self.df = df - self.batch_size = batch_size - self.nb_classes = nb_classes - self.terms_dict = terms_dict - - def __len__(self): - return np.ceil(len(self.df) / float(self.batch_size)).astype(np.int32) - - def __getitem__(self, idx): - batch_index = np.arange(idx * self.batch_size, - min(self.size, (idx + 1) * self.batch_size)) - df = self.df.iloc[batch_index] - labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) - feature_data = [] - protein_si = [] - for i, row in enumerate(df.itertuples()): - feature_data.append(list(row.Promap_feature)) - protein_si.append(list(row.Protein_similary)) - data_onehot = np.array(feature_data) - data_si = np.array(protein_si) - for t_id in row.Prop_annotations: - if t_id in self.terms_dict: - labels[i, self.terms_dict[t_id]] = 1 - self.start += self.batch_size - return ([data_onehot, data_si], labels) - - def __next__(self): - return self.next() - - def reset(self): - self.start = 0 - - def next(self): - if self.start < self.size: - batch_index = np.arange( - self.start, min(self.size, self.start + self.batch_size)) - df = self.df.iloc[batch_index] - labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) - feature_data = [] - protein_si = [] - for i, row in enumerate(df.itertuples()): - feature_data.append(list(row.Promap_feature)) - protein_si.append(list(row.Protein_similary)) - data_onehot = np.array(feature_data) - data_si = np.array(protein_si) - for t_id in row.Prop_annotations: - if t_id in self.terms_dict: - labels[i, self.terms_dict[t_id]] = 1 - self.start += self.batch_size - return ([data_onehot, data_si], labels) - else: - self.reset() - return self.next() - - -def load_data(file:str, num:int): - data = pd.read_csv(file, header=None) - data.dropna(axis=0, inplace=True) - data_noprotein = data.iloc[:, 1:num] - return data_noprotein - - -def MinMaxScaleClip(data: Union[pd.DataFrame, np.ndarray]) -> Union[pd.DataFrame, np.ndarray]: - data_standard = (data - data.min()) / ((data.max() - data.min()) + 1e-8) - return data_standard \ No newline at end of file diff --git a/annopro/focal_loss/__init__.py b/annopro/focal_loss/__init__.py deleted file mode 100644 index 3230f58..0000000 --- a/annopro/focal_loss/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from ._binary_focal_loss import binary_focal_loss -from ._binary_focal_loss import BinaryFocalLoss -from ._categorical_focal_loss import sparse_categorical_focal_loss -from ._categorical_focal_loss import SparseCategoricalFocalLoss - -# Package information -__package__ = 'focal-loss' -__version__ = '0.0.8' -__author__ = 'Artem Mavrin' -__author_email__ = 'artemvmavrin@gmail.com' -__description__ = 'TensorFlow implementation of focal loss.' -__url__ = 'https://github.com/artemmavrin/focal-loss' -__copyright__ = 'Copyright 2020 Artem Mavrin' -__license__ = 'Apache 2.0' - -__doc__ = __description__ diff --git a/annopro/focal_loss/_binary_focal_loss.py b/annopro/focal_loss/_binary_focal_loss.py deleted file mode 100644 index eadb139..0000000 --- a/annopro/focal_loss/_binary_focal_loss.py +++ /dev/null @@ -1,565 +0,0 @@ -"""Binary focal loss implementation.""" -# ____ __ ___ __ __ __ __ ____ ____ -# ( __)/ \ / __) / _\ ( ) ( ) / \ / ___)/ ___) -# ) _)( O )( (__ / \/ (_/\ / (_/\( O )\___ \\___ \ -# (__) \__/ \___)\_/\_/\____/ \____/ \__/ (____/(____/ - -from functools import partial - -import tensorflow as tf - -from .utils.validation import check_bool, check_float - -_EPSILON = tf.keras.backend.epsilon() - - -def binary_focal_loss(y_true, y_pred, gamma, *, pos_weight=None, - from_logits=False, label_smoothing=None): - r"""Focal loss function for binary classification. - - This loss function generalizes binary cross-entropy by introducing a - hyperparameter :math:`\gamma` (gamma), called the *focusing parameter*, - that allows hard-to-classify examples to be penalized more heavily relative - to easy-to-classify examples. - - The focal loss [1]_ is defined as - - .. math:: - - L(y, \hat{p}) - = -\alpha y \left(1 - \hat{p}\right)^\gamma \log(\hat{p}) - - (1 - y) \hat{p}^\gamma \log(1 - \hat{p}) - - where - - * :math:`y \in \{0, 1\}` is a binary class label, - * :math:`\hat{p} \in [0, 1]` is an estimate of the probability of the - positive class, - * :math:`\gamma` is the *focusing parameter* that specifies how much - higher-confidence correct predictions contribute to the overall loss - (the higher the :math:`\gamma`, the higher the rate at which - easy-to-classify examples are down-weighted). - * :math:`\alpha` is a hyperparameter that governs the trade-off between - precision and recall by weighting errors for the positive class up or - down (:math:`\alpha=1` is the default, which is the same as no - weighting), - - The usual weighted binary cross-entropy loss is recovered by setting - :math:`\gamma = 0`. - - Parameters - ---------- - y_true : tensor-like - Binary (0 or 1) class labels. - - y_pred : tensor-like - Either probabilities for the positive class or logits for the positive - class, depending on the `from_logits` parameter. The shapes of `y_true` - and `y_pred` should be broadcastable. - - gamma : float - The focusing parameter :math:`\gamma`. Higher values of `gamma` make - easy-to-classify examples contribute less to the loss relative to - hard-to-classify examples. Must be non-negative. - - pos_weight : float, optional - The coefficient :math:`\alpha` to use on the positive examples. Must be - non-negative. - - from_logits : bool, optional - Whether `y_pred` contains logits or probabilities. - - label_smoothing : float, optional - Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary - ground truth labels `y_true` are squeezed toward 0.5, with larger values - of `label_smoothing` leading to label values closer to 0.5. - - Returns - ------- - :class:`tf.Tensor` - The focal loss for each example (assuming `y_true` and `y_pred` have the - same shapes). In general, the shape of the output is the result of - broadcasting the shapes of `y_true` and `y_pred`. - - Warnings - -------- - This function does not reduce its output to a scalar, so it cannot be passed - to :meth:`tf.keras.Model.compile` as a `loss` argument. Instead, use the - wrapper class :class:`~focal_loss.BinaryFocalLoss`. - - Examples - -------- - - This function computes the per-example focal loss between a label and - prediction tensor: - - >>> import numpy as np - >>> from focal_loss import binary_focal_loss - >>> loss = binary_focal_loss([0, 1, 1], [0.1, 0.7, 0.9], gamma=2) - >>> np.set_printoptions(precision=3) - >>> print(loss.numpy()) - [0.001 0.032 0.001] - - Below is a visualization of the focal loss between the positive class and - predicted probabilities between 0 and 1. Note that as :math:`\gamma` - increases, the losses for predictions closer to 1 get smoothly pushed to 0. - - .. plot:: - :include-source: - :align: center - - import numpy as np - import matplotlib.pyplot as plt - - from focal_loss import binary_focal_loss - - ps = np.linspace(0, 1, 100) - gammas = (0, 0.5, 1, 2, 5) - - plt.figure() - for gamma in gammas: - loss = binary_focal_loss(1, ps, gamma=gamma) - label = rf'$\gamma$={gamma}' - if gamma == 0: - label += ' (cross-entropy)' - plt.plot(ps, loss, label=label) - plt.legend(loc='best', frameon=True, shadow=True) - plt.xlim(0, 1) - plt.ylim(0, 4) - plt.xlabel(r'Probability of positive class $\hat{p}$') - plt.ylabel('Loss') - plt.title(r'Plot of focal loss $L(1, \hat{p})$ for different $\gamma$', - fontsize=14) - plt.show() - - Notes - ----- - A classifier often estimates the positive class probability :math:`\hat{p}` - by computing a real-valued *logit* :math:`\hat{y} \in \mathbb{R}` and - applying the *sigmoid function* :math:`\sigma : \mathbb{R} \to (0, 1)` - defined by - - .. math:: - - \sigma(t) = \frac{1}{1 + e^{-t}}, \qquad (t \in \mathbb{R}). - - That is, :math:`\hat{p} = \sigma(\hat{y})`. In this case, the focal loss - can be written as a function of the logit :math:`\hat{y}` instead of the - predicted probability :math:`\hat{p}`: - - .. math:: - - L(y, \hat{y}) - = -\alpha y \left(1 - \sigma(\hat{y})\right)^\gamma - \log(\sigma(\hat{y})) - - (1 - y) \sigma(\hat{y})^\gamma \log(1 - \sigma(\hat{y})). - - This is the formula that is computed when specifying `from_logits=True`. - However, this formula is not very numerically stable if implemented - directly; for example, there are multiple log and sigmoid computations - involved. Instead, we use some tricks to rewrite it in the more numerically - stable form - - .. math:: - - L(y, \hat{y}) - = (1 - y) \hat{p}^\gamma \hat{y} - + \left(\alpha y \hat{q}^\gamma + (1 - y) \hat{p}^\gamma\right) - \left(\log(1 + e^{-|\hat{y}|}) + \max\{-\hat{y}, 0\}\right), - - where :math:`\hat{p} = \sigma(\hat{y})` and :math:`\hat{q} = 1 - \hat{p}` - denote the estimates of the probabilities of the positive and negative - classes, respectively. - - Indeed, starting with the observations that - - .. math:: - - \log(\sigma(\hat{y})) - = \log\left(\frac{1}{1 + e^{-\hat{y}}}\right) - = -\log(1 + e^{-\hat{y}}) - - and - - .. math:: - - \log(1 - \sigma(\hat{y})) - = \log\left(\frac{e^{-\hat{y}}}{1 + e^{-\hat{y}}}\right) - = -\hat{y} - \log(1 + e^{-\hat{y}}), - - we obtain - - .. math:: - - \begin{aligned} - L(y, \hat{y}) - &= -\alpha y \hat{q}^\gamma \log(\sigma(\hat{y})) - - (1 - y) \hat{p}^\gamma \log(1 - \sigma(\hat{y})) \\ - &= \alpha y \hat{q}^\gamma \log(1 + e^{-\hat{y}}) - + (1 - y) \hat{p}^\gamma \left(\hat{y} + \log(1 + e^{-\hat{y}})\right)\\ - &= (1 - y) \hat{p}^\gamma \hat{y} - + \left(\alpha y \hat{q}^\gamma + (1 - y) \hat{p}^\gamma\right) - \log(1 + e^{-\hat{y}}). - \end{aligned} - - Note that if :math:`\hat{y} < 0`, then the exponential term - :math:`e^{-\hat{y}}` could become very large. In this case, we can instead - observe that - - .. math:: - - \begin{align*} - \log(1 + e^{-\hat{y}}) - &= \log(1 + e^{-\hat{y}}) + \hat{y} - \hat{y} \\ - &= \log(1 + e^{-\hat{y}}) + \log(e^{\hat{y}}) - \hat{y} \\ - &= \log(1 + e^{\hat{y}}) - \hat{y}. - \end{align*} - - Moreover, the :math:`\hat{y} < 0` and :math:`\hat{y} \geq 0` cases can be - unified by writing - - .. math:: - - \log(1 + e^{-\hat{y}}) - = \log(1 + e^{-|\hat{y}|}) + \max\{-\hat{y}, 0\}. - - Thus, we arrive at the numerically stable formula shown earlier. - - References - ---------- - .. [1] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for - dense object detection. IEEE Transactions on Pattern Analysis and - Machine Intelligence, 2018. - (`DOI `__) - (`arXiv preprint `__) - - See Also - -------- - :meth:`~focal_loss.BinaryFocalLoss` - A wrapper around this function that makes it a - :class:`tf.keras.losses.Loss`. - """ - # Validate arguments - gamma = check_float(gamma, name='gamma', minimum=0) - pos_weight = check_float(pos_weight, name='pos_weight', minimum=0, - allow_none=True) - from_logits = check_bool(from_logits, name='from_logits') - label_smoothing = check_float(label_smoothing, name='label_smoothing', - minimum=0, maximum=1, allow_none=True) - - # Ensure predictions are a floating point tensor; converting labels to a - # tensor will be done in the helper functions - y_pred = tf.convert_to_tensor(y_pred) - if not y_pred.dtype.is_floating: - y_pred = tf.dtypes.cast(y_pred, dtype=tf.float32) - - # Delegate per-example loss computation to helpers depending on whether - # predictions are logits or probabilities - if from_logits: - return _binary_focal_loss_from_logits(labels=y_true, logits=y_pred, - gamma=gamma, - pos_weight=pos_weight, - label_smoothing=label_smoothing) - else: - return _binary_focal_loss_from_probs(labels=y_true, p=y_pred, - gamma=gamma, pos_weight=pos_weight, - label_smoothing=label_smoothing) - - -@tf.keras.utils.register_keras_serializable() -class BinaryFocalLoss(tf.keras.losses.Loss): - r"""Focal loss function for binary classification. - - This loss function generalizes binary cross-entropy by introducing a - hyperparameter called the *focusing parameter* that allows hard-to-classify - examples to be penalized more heavily relative to easy-to-classify examples. - - This class is a wrapper around :class:`~focal_loss.binary_focal_loss`. See - the documentation there for details about this loss function. - - Parameters - ---------- - gamma : float - The focusing parameter :math:`\gamma`. Must be non-negative. - - pos_weight : float, optional - The coefficient :math:`\alpha` to use on the positive examples. Must be - non-negative. - - from_logits : bool, optional - Whether model prediction will be logits or probabilities. - - label_smoothing : float, optional - Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary - ground truth labels are squeezed toward 0.5, with larger values of - `label_smoothing` leading to label values closer to 0.5. - - **kwargs : keyword arguments - Other keyword arguments for :class:`tf.keras.losses.Loss` (e.g., `name` - or `reduction`). - - Examples - -------- - - An instance of this class is a callable that takes a tensor of binary ground - truth labels `y_true` and a tensor of model predictions `y_pred` and returns - a scalar tensor obtained by reducing the per-example focal loss (the default - reduction is a batch-wise average). - - >>> from focal_loss import BinaryFocalLoss - >>> loss_func = BinaryFocalLoss(gamma=2) - >>> loss = loss_func([0, 1, 1], [0.1, 0.7, 0.9]) # A scalar tensor - >>> print(f'Mean focal loss: {loss.numpy():.3f}') - Mean focal loss: 0.011 - - Use this class in the :mod:`tf.keras` API like any other binary - classification loss function class found in :mod:`tf.keras.losses` (e.g., - :class:`tf.keras.losses.BinaryCrossentropy`: - - .. code-block:: python - - # Typical usage - model = tf.keras.Model(...) - model.compile( - optimizer=..., - loss=BinaryFocalLoss(gamma=2), # Used here like a tf.keras loss - metrics=..., - ) - history = model.fit(...) - - See Also - -------- - :meth:`~focal_loss.binary_focal_loss` - The function that performs the focal loss computation, taking a label - tensor and a prediction tensor and outputting a loss. - """ - - def __init__(self, gamma, *, pos_weight=None, from_logits=False, - label_smoothing=None, **kwargs): - # Validate arguments - gamma = check_float(gamma, name='gamma', minimum=0) - pos_weight = check_float(pos_weight, name='pos_weight', minimum=0, - allow_none=True) - from_logits = check_bool(from_logits, name='from_logits') - label_smoothing = check_float(label_smoothing, name='label_smoothing', - minimum=0, maximum=1, allow_none=True) - - super().__init__(**kwargs) - self.gamma = gamma - self.pos_weight = pos_weight - self.from_logits = from_logits - self.label_smoothing = label_smoothing - - def get_config(self): - """Returns the config of the layer. - - A layer config is a Python dictionary containing the configuration of a - layer. The same layer can be re-instantiated later (without its trained - weights) from this configuration. - - Returns - ------- - dict - This layer's config. - """ - config = super().get_config() - config.update(gamma=self.gamma, pos_weight=self.pos_weight, - from_logits=self.from_logits, - label_smoothing=self.label_smoothing) - return config - - def call(self, y_true, y_pred): - """Compute the per-example focal loss. - - This method simply calls :meth:`~focal_loss.binary_focal_loss` with the - appropriate arguments. - - Parameters - ---------- - y_true : tensor-like - Binary (0 or 1) class labels. - - y_pred : tensor-like - Either probabilities for the positive class or logits for the - positive class, depending on the `from_logits` attribute. The shapes - of `y_true` and `y_pred` should be broadcastable. - - Returns - ------- - :class:`tf.Tensor` - The per-example focal loss. Reduction to a scalar is handled by - this layer's :meth:`~focal_loss.BinaryFocalLoss.__call__` method. - """ - return binary_focal_loss(y_true=y_true, y_pred=y_pred, gamma=self.gamma, - pos_weight=self.pos_weight, - from_logits=self.from_logits, - label_smoothing=self.label_smoothing) - - -# Helper functions below - - -def _process_labels(labels, label_smoothing, dtype): - """Pre-process a binary label tensor, maybe applying smoothing. - - Parameters - ---------- - labels : tensor-like - Tensor of 0's and 1's. - - label_smoothing : float or None - Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary - ground truth labels `y_true` are squeezed toward 0.5, with larger values - of `label_smoothing` leading to label values closer to 0.5. - - dtype : tf.dtypes.DType - Desired type of the elements of `labels`. - - Returns - ------- - tf.Tensor - The processed labels. - """ - labels = tf.dtypes.cast(labels, dtype=dtype) - if label_smoothing is not None: - labels = (1 - label_smoothing) * labels + label_smoothing * 0.5 - return labels - - -def _binary_focal_loss_from_logits(labels, logits, gamma, pos_weight, - label_smoothing): - """Compute focal loss from logits using a numerically stable formula. - - Parameters - ---------- - labels : tensor-like - Tensor of 0's and 1's: binary class labels. - - logits : tf.Tensor - Logits for the positive class. - - gamma : float - Focusing parameter. - - pos_weight : float or None - If not None, losses for the positive class will be scaled by this - weight. - - label_smoothing : float or None - Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary - ground truth labels `y_true` are squeezed toward 0.5, with larger values - of `label_smoothing` leading to label values closer to 0.5. - - Returns - ------- - tf.Tensor - The loss for each example. - """ - labels = _process_labels(labels=labels, label_smoothing=label_smoothing, - dtype=logits.dtype) - - # Compute probabilities for the positive class - p = tf.math.sigmoid(logits) - - # Without label smoothing we can use TensorFlow's built-in per-example cross - # entropy loss functions and multiply the result by the modulating factor. - # Otherwise, we compute the focal loss ourselves using a numerically stable - # formula below - if label_smoothing is None: - # The labels and logits tensors' shapes need to be the same for the - # built-in cross-entropy functions. Since we want to allow broadcasting, - # we do some checks on the shapes and possibly broadcast explicitly - # Note: tensor.shape returns a tf.TensorShape, whereas tf.shape(tensor) - # returns an int tf.Tensor; this is why both are used below - labels_shape = labels.shape - logits_shape = logits.shape - if not labels_shape.is_fully_defined() or labels_shape != logits_shape: - labels_shape = tf.shape(labels) - logits_shape = tf.shape(logits) - shape = tf.broadcast_dynamic_shape(labels_shape, logits_shape) - labels = tf.broadcast_to(labels, shape) - logits = tf.broadcast_to(logits, shape) - if pos_weight is None: - loss_func = tf.nn.sigmoid_cross_entropy_with_logits - else: - loss_func = partial(tf.nn.weighted_cross_entropy_with_logits, - pos_weight=pos_weight) - loss = loss_func(labels=labels, logits=logits) - modulation_pos = (1 - p) ** gamma - modulation_neg = p ** gamma - mask = tf.dtypes.cast(labels, dtype=tf.bool) - modulation = tf.where(mask, modulation_pos, modulation_neg) - return modulation * loss - - # Terms for the positive and negative class components of the loss - pos_term = labels * ((1 - p) ** gamma) - neg_term = (1 - labels) * (p ** gamma) - - # Term involving the log and ReLU - log_weight = pos_term - if pos_weight is not None: - log_weight *= pos_weight - log_weight += neg_term - log_term = tf.math.log1p(tf.math.exp(-tf.math.abs(logits))) - log_term += tf.nn.relu(-logits) - log_term *= log_weight - - # Combine all the terms into the loss - loss = neg_term * logits + log_term - return loss - - -def _binary_focal_loss_from_probs(labels, p, gamma, pos_weight, - label_smoothing): - """Compute focal loss from probabilities. - - Parameters - ---------- - labels : tensor-like - Tensor of 0's and 1's: binary class labels. - - p : tf.Tensor - Estimated probabilities for the positive class. - - gamma : float - Focusing parameter. - - pos_weight : float or None - If not None, losses for the positive class will be scaled by this - weight. - - label_smoothing : float or None - Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary - ground truth labels `y_true` are squeezed toward 0.5, with larger values - of `label_smoothing` leading to label values closer to 0.5. - - Returns - ------- - tf.Tensor - The loss for each example. - """ - # Predicted probabilities for the negative class - q = 1 - p - - # For numerical stability (so we don't inadvertently take the log of 0) - p = tf.math.maximum(p, _EPSILON) - q = tf.math.maximum(q, _EPSILON) - - # Loss for the positive examples - pos_loss = -(q ** gamma) * tf.math.log(p) - if pos_weight is not None: - pos_loss *= pos_weight - - # Loss for the negative examples - neg_loss = -(p ** gamma) * tf.math.log(q) - - # Combine loss terms - if label_smoothing is None: - labels = tf.dtypes.cast(labels, dtype=tf.bool) - loss = tf.where(labels, pos_loss, neg_loss) - else: - labels = _process_labels(labels=labels, label_smoothing=label_smoothing, - dtype=p.dtype) - loss = labels * pos_loss + (1 - labels) * neg_loss - - return loss diff --git a/annopro/focal_loss/_categorical_focal_loss.py b/annopro/focal_loss/_categorical_focal_loss.py deleted file mode 100644 index ca9dc5b..0000000 --- a/annopro/focal_loss/_categorical_focal_loss.py +++ /dev/null @@ -1,311 +0,0 @@ -"""Multiclass focal loss implementation.""" -# __ _ _ -# / _| | | | | -# | |_ ___ ___ __ _ | | | | ___ ___ ___ -# | _| / _ \ / __| / _` | | | | | / _ \ / __| / __| -# | | | (_) | | (__ | (_| | | | | | | (_) | \__ \ \__ \ -# |_| \___/ \___| \__,_| |_| |_| \___/ |___/ |___/ - -import itertools -from typing import Any, Optional - -import tensorflow as tf - -_EPSILON = tf.keras.backend.epsilon() - - -def sparse_categorical_focal_loss(y_true, y_pred, gamma, *, - class_weight: Optional[Any] = None, - from_logits: bool = False, axis: int = -1 - ) -> tf.Tensor: - r"""Focal loss function for multiclass classification with integer labels. - - This loss function generalizes multiclass softmax cross-entropy by - introducing a hyperparameter called the *focusing parameter* that allows - hard-to-classify examples to be penalized more heavily relative to - easy-to-classify examples. - - See :meth:`~focal_loss.binary_focal_loss` for a description of the focal - loss in the binary setting, as presented in the original work [1]_. - - In the multiclass setting, with integer labels :math:`y`, focal loss is - defined as - - .. math:: - - L(y, \hat{\mathbf{p}}) - = -\left(1 - \hat{p}_y\right)^\gamma \log(\hat{p}_y) - - where - - * :math:`y \in \{0, \ldots, K - 1\}` is an integer class label (:math:`K` - denotes the number of classes), - * :math:`\hat{\mathbf{p}} = (\hat{p}_0, \ldots, \hat{p}_{K-1}) - \in [0, 1]^K` is a vector representing an estimated probability - distribution over the :math:`K` classes, - * :math:`\gamma` (gamma, not :math:`y`) is the *focusing parameter* that - specifies how much higher-confidence correct predictions contribute to - the overall loss (the higher the :math:`\gamma`, the higher the rate at - which easy-to-classify examples are down-weighted). - - The usual multiclass softmax cross-entropy loss is recovered by setting - :math:`\gamma = 0`. - - Parameters - ---------- - y_true : tensor-like - Integer class labels. - - y_pred : tensor-like - Either probabilities or logits, depending on the `from_logits` - parameter. - - gamma : float or tensor-like of shape (K,) - The focusing parameter :math:`\gamma`. Higher values of `gamma` make - easy-to-classify examples contribute less to the loss relative to - hard-to-classify examples. Must be non-negative. This can be a - one-dimensional tensor, in which case it specifies a focusing parameter - for each class. - - class_weight: tensor-like of shape (K,) - Weighting factor for each of the :math:`k` classes. If not specified, - then all classes are weighted equally. - - from_logits : bool, optional - Whether `y_pred` contains logits or probabilities. - - axis : int, optional - Channel axis in the `y_pred` tensor. - - Returns - ------- - :class:`tf.Tensor` - The focal loss for each example. - - Examples - -------- - - This function computes the per-example focal loss between a one-dimensional - integer label vector and a two-dimensional prediction matrix: - - >>> import numpy as np - >>> from focal_loss import sparse_categorical_focal_loss - >>> y_true = [0, 1, 2] - >>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]] - >>> loss = sparse_categorical_focal_loss(y_true, y_pred, gamma=2) - >>> np.set_printoptions(precision=3) - >>> print(loss.numpy()) - [0.009 0.032 0.082] - - Warnings - -------- - This function does not reduce its output to a scalar, so it cannot be passed - to :meth:`tf.keras.Model.compile` as a `loss` argument. Instead, use the - wrapper class :class:`~focal_loss.SparseCategoricalFocalLoss`. - - References - ---------- - .. [1] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for - dense object detection. IEEE Transactions on Pattern Analysis and - Machine Intelligence, 2018. - (`DOI `__) - (`arXiv preprint `__) - - See Also - -------- - :meth:`~focal_loss.SparseCategoricalFocalLoss` - A wrapper around this function that makes it a - :class:`tf.keras.losses.Loss`. - """ - # Process focusing parameter - gamma = tf.convert_to_tensor(gamma, dtype=tf.dtypes.float32) - gamma_rank = gamma.shape.rank - scalar_gamma = gamma_rank == 0 - - # Process class weight - if class_weight is not None: - class_weight = tf.convert_to_tensor(class_weight, - dtype=tf.dtypes.float32) - - # Process prediction tensor - y_pred = tf.convert_to_tensor(y_pred) - y_pred_rank = y_pred.shape.rank - if y_pred_rank is not None: - axis %= y_pred_rank - if axis != y_pred_rank - 1: - # Put channel axis last for sparse_softmax_cross_entropy_with_logits - perm = list(itertools.chain(range(axis), - range(axis + 1, y_pred_rank), [axis])) - y_pred = tf.transpose(y_pred, perm=perm) - elif axis != -1: - raise ValueError( - f'Cannot compute sparse categorical focal loss with axis={axis} on ' - 'a prediction tensor with statically unknown rank.') - y_pred_shape = tf.shape(y_pred) - - # Process ground truth tensor - y_true = tf.dtypes.cast(y_true, dtype=tf.dtypes.int64) - y_true_rank = y_true.shape.rank - - if y_true_rank is None: - raise NotImplementedError('Sparse categorical focal loss not supported ' - 'for target/label tensors of unknown rank') - - reshape_needed = (y_true_rank is not None and y_pred_rank is not None and - y_pred_rank != y_true_rank + 1) - if reshape_needed: - y_true = tf.reshape(y_true, [-1]) - y_pred = tf.reshape(y_pred, [-1, y_pred_shape[-1]]) - - if from_logits: - logits = y_pred - probs = tf.nn.softmax(y_pred, axis=-1) - else: - probs = y_pred - logits = tf.math.log(tf.clip_by_value(y_pred, _EPSILON, 1 - _EPSILON)) - - xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=y_true, - logits=logits, - ) - - y_true_rank = y_true.shape.rank - probs = tf.gather(probs, y_true, axis=-1, batch_dims=y_true_rank) - if not scalar_gamma: - gamma = tf.gather(gamma, y_true, axis=0, batch_dims=y_true_rank) - focal_modulation = (1 - probs) ** gamma - loss = focal_modulation * xent_loss - - if class_weight is not None: - class_weight = tf.gather(class_weight, y_true, axis=0, - batch_dims=y_true_rank) - loss *= class_weight - - if reshape_needed: - loss = tf.reshape(loss, y_pred_shape[:-1]) - - return loss - - -@tf.keras.utils.register_keras_serializable() -class SparseCategoricalFocalLoss(tf.keras.losses.Loss): - r"""Focal loss function for multiclass classification with integer labels. - - This loss function generalizes multiclass softmax cross-entropy by - introducing a hyperparameter :math:`\gamma` (gamma), called the - *focusing parameter*, that allows hard-to-classify examples to be penalized - more heavily relative to easy-to-classify examples. - - This class is a wrapper around - :class:`~focal_loss.sparse_categorical_focal_loss`. See the documentation - there for details about this loss function. - - Parameters - ---------- - gamma : float or tensor-like of shape (K,) - The focusing parameter :math:`\gamma`. Higher values of `gamma` make - easy-to-classify examples contribute less to the loss relative to - hard-to-classify examples. Must be non-negative. This can be a - one-dimensional tensor, in which case it specifies a focusing parameter - for each class. - - class_weight: tensor-like of shape (K,) - Weighting factor for each of the :math:`k` classes. If not specified, - then all classes are weighted equally. - - from_logits : bool, optional - Whether model prediction will be logits or probabilities. - - **kwargs : keyword arguments - Other keyword arguments for :class:`tf.keras.losses.Loss` (e.g., `name` - or `reduction`). - - Examples - -------- - - An instance of this class is a callable that takes a rank-one tensor of - integer class labels `y_true` and a tensor of model predictions `y_pred` and - returns a scalar tensor obtained by reducing the per-example focal loss (the - default reduction is a batch-wise average). - - >>> from focal_loss import SparseCategoricalFocalLoss - >>> loss_func = SparseCategoricalFocalLoss(gamma=2) - >>> y_true = [0, 1, 2] - >>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]] - >>> loss_func(y_true, y_pred) - - - Use this class in the :mod:`tf.keras` API like any other multiclass - classification loss function class that accepts integer labels found in - :mod:`tf.keras.losses` (e.g., - :class:`tf.keras.losses.SparseCategoricalCrossentropy`: - - .. code-block:: python - - # Typical usage - model = tf.keras.Model(...) - model.compile( - optimizer=..., - loss=SparseCategoricalFocalLoss(gamma=2), # Used here like a tf.keras loss - metrics=..., - ) - history = model.fit(...) - - See Also - -------- - :meth:`~focal_loss.sparse_categorical_focal_loss` - The function that performs the focal loss computation, taking a label - tensor and a prediction tensor and outputting a loss. - """ - - def __init__(self, gamma, class_weight: Optional[Any] = None, - from_logits: bool = False, **kwargs): - super().__init__(**kwargs) - self.gamma = gamma - self.class_weight = class_weight - self.from_logits = from_logits - - def get_config(self): - """Returns the config of the layer. - - A layer config is a Python dictionary containing the configuration of a - layer. The same layer can be re-instantiated later (without its trained - weights) from this configuration. - - Returns - ------- - dict - This layer's config. - """ - config = super().get_config() - config.update(gamma=self.gamma, class_weight=self.class_weight, - from_logits=self.from_logits) - return config - - def call(self, y_true, y_pred): - """Compute the per-example focal loss. - - This method simply calls - :meth:`~focal_loss.sparse_categorical_focal_loss` with the appropriate - arguments. - - Parameters - ---------- - y_true : tensor-like, shape (N,) - Integer class labels. - - y_pred : tensor-like, shape (N, K) - Either probabilities or logits, depending on the `from_logits` - parameter. - - Returns - ------- - :class:`tf.Tensor` - The per-example focal loss. Reduction to a scalar is handled by - this layer's - :meth:`~focal_loss.SparseCateogiricalFocalLoss.__call__` method. - """ - return sparse_categorical_focal_loss(y_true=y_true, y_pred=y_pred, - class_weight=self.class_weight, - gamma=self.gamma, - from_logits=self.from_logits) diff --git a/annopro/focal_loss/utils/__init__.py b/annopro/focal_loss/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/annopro/focal_loss/utils/validation.py b/annopro/focal_loss/utils/validation.py deleted file mode 100644 index c619c38..0000000 --- a/annopro/focal_loss/utils/validation.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Helper functions for function parameter validation.""" - -import numbers - - -def check_type(obj, base, *, name=None, func=None, allow_none=False, - default=None, error_message=None): - """Check whether an object is an instance of a base type. - - Parameters - ---------- - obj : object - The object to be validated. - - name : str - The name of `obj` in the calling function. - - base : type or tuple of type - The base type that `obj` should be an instance of. - - func: callable, optional - A function to be applied to `obj` if it is of type `base`. If None, no - function will be applied and `obj` will be returned as-is. - - allow_none : bool, optional - Indicates whether the value None should be allowed to pass through. - - default : object, optional - The default value to return if `obj` is None and `allow_none` is True. - If `default` is not None, it must be of type `base`, and it will have - `func` applied to it if `func` is not None. - - error_message : str or None, optional - Custom error message to display if the type is incorrect. - - Returns - ------- - base type or None - The validated object. - - Raises - ------ - TypeError - If `obj` is not an instance of `base`. - - Examples - -------- - >>> check_type(1, int) - 1 - >>> check_type(1, (int, str)) - 1 - >>> check_type(1, str) - Traceback (most recent call last): - ... - TypeError: Invalid type. Expected: str. Actual: int. - >>> check_type(1, (str, bool)) - Traceback (most recent call last): - ... - TypeError: Invalid type. Expected: (str, bool). Actual: int. - >>> print(check_type(None, str, allow_none=True)) - None - >>> check_type(1, str, name='num') - Traceback (most recent call last): - ... - TypeError: Invalid type for parameter 'num'. Expected: str. Actual: int. - >>> check_type(1, int, func=str) - '1' - >>> check_type(1, int, func='not callable') - Traceback (most recent call last): - ... - ValueError: Parameter 'func' must be callable or None. - >>> check_type(2.0, str, error_message='Not a string!') - Traceback (most recent call last): - ... - TypeError: Not a string! - >>> check_type(None, int, allow_none=True, default=0) - 0 - - """ - if allow_none and obj is None: - if default is not None: - return check_type(default, base=base, name=name, func=func, - allow_none=False) - return None - - if isinstance(obj, base): - if func is None: - return obj - elif callable(func): - return func(obj) - else: - raise ValueError('Parameter \'func\' must be callable or None.') - - # Handle wrong type - if isinstance(base, tuple): - expect = '(' + ', '.join(cls.__name__ for cls in base) + ')' - else: - expect = base.__name__ - actual = type(obj).__name__ - if error_message is None: - error_message = 'Invalid type' - if name is not None: - error_message += f' for parameter \'{name}\'' - error_message += f'. Expected: {expect}. Actual: {actual}.' - raise TypeError(error_message) - - -def check_bool(obj, *, name=None, allow_none=False, default=None): - """Validate boolean function arguments. - - Parameters - ---------- - obj : object - The object to be validated. - - name : str, optional - The name of `obj` in the calling function. - - allow_none : bool, optional - Indicates whether the value None should be allowed. - - default : object, optional - The default value to return if `obj` is None and `allow_none` is True. - - Returns - ------- - bool or None - The validated bool. - - Raises - ------ - TypeError - If `obj` is not an instance of bool. - - Examples - -------- - >>> check_bool(True) - True - >>> check_bool(1.0) - Traceback (most recent call last): - ... - TypeError: Invalid type. Expected: bool. Actual: float. - >>> a = (1 < 2) - >>> check_bool(a, name='a') - True - >>> b = 'not a bool' - >>> check_bool(b, name='b') - Traceback (most recent call last): - ... - TypeError: Invalid type for parameter 'b'. Expected: bool. Actual: str. - """ - return check_type(obj, name=name, base=bool, func=bool, - allow_none=allow_none, default=default) - - -def _check_numeric(*, check_func, obj, name, base, func, positive, minimum, - maximum, allow_none, default): - """Helper function for check_float and check_int.""" - obj = check_type(obj, name=name, base=base, func=func, - allow_none=allow_none, default=default) - - if obj is None: - return None - - positive = check_bool(positive, name='positive') - if positive and obj <= 0: - if name is None: - message = 'Parameter must be positive.' - else: - message = f'Parameter \'{name}\' must be positive.' - raise ValueError(message) - - if minimum is not None: - minimum = check_func(minimum, name='minimum') - if obj < minimum: - if name is None: - message = f'Parameter must be at least {minimum}.' - else: - message = f'Parameter \'{name}\' must be at least {minimum}.' - raise ValueError(message) - - if maximum is not None: - maximum = check_func(maximum, name='minimum') - if obj > maximum: - if name is None: - message = f'Parameter must be at most {maximum}.' - else: - message = f'Parameter \'{name}\' must be at most {maximum}.' - raise ValueError(message) - - return obj - - -def check_int(obj, *, name=None, positive=False, minimum=None, maximum=None, - allow_none=False, default=None): - """Validate integer function arguments. - - Parameters - ---------- - obj : object - The object to be validated. - - name : str, optional - The name of `obj` in the calling function. - - positive : bool, optional - Whether `obj` must be a positive integer (1 or greater). - - minimum : int, optional - The minimum value that `obj` can take (inclusive). - - maximum : int, optional - The maximum value that `obj` can take (inclusive). - - allow_none : bool, optional - Indicates whether the value None should be allowed. - - default : object, optional - The default value to return if `obj` is None and `allow_none` is True. - - Returns - ------- - int or None - The validated integer. - - Raises - ------ - TypeError - If `obj` is not an integer. - - ValueError - If any of the optional positivity or minimum and maximum value - constraints are violated. - - Examples - -------- - >>> check_int(0) - 0 - >>> check_int(1, positive=True) - 1 - >>> check_int(1.0) - Traceback (most recent call last): - ... - TypeError: Invalid type. Expected: Integral. Actual: float. - >>> check_int(-1, positive=True) - Traceback (most recent call last): - ... - ValueError: Parameter must be positive. - >>> check_int(1, name='a', minimum=10) - Traceback (most recent call last): - ... - ValueError: Parameter 'a' must be at least 10. - - """ - return _check_numeric(check_func=check_int, obj=obj, name=name, - base=numbers.Integral, func=int, positive=positive, - minimum=minimum, maximum=maximum, - allow_none=allow_none, default=default) - - -def check_float(obj, *, name=None, positive=False, minimum=None, maximum=None, - allow_none=False, default=None): - """Validate float function arguments. - - Parameters - ---------- - obj : object - The object to be validated. - - name : str, optional - The name of `obj` in the calling function. - - positive : bool, optional - Whether `obj` must be a positive float. - - minimum : float, optional - The minimum value that `obj` can take (inclusive). - - maximum : float, optional - The maximum value that `obj` can take (inclusive). - - allow_none : bool, optional - Indicates whether the value None should be allowed. - - default : object, optional - The default value to return if `obj` is None and `allow_none` is True. - - Returns - ------- - float or None - The validated float. - - Raises - ------ - TypeError - If `obj` is not a float. - - ValueError - If any of the optional positivity or minimum and maximum value - constraints are violated. - - Examples - -------- - >>> check_float(0) - 0.0 - >>> check_float(1.0, positive=True) - 1.0 - >>> check_float(1.0 + 1.0j) - Traceback (most recent call last): - ... - TypeError: Invalid type. Expected: Real. Actual: complex. - >>> check_float(-1, positive=True) - Traceback (most recent call last): - ... - ValueError: Parameter must be positive. - >>> check_float(1.2, name='a', minimum=10) - Traceback (most recent call last): - ... - ValueError: Parameter 'a' must be at least 10.0. - - """ - return _check_numeric(check_func=check_float, obj=obj, name=name, - base=numbers.Real, func=float, positive=positive, - minimum=minimum, maximum=maximum, - allow_none=allow_none, default=default) diff --git a/annopro/prediction.py b/annopro/prediction.py deleted file mode 100644 index 0c0188a..0000000 --- a/annopro/prediction.py +++ /dev/null @@ -1,181 +0,0 @@ -import os -from tensorflow.keras.utils import Sequence -from tensorflow.keras.models import load_model -import numpy as np -import pandas as pd -import math -import pickle -import annopro.resources as resources -from annopro.focal_loss import BinaryFocalLoss -from annopro.data_procession.utils import NAMESPACES, Ontology - - -def predict(output_dir: str, promap_features_file: str, - used_gpu: str = "-1", diamond_scores_file: str = None): - if output_dir == None: - raise ValueError("Must provide the input fasta sequences.") - os.environ["CUDA_VISIBLE_DEVICES"] = used_gpu - for term_type in NAMESPACES.keys(): - init_evaluate(term_type=term_type, - promap_features_file=promap_features_file, - diamond_scores_file=diamond_scores_file, - output_dir=output_dir) - - -class DFGenerator(Sequence): - def __init__(self, df, terms_dict, nb_classes, batch_size): - self.start = 0 - self.size = len(df) - self.df = df - self.batch_size = batch_size - self.nb_classes = nb_classes - self.terms_dict = terms_dict - - def __len__(self): - return np.ceil(len(self.df) / float(self.batch_size)).astype(np.int32) - - def __getitem__(self, idx): - batch_index = np.arange(idx * self.batch_size, - min(self.size, (idx + 1) * self.batch_size)) - df = self.df.iloc[batch_index] - labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) - feature_data = [] - protein_si = [] - for i, row in enumerate(df.itertuples()): - feature_data.append(list(row.Promap_feature)) - protein_si.append(list(row.Protein_similary)) - data_onehot = np.array(feature_data) - data_si = np.array(protein_si) - self.start += self.batch_size - return ([data_onehot, data_si]) - - def __next__(self): - return self.next() - - def reset(self): - self.start = 0 - - def next(self): - if self.start < self.size: - batch_index = np.arange( - self.start, min(self.size, self.start + self.batch_size)) - df = self.df.iloc[batch_index] - labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) - feature_data = [] - protein_si = [] - for i, row in enumerate(df.itertuples()): - feature_data.append(list(row.Promap_feature)) - protein_si.append(list(row.Protein_similary)) - data_onehot = np.array(feature_data) - data_si = np.array(protein_si) - self.start += self.batch_size - return ([data_onehot, data_si]) - else: - self.reset() - return self.next() - - -def diamond_score(diamond_scores_file, label, data_path, term_type): - with resources.open_binary("go.pkl") as file: - go: Ontology = pickle.load(file) - assert isinstance(go, Ontology) - with resources.open_binary("cafa_train.pkl") as file: - train_df = pd.read_pickle(file) - test_df = pd.read_pickle(data_path) - annotations = train_df['Prop_annotations'].values - annotations = list(map(lambda x: set(x), annotations)) - - prot_index = {} - for i, row in enumerate(train_df.itertuples()): - prot_index[row.Proteins] = i - - diamond_scores = {} - with open(diamond_scores_file) as f: - for line in f: - it = line.strip().split("\t") - if it[0] not in diamond_scores: - diamond_scores[it[0]] = {} - diamond_scores[it[0]][it[1]] = float(it[11]) - blast_preds = [] - - for i, row in enumerate(test_df.itertuples()): - annots = {} - prot_id = row.Proteins - # BlastKNN - if prot_id in diamond_scores: - sim_prots = diamond_scores[prot_id] - allgos = set() - total_score = 0.0 - for p_id, score in sim_prots.items(): - allgos |= annotations[prot_index[p_id]] - total_score += score - allgos = list(sorted(allgos)) - sim = np.zeros(len(allgos), dtype=np.float32) - for j, go_id in enumerate(allgos): - s = 0.0 - for p_id, score in sim_prots.items(): - if go_id in annotations[prot_index[p_id]]: - s += score - sim[j] = s / total_score - ind = np.argsort(-sim) - for go_id, score in zip(allgos, sim): - annots[go_id] = score - blast_preds.append(annots) - with resources.open_binary(f"terms_{NAMESPACES[term_type]}.pkl") as term_path: - terms = pd.read_pickle(term_path) - terms = terms['terms'].values.flatten() - alphas = {NAMESPACES['mf']: 0.55, - NAMESPACES['bp']: 0.6, NAMESPACES['cc']: 0.4} - - for i in range(0, len(label)): - annots_dict = blast_preds[i].copy() - for go_id in annots_dict: - annots_dict[go_id] *= alphas[go.get_namespace(go_id)] - for j in range(0, len(label[0])): - go_id = terms[j] - label[i, j] = label[i, j]*(1 - alphas[go.get_namespace(go_id)]) - if go_id in annots_dict: - label[i, j] = label[i, j] + annots_dict[go_id] - return label - - -def init_evaluate(term_type, promap_features_file, diamond_scores_file, output_dir: str, - data_size=8000, batch_size=16): - with resources.open_binary(f"terms_{NAMESPACES[term_type]}.pkl") as file: - terms_df = pd.read_pickle(file) - with open(promap_features_file, 'rb') as file: - data_df = pd.read_pickle(file) - if len(data_df) > data_size: - data_df = data_df.sample(n=data_size) - data_df.index = range(len(data_df)) - model = load_model( - resources.get_resource_path(f"{term_type}.h5"), - custom_objects={"focus_loss": BinaryFocalLoss}) - proteins = data_df["Proteins"] - terms = terms_df['terms'].values.flatten() - terms_dict = {v: i for i, v in enumerate(terms)} - nb_classes = len(terms) - data_generator = DFGenerator(data_df, terms_dict, nb_classes, batch_size) - data_steps = int(math.ceil(len(data_df) / batch_size)) - preds = model.predict(data_generator, steps=data_steps) - if diamond_scores_file: - preds = diamond_score(diamond_scores_file, preds, - promap_features_file, term_type=term_type) - # label_di=defaultdict(list) - protein = [] - go_terms = [] - score = [] - for i in range(len(preds)): - for j in range(len(preds[i])): - if preds[i][j] > 0: - protein.append(proteins[i]) - go_terms.append(terms[j]) - score.append(preds[i][j]) - res = [protein, go_terms, score] - res = pd.DataFrame(res) - res = res.T - res.columns = ['Proteins', 'GO-terms', 'Scores'] - res.sort_values(by='Scores', axis=0, ascending=False, inplace=True) - result_file = os.path.join(output_dir, f"{term_type}_result.csv") - res.to_csv(result_file, sep=',', index=False, header=True) - return res diff --git a/annopro/resources.py b/annopro/resources.py deleted file mode 100644 index de7bcad..0000000 --- a/annopro/resources.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -This module manages all required resources for annopro. -""" -from io import FileIO, TextIOWrapper -import os -import wget -import hashlib - -RESOURCE_DIR = os.path.join(os.path.expanduser("~"), ".annopro/data") -os.makedirs(RESOURCE_DIR, exist_ok=True) - -RESOURCE_DICT = { - "cafa_train.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/cafa_train.pkl", - "md5sum": "07d3e4334c31c914efec3f52cae5e498" - }, - "cafa4_del.csv": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/cafa4_del.csv", - "md5sum": "d00b71439084cb19b7d3d0d4fbbaa819" - }, - "cafa4.dmnd": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/cafa4.dmnd", - "md5sum": "a2a6cba9af26dbe1911e14a306db2712" - }, - "data_grid.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/data_grid.pkl", - "md5sum": "fb2d2d86a4bc21c6e60fac996b3a90d3" - }, - "go.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/go.pkl", - "md5sum": "8d7a975d38a4af670b0370f4ea722a2a" - }, - "go.txt": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/go.txt", - "md5sum": "1dae308468fa00ae6d5796fd22c65044" - }, - "row_asses.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/row_asses.pkl", - "md5sum": "bf9bb1eda744a60c381d19b275ac6f33" - }, - "terms_biological_process.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/terms_biological_process.pkl", - "md5sum": "e79cf5e006432c19606b8a482cf7ddfa" - }, - "terms_cellular_component.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/terms_cellular_component.pkl", - "md5sum": "8cf58075eba65e2bb710566e4ef93f42" - }, - "terms_molecular_function.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/terms_molecular_function.pkl", - "md5sum": "002c3696fbca0061402a68129b35dcd4" - }, - "bp.h5": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/model_param/bp.h5", - "md5sum": "7e19158e5252a70ff831f5f583b1c2ed" - }, - "cc.h5": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/model_param/cc.h5", - "md5sum": "73876beec9370ff56b58878cf4446d2c" - }, - "mf.h5": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/model_param/mf.h5", - "md5sum": "f4fb632f553afeb45571a29e46286bb8" - } -} - - -def md5sum(file_path: str) -> str: - with open(file_path, "rb") as f: - md5 = hashlib.md5() - while True: - data = f.read(65536) - if not data: - break - md5.update(data) - return md5.hexdigest() - - -def md5check(file_path: str, expected: str): - return md5sum(file_path).startswith(expected) - - -def download_resource(name: str, overwrite: bool = False) -> str: - if name in RESOURCE_DICT: - resource = RESOURCE_DICT[name] - path_name = os.path.join(RESOURCE_DIR, name) - if os.path.exists(path_name): - if overwrite or not md5check(path_name, resource["md5sum"]): - os.remove(path_name) - else: - return path_name - print(f"Download {name}...") - wget.download( - url=resource["url"], - out=path_name) - print(f"\nValidate md5sum of {name}...") - if not md5check(path_name, resource["md5sum"]): - raise RuntimeError(f"{name} do not pass md5 validation, please visit https://github.com/idrblab/AnnoPRO for help") - return path_name - else: - raise FileNotFoundError(f"Invalid resource name: {name}") - - -def get_resource_path(name: str) -> str: - return download_resource(name) - - -def open_binary(name: str) -> FileIO: - file_path = get_resource_path(name) - return open(file_path, "rb") - - -def open_text(name: str) -> TextIOWrapper: - file_path = get_resource_path(name) - return open(file_path, "rt")