From 894a010e90356626c55dd30b8041be51ab8e8c47 Mon Sep 17 00:00:00 2001 From: lingyan Date: Thu, 2 Nov 2023 19:12:02 +0800 Subject: [PATCH 1/2] Add files via upload --- annopro/annopro_bp.py | 595 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 595 insertions(+) create mode 100644 annopro/annopro_bp.py diff --git a/annopro/annopro_bp.py b/annopro/annopro_bp.py new file mode 100644 index 0000000..25f8af0 --- /dev/null +++ b/annopro/annopro_bp.py @@ -0,0 +1,595 @@ +# 要求的库与参数 +import os +from tensorflow.keras.utils import Sequence, plot_model +from tensorflow.keras.optimizers import Adam +from tensorflow.keras import models, layers +from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger, TensorBoard +# import matplotlib.pyplot as plt +from tensorflow.keras.models import Model, load_model +from tensorflow.keras.layers import ( + Input, Dense, Embedding, Conv2D, Flatten, Concatenate, TimeDistributed, + MaxPool2D, Dropout, RepeatVector, Layer, Reshape, SimpleRNN, LSTM, BatchNormalization, GRU, Reshape, + GlobalAveragePooling2D, GlobalMaxPooling2D, multiply, Permute, Add, Activation, Lambda, Permute, Multiply +) +import tensorflow as tf +import numpy as np +import pandas as pd +import math +import pickle +from sklearn.metrics import roc_curve, auc, matthews_corrcoef, precision_score, recall_score, roc_auc_score +from collections import deque, Counter +from tqdm import tqdm +from tensorflow.keras import backend as K +from argparse import ArgumentParser + + +# AAINDEX = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10, 'L': 11, +# 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'S': 16, 'T': 17, 'W': 18, 'Y': 19, 'V': 20} +# MAXLEN = 2000 + + +# def to_onehot(seq, start=0): +# onehot = np.zeros((MAXLEN, 21), dtype=np.int32) +# l = min(MAXLEN, len(seq)) +# for i in range(start, start + l): +# onehot[i, AAINDEX.get(seq[i - start], 0)] = 1 +# onehot[0:start, 0] = 1 +# onehot[start + l:, 0] = 1 +# return onehot + +# gpus = tf.config.experimental.list_physical_devices('GPU') + +# tf.config.experimental.set_memory_growth(gpus[0], True) + + + +class Ontology(object): + def __init__(self, filename='/public/home/zhengqq/data/Train/go-basic.obo', with_rels=False): + self.ont = self.load(filename, with_rels) + self.ic = None + + def has_term(self, term_id): + return term_id in self.ont + + def get_term(self, term_id): + if self.has_term(term_id): + return self.ont[term_id] + return None + + def get_anchestors(self, term_id): + if term_id not in self.ont: + return set() + term_set = set() + q = deque() + q.append(term_id) + while (len(q) > 0): + t_id = q.popleft() + if t_id not in term_set: + term_set.add(t_id) + for parent_id in self.ont[t_id]['is_a']: + if parent_id in self.ont: + q.append(parent_id) + return term_set + + def get_parents(self, term_id): + if term_id not in self.ont: + return set() + term_set = set() + for parent_id in self.ont[term_id]['is_a']: + if parent_id in self.ont: + term_set.add(parent_id) + return term_set + + def get_namespace_terms(self, namespace): + terms = set() + for go_id, obj in self.ont.items(): + if obj['namespace'] == namespace: + terms.add(go_id) + return terms + + def get_namespace(self, term_id): + return self.ont[term_id]['namespace'] + + def get_term_set(self, term_id): + if term_id not in self.ont: + return set() + term_set = set() + q = deque() + q.append(term_id) + while len(q) > 0: + t_id = q.popleft() + if t_id not in term_set: + term_set.add(t_id) + for ch_id in self.ont[t_id]['children']: + q.append(ch_id) + return term_set + + +class DFGenerator(Sequence): + def __init__(self, df, terms_dict, nb_classes, batch_size): + self.start = 0 + self.size = len(df) + self.df = df + self.batch_size = batch_size + self.nb_classes = nb_classes + self.terms_dict = terms_dict + + def __len__(self): + return np.ceil(len(self.df) / float(self.batch_size)).astype(np.int32) + + def __getitem__(self, idx): + batch_index = np.arange(idx * self.batch_size, min(self.size, (idx + 1) * self.batch_size)) + df = self.df.iloc[batch_index] + labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) + feature_data = [] + protein_si = [] + for i, row in enumerate(df.itertuples()): + feature_data.append(list(row.Promap_feature)) + protein_si.append(list(row.Protein_similary)) + data_onehot = np.array(feature_data) + data_si = np.array(protein_si) + for t_id in row.Annotations: + if t_id in self.terms_dict: + labels[i, self.terms_dict[t_id]] = 1 + self.start += self.batch_size + return ([data_onehot, data_si], labels) + + def __next__(self): + return self.next() + + def reset(self): + self.start = 0 + + def next(self): + if self.start < self.size: + batch_index = np.arange( + self.start, min(self.size, self.start + self.batch_size)) + df = self.df.iloc[batch_index] + labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) + feature_data = [] + protein_si = [] + for i, row in enumerate(df.itertuples()): + feature_data.append(list(row.Promap_feature)) + protein_si.append(list(row.Protein_similary)) + data_onehot = np.array(feature_data) + data_si = np.array(protein_si) + for t_id in row.Annotations: + if t_id in self.terms_dict: + labels[i, self.terms_dict[t_id]] = 1 + self.start += self.batch_size + return ([data_onehot, data_si], labels) + else: + self.reset() + return self.next() + + +def load_data(file): + f = open(file, 'rb') + data = pickle.load(f) + return data + + +def load_weight(model_path1, model_path2): + model = load_model(model_path1) + loaded_model = load_model(model_path2) + old_weights = loaded_model.get_weights() + now_weights = model.get_weights() + cnt = 0 + for i in range(len(old_weights)): + if old_weights[cnt].shape == now_weights[i].shape: + now_weights[i] = old_weights[cnt] + cnt = cnt + 1 + + print(f'{cnt} layers weights copied, total {len(now_weights)}') + model.set_weights(now_weights) + model.save(model_path1) + +def diamond_score(diamond_scores_file, label, data_path,term_path): + with open("/public/home/zhengqq/data/Train/go.pkl", 'rb') as file: + go = pickle.loads(file.read()) + train_df = pd.read_pickle("/public/home/zhengqq/data/Train/train_dic.pkl") + test_df = data_path + annotations = train_df['Annotation'].values + annotations = list(map(lambda x: set(x), annotations)) + prot_index = {} + for i, row in enumerate(train_df.itertuples()): + prot_index[row.Protein] = i + diamond_scores = {} + with open(diamond_scores_file) as f: + for line in f: + it = line.strip().split("\t") + if it[0] not in diamond_scores: + diamond_scores[it[0]] = {} + diamond_scores[it[0]][it[1]] = float(it[11]) + blast_preds = [] + + for i, row in enumerate(test_df.itertuples()): + annots = {} + prot_id = row.Proteins + # BlastKNN + if prot_id in diamond_scores: + sim_prots = diamond_scores[prot_id] + allgos = set() + total_score = 0.0 + for p_id, score in sim_prots.items(): + allgos |= annotations[prot_index[p_id]] + total_score += score + allgos = list(sorted(allgos)) + sim = np.zeros(len(allgos), dtype=np.float32) + for j, go_id in enumerate(allgos): + s = 0.0 + for p_id, score in sim_prots.items(): + if go_id in annotations[prot_index[p_id]]: + s += score + sim[j] = s / total_score + ind = np.argsort(-sim) + for go_id, score in zip(allgos, sim): + annots[go_id] = score + blast_preds.append(annots) + terms = pd.read_pickle(term_path) + terms = terms.terms.values.flatten() + terms_dict = {v: i for i, v in enumerate(terms)} + NAMESPACES = {'cc': 'cellular_component', 'mf': 'molecular_function', 'bp': 'biological_process'} + alphas = {NAMESPACES['mf']: 0.55, NAMESPACES['bp']: 0.6, NAMESPACES['cc']: 0.4} + + for i in range(0, len(label)): + annots_dict = blast_preds[i].copy() + for go_id in annots_dict: + annots_dict[go_id] *= alphas[go.get_namespace(go_id)] + for j in range(0, len(label[0])): + go_id = terms[j] + # print(go_id) + if go_id == 'GO:0071427': + label[i, j] = label[i, j]*(1 - alphas['biological_process']) + elif go_id == 'GO:0099061': + label[i, j] = label[i, j]*(1 - alphas['cellular_component']) + else: + label[i, j] = label[i, j]*(1 - alphas[go.get_namespace(go_id)]) + if go_id in annots_dict: + label[i, j] = label[i, j] + annots_dict[go_id] + return label + +# def plot_curve(history): +# plt.figure() +# x_range = range(0, len(history.history['loss'])) +# plt.plot(x_range, history.history['loss'], 'bo', label='Training loss') +# plt.plot(x_range, history.history['val_loss'], 'b', label='Validation loss') +# plt.title('Training and validation loss') +# plt.legend() + + +def init_evaluate(data_size, batch_size, model_file, data_path, term_path): + with open(term_path, 'rb') as file: + terms_df = pickle.load(file) + # with open(data_path, 'rb') as file: + data_df = data_path + if len(data_df) > data_size: + data_df = data_df.sample(n=data_size) + model = load_model(model_file) + # data_file = data_path.split('/')[-1].split('.')[0] + terms = terms_df.terms.values.flatten() + terms_dict = {v: i for i, v in enumerate(terms)} + nb_classes = len(terms) + labels = np.zeros((len(data_df), nb_classes), dtype=np.int32) + for i, row in enumerate(data_df.itertuples()): + for go_id in row.Annotations: + if go_id in terms_dict: + labels[i, terms_dict[go_id]] = 1 + print('predict……') + data_generator = DFGenerator(data_df, terms_dict, nb_classes, batch_size) + data_steps = int(math.ceil(len(data_df) / batch_size)) + preds = model.predict(data_generator, steps=data_steps) + # preds=diamond_score("/public/home/zhengqq/data/fanxiu/valid.txt",preds, data_path,term_path) #25个小数据集 + preds=diamond_score("/public/home/zhengqq/data/fanxiu/GB/20421pkl/20421.txt",preds, data_path,term_path) #human总数据集 + return terms, labels, preds + + +def fmeasure(real_annots, pred_annots): + cnt = 0 + precision = 0.0 + recall = 0.0 + p_total = 0 + for i in range(len(real_annots)): + if len(real_annots[i]) == 0: + continue + tp = set(real_annots[i]).intersection(set(pred_annots[i])) + fp = pred_annots[i] - tp + fn = real_annots[i] - tp + tpn = len(tp) + fpn = len(fp) + fnn = len(fn) + cnt += 1 + recall += tpn / (1.0 * (tpn + fnn)) + if len(pred_annots[i]) > 0: + p_total += 1 + precision_x = tpn / (1.0 * (tpn + fpn)) + precision += precision_x + recall /= cnt + if p_total > 0: + precision /= p_total + fscore = 0.0 + if precision + recall > 0: + fscore = 2 * precision * recall / (precision + recall) + return fscore, precision, recall + + +def evaluate_annotations(labels_np, preds_np, terms): + fmax = 0.0 + tmax = 0.0 + precisions = [] + recalls = [] + labels = list(map(lambda x: set(terms[x == 1]), labels_np)) + for t in range(1, 101): + threshold = t / 100.0 + preds = preds_np.copy() + preds[preds >= threshold] = 1 + preds[preds != 1] = 0 + # fscore, pr, rc = fmeasure(labels, prop_annotations(preds, terms)) + fscore, pr, rc = fmeasure(labels, list(map(lambda x: set(terms[x == 1]), preds))) + precisions.append(pr) + recalls.append(rc) + if fmax < fscore: + fmax = fscore + tmax = t + preds = preds_np.copy() + preds[preds >= tmax / 100.0] = 1 + preds[preds != 1] = 0 + mcc = matthews_corrcoef(labels_np.flatten(), preds.flatten()) + precisions = np.array(precisions) + recalls = np.array(recalls) + sorted_index = np.argsort(recalls) + recalls = recalls[sorted_index] + precisions = precisions[sorted_index] + return fmax, tmax, recalls, precisions, mcc + + +def evaluate(model_file, data_path, data_size=8000, batch_size=16, + term_path='/public/home/zhengqq/data/Train/new_terms/new_terms_MF.pkl'): + ont = ['GO:0003674', 'GO:0008150', 'GO:0005575'] + namespace = ['molecular_function', 'biological_process', 'cellular_component', 'all'] + terms, labels, preds = init_evaluate(data_size, batch_size, model_file, data_path, term_path) + # with open("/home/zhengly/promap/data/go.pkl", 'rb') as file: + # go = pickle.loads(file.read()) + # plt.figure(1, figsize=(16, 3)) + evaluate_info = f'{model_file}:\n' + print(f'evaluate ……') + # if i == 3: + # chose = np.ones(len(terms), dtype=bool) + # else: + # go_set = go.get_namespace_terms(namespace[i]) + # go_set.remove(ont[i]) + # chose = list(map(lambda x: x in go_set, terms)) + # _terms = terms[chose] + # _labels = labels[:, chose] + # _preds = preds[:, chose] + print(labels.flatten()) + roc_auc = roc_auc_score(labels.flatten(), preds.flatten()) + + fmax, alpha, recalls, precisions, mcc = evaluate_annotations(labels, preds, terms) + precision_t=precisions[alpha] + recall_t=recalls[alpha] + sorted_index = np.argsort(recalls) + recalls = recalls[sorted_index] + precisions = precisions[sorted_index] + AUPR=auc(recalls,precisions) + # plt.subplot(1, 4, i + 1) + # plt.plot(recalls, precisions, color='darkorange', lw=1, label=f'Fmax={fmax:0.3f}') + # plt.xlim([0.0, 1.0]) + # plt.ylim([0.0, 1.0]) + # plt.xlabel('Recall') + # plt.ylabel('Precision') + # plt.title(f'P-R curve of {namespace[i]}') + # plt.legend(loc="lower right") + evaluate_info += f'\t, {len(terms)}: fmax={fmax:0.3f}, mcc={mcc:0.3f}, roc_auc={roc_auc:0.3f}, precision={precisions[alpha]:0.3f}, recall={recalls[alpha]:0.3f}, threshold={alpha}, AUPR={AUPR}\n' + # plt.show() + # aucli = [] + # fs = [] + # with open(term_path, 'rb') as file: + # terms_df = pickle.load(file) + # tags = terms_df['tag'] + # for i in range(1, 10): + # tag_select = tags == i + # _terms = terms[tag_select] + # _labels = labels[:, tag_select] + # _preds = preds[:, tag_select] + # aucli.append(roc_auc_score(_labels.flatten(), _preds.flatten())) + # (res) = evaluate_annotations(_labels, _preds, _terms) + # fs.append(res[0]) + # plt.figure() + # plt.plot(range(1, len(aucli) + 1), aucli, lw=1, label=f'STD of auc={np.std(aucli):0.5f}') + # # plt.plot(range(1,len(fs)+1), fs, lw=1, color='orange', label=f'STD of fmax={np.std(fs):0.5f}') + # plt.xlabel('Depth') + # plt.legend(loc="lower right") + # plt.ylim([0.0, 1.0]) + # plt.show() + # evaluate_info += f'\tauc_std={np.std(aucli):0.5f}, fmax_std={np.std(fs):0.5f}\n' + # evaluate_info += f'\tauc_std={np.std(aucli):0.5f}\n' + # print(aucli) + print(evaluate_info) + with open("logfile.json", "a") as file: + file.write(evaluate_info) + + +def train_model(hparams): + print(tf.test.is_gpu_available()) + model = load_model(hparams.model_file) + # model.summary() + checkpointer = ModelCheckpoint(filepath=hparams.model_file, verbose=1, save_best_only=True) + earlystopper = EarlyStopping(monitor='val_loss', patience=3, verbose=1) + tbCallBack = TensorBoard(log_dir="./model", histogram_freq=1, write_grads=True) + logger = CSVLogger('/public/home/zhengqq/data/log/reCRNN_bi.log') + terms_df = pd.read_pickle(hparams.term_file) + terms = terms_df.terms.values.flatten() + terms_dict = {v: i for i, v in enumerate(terms)} + nb_classes = len(terms) + t = open(hparams.data_path, 'rb') + data_df = pickle.load(t) + # t.close() + if len(data_df) > hparams.data_size: + data_df = data_df.sample(n=hparams.data_size, random_state=918) + + if hparams.valid_path == 'none': + valid_df = data_df.sample(frac=0.2, random_state=918) + train_df = data_df[~data_df.index.isin(valid_df.index)] + else: + train_df = data_df + v = open(hparams.valid_path, 'rb') + valid_df = pickle.load(v) + # valid_steps = int(math.ceil(len(valid_df) / hparams.batch_size)) + # train_steps = int(math.ceil(len(train_df) / hparams.batch_size)) + # train_generator = DFGenerator(train_df, terms_dict, nb_classes, hparams.batch_size) + # valid_generator = DFGenerator(valid_df, terms_dict, nb_classes, hparams.batch_size) + # # 训练模型 + # if hparams.early: + # his = model.fit( + # train_generator, + # steps_per_epoch=train_steps, + # epochs=hparams.epochs, + # validation_data=valid_generator, + # validation_steps=valid_steps, + # max_queue_size=hparams.batch_size, + # workers=12, + # callbacks=[logger, checkpointer, earlystopper]) + # else: + # his = model.fit( + # train_generator, + # steps_per_epoch=train_steps, + # epochs=hparams.epochs, + # validation_data=valid_generator, + # validation_steps=valid_steps, + # max_queue_size=hparams.batch_size, + # workers=12, + # callbacks=[logger, checkpointer]) + evaluate(hparams.model_file, valid_df, data_size=10000, batch_size=32, term_path=hparams.term_file) + + + +def pretrain_model(hparams): + terms_df = pd.read_pickle(hparams.term_file) + terms = terms_df.terms.values.flatten() + batch_size = 32 + # promap通道 + params = { + 'max_kernel': 129, + 'initializer': 'glorot_normal', + 'dense_depth': 0, + 'nb_filters': 512, + 'optimizer': Adam(lr=2e-4), + 'loss': 'binary_crossentropy' + } + nb_classes = len(terms) + inp_hot = Input(shape=(39, 39, 7), dtype=np.float32) + cnn1 = Conv2D(64, (3, 3), activation='relu', input_shape=(39, 39, 7))(inp_hot) + pool1 = MaxPool2D((2, 2))(cnn1) + cnn2 = Conv2D(128, (3, 3), activation='relu')(pool1) + pool2 = MaxPool2D((2, 2))(cnn2) + cnn_out = Flatten()(pool2) + + + # protein_similary + inp_similary = Input(shape=(92120), dtype=np.float32) + encoded1 = Dense(2048, activation='relu')(inp_similary) + encoded2 = Dense(1024, activation='relu')(encoded1) + encoded3 = Dense(512, activation='relu')(encoded2) + decoded1 = Dense(1024, activation='relu')(encoded3) + decoded2 = Dense(2048, activation='relu')(decoded1) + + # concenate + concat = Concatenate()([cnn_out, decoded2]) + net = BatchNormalization()(concat) + net = Dropout(0.5)(net) + dense =Dense(nb_classes, activation='sigmoid')(net) + model = Model(inputs=[inp_hot, inp_similary], outputs=dense) + model.compile(optimizer=params['optimizer'], loss=params['loss']) + model.save(hparams.model_file) + model.summary() + + +def CRNN_model(hparams): + terms_df = pd.read_pickle(hparams.term_file) + terms = terms_df.index.values.flatten() + batch_size = 32 + # promap通道 + params = { + 'max_kernel': 129, + 'initializer': 'glorot_normal', + 'dense_depth': 0, + 'nb_filters': 512, + 'optimizer': Adam(lr=2e-4), + 'loss': 'binary_crossentropy' + } + nb_classes = len(terms) + inp_hot = Input(shape=(39, 39, 7), dtype=np.float32) + cnn1 = Conv2D(64, (3, 3), activation='relu', input_shape=(39, 39, 7),trainable=hparams.frozen)(inp_hot) + pool1 = MaxPool2D((2, 2),trainable=hparams.frozen)(cnn1) + cnn2 = Conv2D(128, (3, 3), activation='relu',trainable=hparams.frozen)(pool1) + pool2 = MaxPool2D((2, 2),trainable=hparams.frozen)(cnn2) + cnn_out = Flatten()(pool2) + + + # protein_similary + inp_similary = Input(shape=(92120), dtype=np.float32) + encoded1 = Dense(2048, activation='relu',trainable=hparams.frozen)(inp_similary) + encoded2 = Dense(1024, activation='relu',trainable=hparams.frozen)(encoded1) + encoded3 = Dense(512, activation='relu',trainable=hparams.frozen)(encoded2) + decoded1 = Dense(1024, activation='relu',trainable=hparams.frozen)(encoded3) + decoded2 = Dense(2048, activation='relu',trainable=hparams.frozen)(decoded1) + + # concenate + concat = Concatenate()([cnn_out, decoded2]) + net = BatchNormalization()(concat) + net = Dropout(0.5)(net) + out =Dense(nb_classes, activation='relu')(net) + drop = Dropout(0.5)(out) + repeat = RepeatVector(11)(drop) + gru1 = LSTM(256, activation='tanh', return_sequences=True)(repeat) + gru2 = LSTM(256, activation='tanh', return_sequences=True)(gru1) + gru3 = LSTM(256, activation='tanh', return_sequences=True)(gru2) + net = layers.Flatten()(gru3) + classify = layers.Dense(nb_classes, activation='sigmoid')(net) + model = Model(inputs=[inp_hot, inp_similary], outputs=classify) + model.compile(optimizer=params['optimizer'], loss=params['loss']) + if hparams.load: + loaded_model = load_model(hparams.origin_path) + old_weights = loaded_model.get_weights() + now_weights = model.get_weights() + cnt = 0 + for i in range(len(old_weights)): + if old_weights[cnt].shape == now_weights[i].shape: + now_weights[i] = old_weights[cnt] + cnt = cnt + 1 + print(i, cnt, len(now_weights), len(now_weights), len(old_weights)) + print(f'{cnt} layers weights copied, total {len(now_weights)}') + model.set_weights(now_weights) + model.save(hparams.model_file) + model.summary() + + +# pretrain_model('/public/home/zhengqq/data/model_parma/annopro/0727pre_annopro_bp.h5') +# train_model('/public/home/zhengqq/data/model_parma/annopro/0727pre_annopro_bp.h5', '/public/home/zhengqq/data/Train/cafa5/annopro/time_train_annopro.pkl', +# valid_path='/public/home/zhengqq/data/Train/cafa5/annopro/time_valid_annopro.pkl', epochs=200, data_size=60000) +# evaluate('0727pre_annopro_bp', '/public/home/zhengqq/data/Train/cafa5/annopro/time_valid_annopro.pkl', data_size=100000, batch_size=32, term_path='/public/home/zhengqq/data/Train/new_terms/new_terms_BP.pkl') +# CRNN_model('/public/home/zhengqq/data/model_parma/annopro/0727annopro_bp.h5',frozen=False,load=True,origin_path='/public/home/zhengqq/data/model_parma/annopro/0725pre_annopro.h5',term_file='/public/home/zhengqq/data/Train/new_terms/new_terms_BP.pkl') +# train_model('/public/home/zhengqq/data/model_parma/annopro/0727annopro_bp.h5','/public/home/zhengqq/data/Train/cafa5/annopro/time_train_annopro.pkl',valid_path='/public/home/zhengqq/data/Train/cafa5/annopro/time_valid_annopro.pkl', batch_size=32,epochs=200) +if __name__ =="__main__": + + parser = ArgumentParser() + parser.add_argument("--data_path", default='/public/home/zhengqq/data/fanxiu/GB/25train/BP_High_Normal.pkl') + parser.add_argument("--model_file", default='/public/home/zhengqq/data/fanxiu/model/BP_High_Normal.h5') + parser.add_argument("--term_file", default='/public/home/zhengqq/data/fanxiu/GB/25terms/BP_High_Normal.pkl') + parser.add_argument("--valid_path", default='none') + parser.add_argument("--data_size", default=1000000) + parser.add_argument("--batch_size", default=32) + parser.add_argument("--epochs", default=1000) + parser.add_argument("--early", default=True) + parser.add_argument("--gpu", default="0") + parser.add_argument("--frozen", default=False) + parser.add_argument("--load", default=True) + parser.add_argument("--origin_path", default=True) + args = parser.parse_args() + # pretrain_model(args) + # train_model(args) + + # CRNN_model(args) + train_model(args) + From 56035ffcac721027b5dc343d274021bb4e6d1173 Mon Sep 17 00:00:00 2001 From: lingyan Date: Mon, 15 Jan 2024 22:28:07 +0800 Subject: [PATCH 2/2] Delete annopro directory --- annopro/__init__.py | 82 --- annopro/__main__.py | 4 - annopro/_version.py | 658 ------------------ annopro/annopro_bp.py | 595 ---------------- annopro/data_procession/__init__.py | 11 - annopro/data_procession/data_predict.py | 99 --- annopro/data_procession/utils.py | 273 -------- annopro/focal_loss/__init__.py | 16 - annopro/focal_loss/_binary_focal_loss.py | 565 --------------- annopro/focal_loss/_categorical_focal_loss.py | 311 --------- annopro/focal_loss/utils/__init__.py | 0 annopro/focal_loss/utils/validation.py | 325 --------- annopro/prediction.py | 181 ----- annopro/resources.py | 115 --- 14 files changed, 3235 deletions(-) delete mode 100644 annopro/__init__.py delete mode 100644 annopro/__main__.py delete mode 100644 annopro/_version.py delete mode 100644 annopro/annopro_bp.py delete mode 100644 annopro/data_procession/__init__.py delete mode 100644 annopro/data_procession/data_predict.py delete mode 100644 annopro/data_procession/utils.py delete mode 100644 annopro/focal_loss/__init__.py delete mode 100644 annopro/focal_loss/_binary_focal_loss.py delete mode 100644 annopro/focal_loss/_categorical_focal_loss.py delete mode 100644 annopro/focal_loss/utils/__init__.py delete mode 100644 annopro/focal_loss/utils/validation.py delete mode 100644 annopro/prediction.py delete mode 100644 annopro/resources.py diff --git a/annopro/__init__.py b/annopro/__init__.py deleted file mode 100644 index 58bf170..0000000 --- a/annopro/__init__.py +++ /dev/null @@ -1,82 +0,0 @@ -def console_main(): - import argparse - parser = argparse.ArgumentParser(description='Arguments for AnnoPRO') - parser.add_argument("--fasta_file", "-i", help="The protein sequences file") - parser.add_argument('--output', "-o", default=None, - type=str, help="Output directory") - parser.add_argument('--used_gpu', default="-1", type=str, - help="GPU device selected, default is CPU") - parser.add_argument('--disable_diamond', - action='store_true', default=False, - help="Disable blast with diamond") - parser.add_argument('--overwrite', - action="store_true", - default=False, - help="Overwrite existed output" - ) - parser.add_argument("--version", - action="store_true", default=False, help="Show version") - args = parser.parse_args() - if args.version: - print("{} {}, Copyright Zhejiang University.".format( - __name__, __version__)) - exit(0) - elif args.fasta_file is None: - parser.print_help() - exit(1) - main( - proteins_fasta_file=args.fasta_file, - output_dir=args.output, - used_gpu=args.used_gpu, - with_diamond=(not args.disable_diamond), - overwrite=args.overwrite - ) - - -def main(proteins_fasta_file: str, output_dir: str = None, - used_gpu: str = None, with_diamond: bool = True, overwrite: bool = False): - from annopro.data_procession import process - from diamond4py import Diamond - from annopro import resources - from os.path import join, exists - from annopro.prediction import predict - from shutil import rmtree - import profeat - - if output_dir is None: - output_dir = proteins_fasta_file + ".output" - - if exists(output_dir): - if overwrite: - rmtree(output_dir) - else: - print(f"Output directory {output_dir} already existed!") - exit(1) - - profeat.run(proteins_fasta_file, output_dir) - - diamond_scores_file: str = None - if with_diamond: - diamond_scores_file = join(output_dir, "diamond_scores.txt") - diamond = Diamond( - database=resources.get_resource_path("cafa4.dmnd"), - n_threads=4 - ) - diamond.blastp( - query=proteins_fasta_file, - out=diamond_scores_file - ) - - promap_features_file = join(output_dir, "promap_features.pkl") - process( - proteins_fasta_file=proteins_fasta_file, - profeat_file=join(output_dir, "output-protein.dat"), - save_file=promap_features_file) - predict(output_dir=output_dir, - promap_features_file=promap_features_file, - used_gpu=used_gpu, - diamond_scores_file=diamond_scores_file) - - -from . import _version -__version__ = _version.get_versions()['version'] diff --git a/annopro/__main__.py b/annopro/__main__.py deleted file mode 100644 index 40827e9..0000000 --- a/annopro/__main__.py +++ /dev/null @@ -1,4 +0,0 @@ -from annopro import console_main - -if __name__ == "__main__": - console_main() \ No newline at end of file diff --git a/annopro/_version.py b/annopro/_version.py deleted file mode 100644 index 334f7ac..0000000 --- a/annopro/_version.py +++ /dev/null @@ -1,658 +0,0 @@ - -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. -# Generated by versioneer-0.28 -# https://github.com/python-versioneer/python-versioneer - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys -from typing import Callable, Dict -import functools - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "v" - cfg.parentdir_prefix = "annopro-" - cfg.versionfile_source = "annopro/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY: Dict[str, str] = {} -HANDLERS: Dict[str, Dict[str, Callable]] = {} - - -def register_vcs_handler(vcs, method): # decorator - """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - process = None - - popen_kwargs = {} - if sys.platform == "win32": - # This hides the console window if pythonw.exe is used - startupinfo = subprocess.STARTUPINFO() - startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW - popen_kwargs["startupinfo"] = startupinfo - - for command in commands: - try: - dispcmd = str([command] + args) - # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) - break - except OSError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = process.communicate()[0].strip().decode() - if process.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, process.returncode - return stdout, process.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for _ in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - with open(versionfile_abs, "r") as fobj: - for line in fobj: - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - except OSError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if "refnames" not in keywords: - raise NotThisMethod("Short version file found") - date = keywords.get("date") - if date is not None: - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = {r.strip() for r in refnames.strip("()").split(",")} - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - # Filter out refs that exactly match prefix or that don't start - # with a number once the prefix is stripped (mostly a concern - # when prefix is '') - if not re.match(r'\d', r): - continue - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - # GIT_DIR can interfere with correct operation of Versioneer. - # It may be intended to be passed to the Versioneer-versioned project, - # but that should not change where we get our version from. - env = os.environ.copy() - env.pop("GIT_DIR", None) - runner = functools.partial(runner, env=env) - - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) - # --abbrev-ref was added in git-1.6.3 - if rc != 0 or branch_name is None: - raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") - branch_name = branch_name.strip() - - if branch_name == "HEAD": - # If we aren't exactly on a branch, pick a branch which represents - # the current commit. If all else fails, we are on a branchless - # commit. - branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) - # --contains was added in git-1.5.4 - if rc != 0 or branches is None: - raise NotThisMethod("'git branch --contains' returned error") - branches = branches.split("\n") - - # Remove the first line if we're running detached - if "(" in branches[0]: - branches.pop(0) - - # Strip off the leading "* " from the list of branches. - branches = [branch[2:] for branch in branches] - if "master" in branches: - branch_name = "master" - elif not branches: - branch_name = None - else: - # Pick the first branch that is returned. Good or bad. - branch_name = branches[0] - - pieces["branch"] = branch_name - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) - pieces["distance"] = len(out.split()) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() - # Use only the last line. Previous lines may contain GPG signature - # information. - date = date.splitlines()[-1] - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_branch(pieces): - """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . - - The ".dev0" means not master branch. Note that .dev0 sorts backwards - (a feature branch will appear "older" than the master branch). - - Exceptions: - 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0" - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def pep440_split_post(ver): - """Split pep440 version string at the post-release segment. - - Returns the release segments before the post-release and the - post-release version number (or -1 if no post-release segment is present). - """ - vc = str.split(ver, ".post") - return vc[0], int(vc[1] or 0) if len(vc) == 2 else None - - -def render_pep440_pre(pieces): - """TAG[.postN.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post0.devDISTANCE - """ - if pieces["closest-tag"]: - if pieces["distance"]: - # update the post release segment - tag_version, post_version = pep440_split_post(pieces["closest-tag"]) - rendered = tag_version - if post_version is not None: - rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) - else: - rendered += ".post0.dev%d" % (pieces["distance"]) - else: - # no commits, use the tag as the version - rendered = pieces["closest-tag"] - else: - # exception #1 - rendered = "0.post0.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_post_branch(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . - - The ".dev0" means not master branch. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["branch"] != "master": - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-branch": - rendered = render_pep440_branch(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-post-branch": - rendered = render_pep440_post_branch(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} diff --git a/annopro/annopro_bp.py b/annopro/annopro_bp.py deleted file mode 100644 index 25f8af0..0000000 --- a/annopro/annopro_bp.py +++ /dev/null @@ -1,595 +0,0 @@ -# 要求的库与参数 -import os -from tensorflow.keras.utils import Sequence, plot_model -from tensorflow.keras.optimizers import Adam -from tensorflow.keras import models, layers -from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger, TensorBoard -# import matplotlib.pyplot as plt -from tensorflow.keras.models import Model, load_model -from tensorflow.keras.layers import ( - Input, Dense, Embedding, Conv2D, Flatten, Concatenate, TimeDistributed, - MaxPool2D, Dropout, RepeatVector, Layer, Reshape, SimpleRNN, LSTM, BatchNormalization, GRU, Reshape, - GlobalAveragePooling2D, GlobalMaxPooling2D, multiply, Permute, Add, Activation, Lambda, Permute, Multiply -) -import tensorflow as tf -import numpy as np -import pandas as pd -import math -import pickle -from sklearn.metrics import roc_curve, auc, matthews_corrcoef, precision_score, recall_score, roc_auc_score -from collections import deque, Counter -from tqdm import tqdm -from tensorflow.keras import backend as K -from argparse import ArgumentParser - - -# AAINDEX = {'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9, 'I': 10, 'L': 11, -# 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'S': 16, 'T': 17, 'W': 18, 'Y': 19, 'V': 20} -# MAXLEN = 2000 - - -# def to_onehot(seq, start=0): -# onehot = np.zeros((MAXLEN, 21), dtype=np.int32) -# l = min(MAXLEN, len(seq)) -# for i in range(start, start + l): -# onehot[i, AAINDEX.get(seq[i - start], 0)] = 1 -# onehot[0:start, 0] = 1 -# onehot[start + l:, 0] = 1 -# return onehot - -# gpus = tf.config.experimental.list_physical_devices('GPU') - -# tf.config.experimental.set_memory_growth(gpus[0], True) - - - -class Ontology(object): - def __init__(self, filename='/public/home/zhengqq/data/Train/go-basic.obo', with_rels=False): - self.ont = self.load(filename, with_rels) - self.ic = None - - def has_term(self, term_id): - return term_id in self.ont - - def get_term(self, term_id): - if self.has_term(term_id): - return self.ont[term_id] - return None - - def get_anchestors(self, term_id): - if term_id not in self.ont: - return set() - term_set = set() - q = deque() - q.append(term_id) - while (len(q) > 0): - t_id = q.popleft() - if t_id not in term_set: - term_set.add(t_id) - for parent_id in self.ont[t_id]['is_a']: - if parent_id in self.ont: - q.append(parent_id) - return term_set - - def get_parents(self, term_id): - if term_id not in self.ont: - return set() - term_set = set() - for parent_id in self.ont[term_id]['is_a']: - if parent_id in self.ont: - term_set.add(parent_id) - return term_set - - def get_namespace_terms(self, namespace): - terms = set() - for go_id, obj in self.ont.items(): - if obj['namespace'] == namespace: - terms.add(go_id) - return terms - - def get_namespace(self, term_id): - return self.ont[term_id]['namespace'] - - def get_term_set(self, term_id): - if term_id not in self.ont: - return set() - term_set = set() - q = deque() - q.append(term_id) - while len(q) > 0: - t_id = q.popleft() - if t_id not in term_set: - term_set.add(t_id) - for ch_id in self.ont[t_id]['children']: - q.append(ch_id) - return term_set - - -class DFGenerator(Sequence): - def __init__(self, df, terms_dict, nb_classes, batch_size): - self.start = 0 - self.size = len(df) - self.df = df - self.batch_size = batch_size - self.nb_classes = nb_classes - self.terms_dict = terms_dict - - def __len__(self): - return np.ceil(len(self.df) / float(self.batch_size)).astype(np.int32) - - def __getitem__(self, idx): - batch_index = np.arange(idx * self.batch_size, min(self.size, (idx + 1) * self.batch_size)) - df = self.df.iloc[batch_index] - labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) - feature_data = [] - protein_si = [] - for i, row in enumerate(df.itertuples()): - feature_data.append(list(row.Promap_feature)) - protein_si.append(list(row.Protein_similary)) - data_onehot = np.array(feature_data) - data_si = np.array(protein_si) - for t_id in row.Annotations: - if t_id in self.terms_dict: - labels[i, self.terms_dict[t_id]] = 1 - self.start += self.batch_size - return ([data_onehot, data_si], labels) - - def __next__(self): - return self.next() - - def reset(self): - self.start = 0 - - def next(self): - if self.start < self.size: - batch_index = np.arange( - self.start, min(self.size, self.start + self.batch_size)) - df = self.df.iloc[batch_index] - labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) - feature_data = [] - protein_si = [] - for i, row in enumerate(df.itertuples()): - feature_data.append(list(row.Promap_feature)) - protein_si.append(list(row.Protein_similary)) - data_onehot = np.array(feature_data) - data_si = np.array(protein_si) - for t_id in row.Annotations: - if t_id in self.terms_dict: - labels[i, self.terms_dict[t_id]] = 1 - self.start += self.batch_size - return ([data_onehot, data_si], labels) - else: - self.reset() - return self.next() - - -def load_data(file): - f = open(file, 'rb') - data = pickle.load(f) - return data - - -def load_weight(model_path1, model_path2): - model = load_model(model_path1) - loaded_model = load_model(model_path2) - old_weights = loaded_model.get_weights() - now_weights = model.get_weights() - cnt = 0 - for i in range(len(old_weights)): - if old_weights[cnt].shape == now_weights[i].shape: - now_weights[i] = old_weights[cnt] - cnt = cnt + 1 - - print(f'{cnt} layers weights copied, total {len(now_weights)}') - model.set_weights(now_weights) - model.save(model_path1) - -def diamond_score(diamond_scores_file, label, data_path,term_path): - with open("/public/home/zhengqq/data/Train/go.pkl", 'rb') as file: - go = pickle.loads(file.read()) - train_df = pd.read_pickle("/public/home/zhengqq/data/Train/train_dic.pkl") - test_df = data_path - annotations = train_df['Annotation'].values - annotations = list(map(lambda x: set(x), annotations)) - prot_index = {} - for i, row in enumerate(train_df.itertuples()): - prot_index[row.Protein] = i - diamond_scores = {} - with open(diamond_scores_file) as f: - for line in f: - it = line.strip().split("\t") - if it[0] not in diamond_scores: - diamond_scores[it[0]] = {} - diamond_scores[it[0]][it[1]] = float(it[11]) - blast_preds = [] - - for i, row in enumerate(test_df.itertuples()): - annots = {} - prot_id = row.Proteins - # BlastKNN - if prot_id in diamond_scores: - sim_prots = diamond_scores[prot_id] - allgos = set() - total_score = 0.0 - for p_id, score in sim_prots.items(): - allgos |= annotations[prot_index[p_id]] - total_score += score - allgos = list(sorted(allgos)) - sim = np.zeros(len(allgos), dtype=np.float32) - for j, go_id in enumerate(allgos): - s = 0.0 - for p_id, score in sim_prots.items(): - if go_id in annotations[prot_index[p_id]]: - s += score - sim[j] = s / total_score - ind = np.argsort(-sim) - for go_id, score in zip(allgos, sim): - annots[go_id] = score - blast_preds.append(annots) - terms = pd.read_pickle(term_path) - terms = terms.terms.values.flatten() - terms_dict = {v: i for i, v in enumerate(terms)} - NAMESPACES = {'cc': 'cellular_component', 'mf': 'molecular_function', 'bp': 'biological_process'} - alphas = {NAMESPACES['mf']: 0.55, NAMESPACES['bp']: 0.6, NAMESPACES['cc']: 0.4} - - for i in range(0, len(label)): - annots_dict = blast_preds[i].copy() - for go_id in annots_dict: - annots_dict[go_id] *= alphas[go.get_namespace(go_id)] - for j in range(0, len(label[0])): - go_id = terms[j] - # print(go_id) - if go_id == 'GO:0071427': - label[i, j] = label[i, j]*(1 - alphas['biological_process']) - elif go_id == 'GO:0099061': - label[i, j] = label[i, j]*(1 - alphas['cellular_component']) - else: - label[i, j] = label[i, j]*(1 - alphas[go.get_namespace(go_id)]) - if go_id in annots_dict: - label[i, j] = label[i, j] + annots_dict[go_id] - return label - -# def plot_curve(history): -# plt.figure() -# x_range = range(0, len(history.history['loss'])) -# plt.plot(x_range, history.history['loss'], 'bo', label='Training loss') -# plt.plot(x_range, history.history['val_loss'], 'b', label='Validation loss') -# plt.title('Training and validation loss') -# plt.legend() - - -def init_evaluate(data_size, batch_size, model_file, data_path, term_path): - with open(term_path, 'rb') as file: - terms_df = pickle.load(file) - # with open(data_path, 'rb') as file: - data_df = data_path - if len(data_df) > data_size: - data_df = data_df.sample(n=data_size) - model = load_model(model_file) - # data_file = data_path.split('/')[-1].split('.')[0] - terms = terms_df.terms.values.flatten() - terms_dict = {v: i for i, v in enumerate(terms)} - nb_classes = len(terms) - labels = np.zeros((len(data_df), nb_classes), dtype=np.int32) - for i, row in enumerate(data_df.itertuples()): - for go_id in row.Annotations: - if go_id in terms_dict: - labels[i, terms_dict[go_id]] = 1 - print('predict……') - data_generator = DFGenerator(data_df, terms_dict, nb_classes, batch_size) - data_steps = int(math.ceil(len(data_df) / batch_size)) - preds = model.predict(data_generator, steps=data_steps) - # preds=diamond_score("/public/home/zhengqq/data/fanxiu/valid.txt",preds, data_path,term_path) #25个小数据集 - preds=diamond_score("/public/home/zhengqq/data/fanxiu/GB/20421pkl/20421.txt",preds, data_path,term_path) #human总数据集 - return terms, labels, preds - - -def fmeasure(real_annots, pred_annots): - cnt = 0 - precision = 0.0 - recall = 0.0 - p_total = 0 - for i in range(len(real_annots)): - if len(real_annots[i]) == 0: - continue - tp = set(real_annots[i]).intersection(set(pred_annots[i])) - fp = pred_annots[i] - tp - fn = real_annots[i] - tp - tpn = len(tp) - fpn = len(fp) - fnn = len(fn) - cnt += 1 - recall += tpn / (1.0 * (tpn + fnn)) - if len(pred_annots[i]) > 0: - p_total += 1 - precision_x = tpn / (1.0 * (tpn + fpn)) - precision += precision_x - recall /= cnt - if p_total > 0: - precision /= p_total - fscore = 0.0 - if precision + recall > 0: - fscore = 2 * precision * recall / (precision + recall) - return fscore, precision, recall - - -def evaluate_annotations(labels_np, preds_np, terms): - fmax = 0.0 - tmax = 0.0 - precisions = [] - recalls = [] - labels = list(map(lambda x: set(terms[x == 1]), labels_np)) - for t in range(1, 101): - threshold = t / 100.0 - preds = preds_np.copy() - preds[preds >= threshold] = 1 - preds[preds != 1] = 0 - # fscore, pr, rc = fmeasure(labels, prop_annotations(preds, terms)) - fscore, pr, rc = fmeasure(labels, list(map(lambda x: set(terms[x == 1]), preds))) - precisions.append(pr) - recalls.append(rc) - if fmax < fscore: - fmax = fscore - tmax = t - preds = preds_np.copy() - preds[preds >= tmax / 100.0] = 1 - preds[preds != 1] = 0 - mcc = matthews_corrcoef(labels_np.flatten(), preds.flatten()) - precisions = np.array(precisions) - recalls = np.array(recalls) - sorted_index = np.argsort(recalls) - recalls = recalls[sorted_index] - precisions = precisions[sorted_index] - return fmax, tmax, recalls, precisions, mcc - - -def evaluate(model_file, data_path, data_size=8000, batch_size=16, - term_path='/public/home/zhengqq/data/Train/new_terms/new_terms_MF.pkl'): - ont = ['GO:0003674', 'GO:0008150', 'GO:0005575'] - namespace = ['molecular_function', 'biological_process', 'cellular_component', 'all'] - terms, labels, preds = init_evaluate(data_size, batch_size, model_file, data_path, term_path) - # with open("/home/zhengly/promap/data/go.pkl", 'rb') as file: - # go = pickle.loads(file.read()) - # plt.figure(1, figsize=(16, 3)) - evaluate_info = f'{model_file}:\n' - print(f'evaluate ……') - # if i == 3: - # chose = np.ones(len(terms), dtype=bool) - # else: - # go_set = go.get_namespace_terms(namespace[i]) - # go_set.remove(ont[i]) - # chose = list(map(lambda x: x in go_set, terms)) - # _terms = terms[chose] - # _labels = labels[:, chose] - # _preds = preds[:, chose] - print(labels.flatten()) - roc_auc = roc_auc_score(labels.flatten(), preds.flatten()) - - fmax, alpha, recalls, precisions, mcc = evaluate_annotations(labels, preds, terms) - precision_t=precisions[alpha] - recall_t=recalls[alpha] - sorted_index = np.argsort(recalls) - recalls = recalls[sorted_index] - precisions = precisions[sorted_index] - AUPR=auc(recalls,precisions) - # plt.subplot(1, 4, i + 1) - # plt.plot(recalls, precisions, color='darkorange', lw=1, label=f'Fmax={fmax:0.3f}') - # plt.xlim([0.0, 1.0]) - # plt.ylim([0.0, 1.0]) - # plt.xlabel('Recall') - # plt.ylabel('Precision') - # plt.title(f'P-R curve of {namespace[i]}') - # plt.legend(loc="lower right") - evaluate_info += f'\t, {len(terms)}: fmax={fmax:0.3f}, mcc={mcc:0.3f}, roc_auc={roc_auc:0.3f}, precision={precisions[alpha]:0.3f}, recall={recalls[alpha]:0.3f}, threshold={alpha}, AUPR={AUPR}\n' - # plt.show() - # aucli = [] - # fs = [] - # with open(term_path, 'rb') as file: - # terms_df = pickle.load(file) - # tags = terms_df['tag'] - # for i in range(1, 10): - # tag_select = tags == i - # _terms = terms[tag_select] - # _labels = labels[:, tag_select] - # _preds = preds[:, tag_select] - # aucli.append(roc_auc_score(_labels.flatten(), _preds.flatten())) - # (res) = evaluate_annotations(_labels, _preds, _terms) - # fs.append(res[0]) - # plt.figure() - # plt.plot(range(1, len(aucli) + 1), aucli, lw=1, label=f'STD of auc={np.std(aucli):0.5f}') - # # plt.plot(range(1,len(fs)+1), fs, lw=1, color='orange', label=f'STD of fmax={np.std(fs):0.5f}') - # plt.xlabel('Depth') - # plt.legend(loc="lower right") - # plt.ylim([0.0, 1.0]) - # plt.show() - # evaluate_info += f'\tauc_std={np.std(aucli):0.5f}, fmax_std={np.std(fs):0.5f}\n' - # evaluate_info += f'\tauc_std={np.std(aucli):0.5f}\n' - # print(aucli) - print(evaluate_info) - with open("logfile.json", "a") as file: - file.write(evaluate_info) - - -def train_model(hparams): - print(tf.test.is_gpu_available()) - model = load_model(hparams.model_file) - # model.summary() - checkpointer = ModelCheckpoint(filepath=hparams.model_file, verbose=1, save_best_only=True) - earlystopper = EarlyStopping(monitor='val_loss', patience=3, verbose=1) - tbCallBack = TensorBoard(log_dir="./model", histogram_freq=1, write_grads=True) - logger = CSVLogger('/public/home/zhengqq/data/log/reCRNN_bi.log') - terms_df = pd.read_pickle(hparams.term_file) - terms = terms_df.terms.values.flatten() - terms_dict = {v: i for i, v in enumerate(terms)} - nb_classes = len(terms) - t = open(hparams.data_path, 'rb') - data_df = pickle.load(t) - # t.close() - if len(data_df) > hparams.data_size: - data_df = data_df.sample(n=hparams.data_size, random_state=918) - - if hparams.valid_path == 'none': - valid_df = data_df.sample(frac=0.2, random_state=918) - train_df = data_df[~data_df.index.isin(valid_df.index)] - else: - train_df = data_df - v = open(hparams.valid_path, 'rb') - valid_df = pickle.load(v) - # valid_steps = int(math.ceil(len(valid_df) / hparams.batch_size)) - # train_steps = int(math.ceil(len(train_df) / hparams.batch_size)) - # train_generator = DFGenerator(train_df, terms_dict, nb_classes, hparams.batch_size) - # valid_generator = DFGenerator(valid_df, terms_dict, nb_classes, hparams.batch_size) - # # 训练模型 - # if hparams.early: - # his = model.fit( - # train_generator, - # steps_per_epoch=train_steps, - # epochs=hparams.epochs, - # validation_data=valid_generator, - # validation_steps=valid_steps, - # max_queue_size=hparams.batch_size, - # workers=12, - # callbacks=[logger, checkpointer, earlystopper]) - # else: - # his = model.fit( - # train_generator, - # steps_per_epoch=train_steps, - # epochs=hparams.epochs, - # validation_data=valid_generator, - # validation_steps=valid_steps, - # max_queue_size=hparams.batch_size, - # workers=12, - # callbacks=[logger, checkpointer]) - evaluate(hparams.model_file, valid_df, data_size=10000, batch_size=32, term_path=hparams.term_file) - - - -def pretrain_model(hparams): - terms_df = pd.read_pickle(hparams.term_file) - terms = terms_df.terms.values.flatten() - batch_size = 32 - # promap通道 - params = { - 'max_kernel': 129, - 'initializer': 'glorot_normal', - 'dense_depth': 0, - 'nb_filters': 512, - 'optimizer': Adam(lr=2e-4), - 'loss': 'binary_crossentropy' - } - nb_classes = len(terms) - inp_hot = Input(shape=(39, 39, 7), dtype=np.float32) - cnn1 = Conv2D(64, (3, 3), activation='relu', input_shape=(39, 39, 7))(inp_hot) - pool1 = MaxPool2D((2, 2))(cnn1) - cnn2 = Conv2D(128, (3, 3), activation='relu')(pool1) - pool2 = MaxPool2D((2, 2))(cnn2) - cnn_out = Flatten()(pool2) - - - # protein_similary - inp_similary = Input(shape=(92120), dtype=np.float32) - encoded1 = Dense(2048, activation='relu')(inp_similary) - encoded2 = Dense(1024, activation='relu')(encoded1) - encoded3 = Dense(512, activation='relu')(encoded2) - decoded1 = Dense(1024, activation='relu')(encoded3) - decoded2 = Dense(2048, activation='relu')(decoded1) - - # concenate - concat = Concatenate()([cnn_out, decoded2]) - net = BatchNormalization()(concat) - net = Dropout(0.5)(net) - dense =Dense(nb_classes, activation='sigmoid')(net) - model = Model(inputs=[inp_hot, inp_similary], outputs=dense) - model.compile(optimizer=params['optimizer'], loss=params['loss']) - model.save(hparams.model_file) - model.summary() - - -def CRNN_model(hparams): - terms_df = pd.read_pickle(hparams.term_file) - terms = terms_df.index.values.flatten() - batch_size = 32 - # promap通道 - params = { - 'max_kernel': 129, - 'initializer': 'glorot_normal', - 'dense_depth': 0, - 'nb_filters': 512, - 'optimizer': Adam(lr=2e-4), - 'loss': 'binary_crossentropy' - } - nb_classes = len(terms) - inp_hot = Input(shape=(39, 39, 7), dtype=np.float32) - cnn1 = Conv2D(64, (3, 3), activation='relu', input_shape=(39, 39, 7),trainable=hparams.frozen)(inp_hot) - pool1 = MaxPool2D((2, 2),trainable=hparams.frozen)(cnn1) - cnn2 = Conv2D(128, (3, 3), activation='relu',trainable=hparams.frozen)(pool1) - pool2 = MaxPool2D((2, 2),trainable=hparams.frozen)(cnn2) - cnn_out = Flatten()(pool2) - - - # protein_similary - inp_similary = Input(shape=(92120), dtype=np.float32) - encoded1 = Dense(2048, activation='relu',trainable=hparams.frozen)(inp_similary) - encoded2 = Dense(1024, activation='relu',trainable=hparams.frozen)(encoded1) - encoded3 = Dense(512, activation='relu',trainable=hparams.frozen)(encoded2) - decoded1 = Dense(1024, activation='relu',trainable=hparams.frozen)(encoded3) - decoded2 = Dense(2048, activation='relu',trainable=hparams.frozen)(decoded1) - - # concenate - concat = Concatenate()([cnn_out, decoded2]) - net = BatchNormalization()(concat) - net = Dropout(0.5)(net) - out =Dense(nb_classes, activation='relu')(net) - drop = Dropout(0.5)(out) - repeat = RepeatVector(11)(drop) - gru1 = LSTM(256, activation='tanh', return_sequences=True)(repeat) - gru2 = LSTM(256, activation='tanh', return_sequences=True)(gru1) - gru3 = LSTM(256, activation='tanh', return_sequences=True)(gru2) - net = layers.Flatten()(gru3) - classify = layers.Dense(nb_classes, activation='sigmoid')(net) - model = Model(inputs=[inp_hot, inp_similary], outputs=classify) - model.compile(optimizer=params['optimizer'], loss=params['loss']) - if hparams.load: - loaded_model = load_model(hparams.origin_path) - old_weights = loaded_model.get_weights() - now_weights = model.get_weights() - cnt = 0 - for i in range(len(old_weights)): - if old_weights[cnt].shape == now_weights[i].shape: - now_weights[i] = old_weights[cnt] - cnt = cnt + 1 - print(i, cnt, len(now_weights), len(now_weights), len(old_weights)) - print(f'{cnt} layers weights copied, total {len(now_weights)}') - model.set_weights(now_weights) - model.save(hparams.model_file) - model.summary() - - -# pretrain_model('/public/home/zhengqq/data/model_parma/annopro/0727pre_annopro_bp.h5') -# train_model('/public/home/zhengqq/data/model_parma/annopro/0727pre_annopro_bp.h5', '/public/home/zhengqq/data/Train/cafa5/annopro/time_train_annopro.pkl', -# valid_path='/public/home/zhengqq/data/Train/cafa5/annopro/time_valid_annopro.pkl', epochs=200, data_size=60000) -# evaluate('0727pre_annopro_bp', '/public/home/zhengqq/data/Train/cafa5/annopro/time_valid_annopro.pkl', data_size=100000, batch_size=32, term_path='/public/home/zhengqq/data/Train/new_terms/new_terms_BP.pkl') -# CRNN_model('/public/home/zhengqq/data/model_parma/annopro/0727annopro_bp.h5',frozen=False,load=True,origin_path='/public/home/zhengqq/data/model_parma/annopro/0725pre_annopro.h5',term_file='/public/home/zhengqq/data/Train/new_terms/new_terms_BP.pkl') -# train_model('/public/home/zhengqq/data/model_parma/annopro/0727annopro_bp.h5','/public/home/zhengqq/data/Train/cafa5/annopro/time_train_annopro.pkl',valid_path='/public/home/zhengqq/data/Train/cafa5/annopro/time_valid_annopro.pkl', batch_size=32,epochs=200) -if __name__ =="__main__": - - parser = ArgumentParser() - parser.add_argument("--data_path", default='/public/home/zhengqq/data/fanxiu/GB/25train/BP_High_Normal.pkl') - parser.add_argument("--model_file", default='/public/home/zhengqq/data/fanxiu/model/BP_High_Normal.h5') - parser.add_argument("--term_file", default='/public/home/zhengqq/data/fanxiu/GB/25terms/BP_High_Normal.pkl') - parser.add_argument("--valid_path", default='none') - parser.add_argument("--data_size", default=1000000) - parser.add_argument("--batch_size", default=32) - parser.add_argument("--epochs", default=1000) - parser.add_argument("--early", default=True) - parser.add_argument("--gpu", default="0") - parser.add_argument("--frozen", default=False) - parser.add_argument("--load", default=True) - parser.add_argument("--origin_path", default=True) - args = parser.parse_args() - # pretrain_model(args) - # train_model(args) - - # CRNN_model(args) - train_model(args) - diff --git a/annopro/data_procession/__init__.py b/annopro/data_procession/__init__.py deleted file mode 100644 index 78452d8..0000000 --- a/annopro/data_procession/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from annopro.data_procession.data_predict import Data_process - - -def process(proteins_fasta_file: str, profeat_file: str, save_file: str): - if proteins_fasta_file == None: - raise ValueError("Must provide the input fasta sequences.") - - data = Data_process(protein_file=profeat_file, - proteins_fasta_file=proteins_fasta_file, - save_file=save_file, num=1484) - data.calculate_feature(row_num=39, size=(39, 39, 7)) diff --git a/annopro/data_procession/data_predict.py b/annopro/data_procession/data_predict.py deleted file mode 100644 index 387102b..0000000 --- a/annopro/data_procession/data_predict.py +++ /dev/null @@ -1,99 +0,0 @@ -import numpy as np -import pandas as pd -import pickle - -from sklearn.metrics.pairwise import cosine_similarity -from annopro.data_procession.utils import Ontology, load_data, MinMaxScaleClip -from profeat import profeat_to_df -from annopro import resources -from fasta import FASTA - -class Data_process(): - - def __init__( - self, - protein_file, - proteins_fasta_file, - save_file, - num, - grid_file="data_grid.pkl", - assess_file="row_asses.pkl", - prosim_file="cafa4_del.csv"): - ''' - protein_file 是生成的蛋白特征文件,即profeat生成的文件 - split_file是需要包含所需信息的蛋白序列文件, - save_file是生成的文件需要保存的位置, - prosim_file是蛋白相似性比对的库 - grid_file,assess_file是使用的map的位置文件 - ''' - self.protein_file = protein_file - self.split_file = proteins_fasta_file - self.save_file = save_file - self.grid_file = grid_file - self.assess_file = assess_file - self.prosim_file = prosim_file - self.num = num - self.__data__() - - def __data__(self): - proteins_f = profeat_to_df(self.protein_file) - proteins_f.dropna(axis=0, inplace=True) - feature_data = proteins_f.iloc[:, :self.num] - with resources.open_text("cafa4_del.csv") as cafa4_del: - mia_data = load_data(cafa4_del, 1485) - mia_data.columns = range(len(mia_data.columns)) - feature_data = (feature_data - mia_data.min()) / \ - ((mia_data.max() - mia_data.min()) + 1e-8) - self.feature_data:pd.DataFrame = feature_data - self.proteins = list(proteins_f.index) - with resources.open_binary(self.grid_file) as gf: - self.data_grid = pickle.load(gf) - with resources.open_binary(self.assess_file) as af: - self.row_asses = pickle.load(af) - with resources.open_text(self.prosim_file) as pf: - prosim_data = load_data(pf, 1485) - # this will be minmax by column - prosim_standard = MinMaxScaleClip(prosim_data) - prosim_map = prosim_standard.to_numpy() - # this will be global minmax scale - prosim_map:np.ndarray = MinMaxScaleClip(prosim_map) - self.prosim_map = prosim_map - self.go = Ontology() - - def calculate_feature(self, row_num, size): - protein_seqs = FASTA(self.split_file).sequences - class_labels = ['Composition', 'Autocorrelation', 'Physiochemical', 'Interaction', - 'Quasi-sequence-order descriptors', 'PAAC for amino acid index set', 'Amphiphilic Pseudo amino acid composition'] - protein_all = [] - sequences_all = [] - feature_all = [] - prosim_all = [] - data_grid = self.data_grid - row_asses = self.row_asses - proteins = self.proteins - result = self.feature_data.to_numpy() - consine_similarity = 1-cosine_similarity(result, self.prosim_map) - - for i, protein in enumerate(proteins): - if protein in protein_seqs: - col_list = np.zeros(size) - row = 0 - col = 0 - protein_all.append(protein) - sequences_all.append(str(protein_seqs[protein].seq)) - prosim_all.append(consine_similarity[i]) - # 构建promap特征 - for j in range(len(data_grid['x'])): - channel = data_grid['subtype'][j] - index = class_labels.index(channel) - feature_index = row_asses[j] - row = j % row_num - col = j//row_num - col_list[col][row][index] = self.feature_data.iloc[i, feature_index] - feature_all.append(col_list) - data_t = [protein_all, sequences_all, feature_all, prosim_all] - data_t = pd.DataFrame(data_t) - data_t = data_t.T - data_t.columns = ['Proteins', 'Sequence', - 'Promap_feature', 'Protein_similary'] - data_t.to_pickle(self.save_file) diff --git a/annopro/data_procession/utils.py b/annopro/data_procession/utils.py deleted file mode 100644 index fdb43cd..0000000 --- a/annopro/data_procession/utils.py +++ /dev/null @@ -1,273 +0,0 @@ -from collections import deque, Counter -from tensorflow.keras.utils import Sequence -import numpy as np -import pandas as pd -import math -from annopro import resources -from typing import Union - -BIOLOGICAL_PROCESS = 'GO:0008150' -MOLECULAR_FUNCTION = 'GO:0003674' -CELLULAR_COMPONENT = 'GO:0005575' -FUNC_DICT = { - 'cc': CELLULAR_COMPONENT, - 'mf': MOLECULAR_FUNCTION, - 'bp': BIOLOGICAL_PROCESS} - -NAMESPACES = { - 'cc': 'cellular_component', - 'mf': 'molecular_function', - 'bp': 'biological_process' -} - -EXP_CODES = set([ - 'EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC', - 'HTP', 'HDA', 'HMP', 'HGI', 'HEP']) - -# CAFA4 Targets -CAFA_TARGETS = set([ - '287', '3702', '4577', '6239', '7227', '7955', '9606', '9823', '10090', - '10116', '44689', '83333', '99287', '226900', '243273', '284812', '559292']) - - -def is_cafa_target(org): - return org in CAFA_TARGETS - - -def is_exp_code(code): - return code in EXP_CODES - - -class Ontology(object): - - def __init__(self, with_rels=False): - self.ont = self.load(with_rels) - self.ic = None - - def has_term(self, term_id): - return term_id in self.ont - - def get_term(self, term_id): - if self.has_term(term_id): - return self.ont[term_id] - return None - - def calculate_ic(self, annots): - cnt = Counter() - for x in annots: - cnt.update(x) - self.ic = {} - for go_id, n in cnt.items(): - parents = self.get_parents(go_id) - if len(parents) == 0: - min_n = n - else: - min_n = min([cnt[x] for x in parents]) - - self.ic[go_id] = math.log(min_n / n, 2) - - def get_ic(self, go_id): - if self.ic is None: - raise Exception('Not yet calculated') - if go_id not in self.ic: - return 0.0 - return self.ic[go_id] - - def load(self, with_rels): - ont = dict() - obj = None - with resources.open_text("go.txt") as f: - for line in f: - line = line.strip() - if not line: - continue - if line == '[Term]': - if obj is not None: - ont[obj['id']] = obj - obj = dict() - obj['is_a'] = list() - obj['part_of'] = list() - obj['regulates'] = list() - obj['alt_ids'] = list() - obj['is_obsolete'] = False - continue - elif line == '[Typedef]': - if obj is not None: - ont[obj['id']] = obj - obj = None - else: - if obj is None: - continue - l = line.split(": ") - if l[0] == 'id': - obj['id'] = l[1] - elif l[0] == 'alt_id': - obj['alt_ids'].append(l[1]) - elif l[0] == 'namespace': - obj['namespace'] = l[1] - elif l[0] == 'is_a': - obj['is_a'].append(l[1].split(' ! ')[0]) - elif with_rels and l[0] == 'relationship': - it = l[1].split() - # add all types of relationships - obj['is_a'].append(it[1]) - elif l[0] == 'name': - obj['name'] = l[1] - elif l[0] == 'is_obsolete' and l[1] == 'true': - obj['is_obsolete'] = True - if obj is not None: - ont[obj['id']] = obj - for term_id in list(ont.keys()): - for t_id in ont[term_id]['alt_ids']: - ont[t_id] = ont[term_id] - if ont[term_id]['is_obsolete']: - del ont[term_id] - for term_id, val in ont.items(): - if 'children' not in val: - val['children'] = set() - for p_id in val['is_a']: - if p_id in ont: - if 'children' not in ont[p_id]: - ont[p_id]['children'] = set() - ont[p_id]['children'].add(term_id) - return ont - - def get_anchestors(self, term_id): - if term_id not in self.ont: - return set() - term_set = set() - q = deque() - q.append(term_id) - while(len(q) > 0): - t_id = q.popleft() - if t_id not in term_set: - term_set.add(t_id) - for parent_id in self.ont[t_id]['is_a']: - if parent_id in self.ont: - q.append(parent_id) - return term_set - - def get_parents(self, term_id): - if term_id not in self.ont: - return set() - term_set = set() - for parent_id in self.ont[term_id]['is_a']: - if parent_id in self.ont: - term_set.add(parent_id) - return term_set - - def get_namespace_terms(self, namespace): - terms = set() - for go_id, obj in self.ont.items(): - if obj['namespace'] == namespace: - terms.add(go_id) - return terms - - def get_namespace(self, term_id): - return self.ont[term_id]['namespace'] - - def get_term_set(self, term_id): - if term_id not in self.ont: - return set() - term_set = set() - q = deque() - q.append(term_id) - while len(q) > 0: - t_id = q.popleft() - if t_id not in term_set: - term_set.add(t_id) - for ch_id in self.ont[t_id]['children']: - q.append(ch_id) - return term_set - - -def read_fasta(filename): - seqs = list() - info = list() - seq = '' - inf = '' - with open(filename, 'r') as f: - for line in f: - line = line.strip() - if line.startswith('>'): - if seq != '': - seqs.append(seq) - info.append(inf) - seq = '' - inf = line[1:] - else: - seq += line - seqs.append(seq) - info.append(inf) - return info, seqs - - -class DFGenerator(Sequence): - def __init__(self, df, terms_dict, nb_classes, batch_size): - self.start = 0 - self.size = len(df) - self.df = df - self.batch_size = batch_size - self.nb_classes = nb_classes - self.terms_dict = terms_dict - - def __len__(self): - return np.ceil(len(self.df) / float(self.batch_size)).astype(np.int32) - - def __getitem__(self, idx): - batch_index = np.arange(idx * self.batch_size, - min(self.size, (idx + 1) * self.batch_size)) - df = self.df.iloc[batch_index] - labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) - feature_data = [] - protein_si = [] - for i, row in enumerate(df.itertuples()): - feature_data.append(list(row.Promap_feature)) - protein_si.append(list(row.Protein_similary)) - data_onehot = np.array(feature_data) - data_si = np.array(protein_si) - for t_id in row.Prop_annotations: - if t_id in self.terms_dict: - labels[i, self.terms_dict[t_id]] = 1 - self.start += self.batch_size - return ([data_onehot, data_si], labels) - - def __next__(self): - return self.next() - - def reset(self): - self.start = 0 - - def next(self): - if self.start < self.size: - batch_index = np.arange( - self.start, min(self.size, self.start + self.batch_size)) - df = self.df.iloc[batch_index] - labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) - feature_data = [] - protein_si = [] - for i, row in enumerate(df.itertuples()): - feature_data.append(list(row.Promap_feature)) - protein_si.append(list(row.Protein_similary)) - data_onehot = np.array(feature_data) - data_si = np.array(protein_si) - for t_id in row.Prop_annotations: - if t_id in self.terms_dict: - labels[i, self.terms_dict[t_id]] = 1 - self.start += self.batch_size - return ([data_onehot, data_si], labels) - else: - self.reset() - return self.next() - - -def load_data(file:str, num:int): - data = pd.read_csv(file, header=None) - data.dropna(axis=0, inplace=True) - data_noprotein = data.iloc[:, 1:num] - return data_noprotein - - -def MinMaxScaleClip(data: Union[pd.DataFrame, np.ndarray]) -> Union[pd.DataFrame, np.ndarray]: - data_standard = (data - data.min()) / ((data.max() - data.min()) + 1e-8) - return data_standard \ No newline at end of file diff --git a/annopro/focal_loss/__init__.py b/annopro/focal_loss/__init__.py deleted file mode 100644 index 3230f58..0000000 --- a/annopro/focal_loss/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from ._binary_focal_loss import binary_focal_loss -from ._binary_focal_loss import BinaryFocalLoss -from ._categorical_focal_loss import sparse_categorical_focal_loss -from ._categorical_focal_loss import SparseCategoricalFocalLoss - -# Package information -__package__ = 'focal-loss' -__version__ = '0.0.8' -__author__ = 'Artem Mavrin' -__author_email__ = 'artemvmavrin@gmail.com' -__description__ = 'TensorFlow implementation of focal loss.' -__url__ = 'https://github.com/artemmavrin/focal-loss' -__copyright__ = 'Copyright 2020 Artem Mavrin' -__license__ = 'Apache 2.0' - -__doc__ = __description__ diff --git a/annopro/focal_loss/_binary_focal_loss.py b/annopro/focal_loss/_binary_focal_loss.py deleted file mode 100644 index eadb139..0000000 --- a/annopro/focal_loss/_binary_focal_loss.py +++ /dev/null @@ -1,565 +0,0 @@ -"""Binary focal loss implementation.""" -# ____ __ ___ __ __ __ __ ____ ____ -# ( __)/ \ / __) / _\ ( ) ( ) / \ / ___)/ ___) -# ) _)( O )( (__ / \/ (_/\ / (_/\( O )\___ \\___ \ -# (__) \__/ \___)\_/\_/\____/ \____/ \__/ (____/(____/ - -from functools import partial - -import tensorflow as tf - -from .utils.validation import check_bool, check_float - -_EPSILON = tf.keras.backend.epsilon() - - -def binary_focal_loss(y_true, y_pred, gamma, *, pos_weight=None, - from_logits=False, label_smoothing=None): - r"""Focal loss function for binary classification. - - This loss function generalizes binary cross-entropy by introducing a - hyperparameter :math:`\gamma` (gamma), called the *focusing parameter*, - that allows hard-to-classify examples to be penalized more heavily relative - to easy-to-classify examples. - - The focal loss [1]_ is defined as - - .. math:: - - L(y, \hat{p}) - = -\alpha y \left(1 - \hat{p}\right)^\gamma \log(\hat{p}) - - (1 - y) \hat{p}^\gamma \log(1 - \hat{p}) - - where - - * :math:`y \in \{0, 1\}` is a binary class label, - * :math:`\hat{p} \in [0, 1]` is an estimate of the probability of the - positive class, - * :math:`\gamma` is the *focusing parameter* that specifies how much - higher-confidence correct predictions contribute to the overall loss - (the higher the :math:`\gamma`, the higher the rate at which - easy-to-classify examples are down-weighted). - * :math:`\alpha` is a hyperparameter that governs the trade-off between - precision and recall by weighting errors for the positive class up or - down (:math:`\alpha=1` is the default, which is the same as no - weighting), - - The usual weighted binary cross-entropy loss is recovered by setting - :math:`\gamma = 0`. - - Parameters - ---------- - y_true : tensor-like - Binary (0 or 1) class labels. - - y_pred : tensor-like - Either probabilities for the positive class or logits for the positive - class, depending on the `from_logits` parameter. The shapes of `y_true` - and `y_pred` should be broadcastable. - - gamma : float - The focusing parameter :math:`\gamma`. Higher values of `gamma` make - easy-to-classify examples contribute less to the loss relative to - hard-to-classify examples. Must be non-negative. - - pos_weight : float, optional - The coefficient :math:`\alpha` to use on the positive examples. Must be - non-negative. - - from_logits : bool, optional - Whether `y_pred` contains logits or probabilities. - - label_smoothing : float, optional - Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary - ground truth labels `y_true` are squeezed toward 0.5, with larger values - of `label_smoothing` leading to label values closer to 0.5. - - Returns - ------- - :class:`tf.Tensor` - The focal loss for each example (assuming `y_true` and `y_pred` have the - same shapes). In general, the shape of the output is the result of - broadcasting the shapes of `y_true` and `y_pred`. - - Warnings - -------- - This function does not reduce its output to a scalar, so it cannot be passed - to :meth:`tf.keras.Model.compile` as a `loss` argument. Instead, use the - wrapper class :class:`~focal_loss.BinaryFocalLoss`. - - Examples - -------- - - This function computes the per-example focal loss between a label and - prediction tensor: - - >>> import numpy as np - >>> from focal_loss import binary_focal_loss - >>> loss = binary_focal_loss([0, 1, 1], [0.1, 0.7, 0.9], gamma=2) - >>> np.set_printoptions(precision=3) - >>> print(loss.numpy()) - [0.001 0.032 0.001] - - Below is a visualization of the focal loss between the positive class and - predicted probabilities between 0 and 1. Note that as :math:`\gamma` - increases, the losses for predictions closer to 1 get smoothly pushed to 0. - - .. plot:: - :include-source: - :align: center - - import numpy as np - import matplotlib.pyplot as plt - - from focal_loss import binary_focal_loss - - ps = np.linspace(0, 1, 100) - gammas = (0, 0.5, 1, 2, 5) - - plt.figure() - for gamma in gammas: - loss = binary_focal_loss(1, ps, gamma=gamma) - label = rf'$\gamma$={gamma}' - if gamma == 0: - label += ' (cross-entropy)' - plt.plot(ps, loss, label=label) - plt.legend(loc='best', frameon=True, shadow=True) - plt.xlim(0, 1) - plt.ylim(0, 4) - plt.xlabel(r'Probability of positive class $\hat{p}$') - plt.ylabel('Loss') - plt.title(r'Plot of focal loss $L(1, \hat{p})$ for different $\gamma$', - fontsize=14) - plt.show() - - Notes - ----- - A classifier often estimates the positive class probability :math:`\hat{p}` - by computing a real-valued *logit* :math:`\hat{y} \in \mathbb{R}` and - applying the *sigmoid function* :math:`\sigma : \mathbb{R} \to (0, 1)` - defined by - - .. math:: - - \sigma(t) = \frac{1}{1 + e^{-t}}, \qquad (t \in \mathbb{R}). - - That is, :math:`\hat{p} = \sigma(\hat{y})`. In this case, the focal loss - can be written as a function of the logit :math:`\hat{y}` instead of the - predicted probability :math:`\hat{p}`: - - .. math:: - - L(y, \hat{y}) - = -\alpha y \left(1 - \sigma(\hat{y})\right)^\gamma - \log(\sigma(\hat{y})) - - (1 - y) \sigma(\hat{y})^\gamma \log(1 - \sigma(\hat{y})). - - This is the formula that is computed when specifying `from_logits=True`. - However, this formula is not very numerically stable if implemented - directly; for example, there are multiple log and sigmoid computations - involved. Instead, we use some tricks to rewrite it in the more numerically - stable form - - .. math:: - - L(y, \hat{y}) - = (1 - y) \hat{p}^\gamma \hat{y} - + \left(\alpha y \hat{q}^\gamma + (1 - y) \hat{p}^\gamma\right) - \left(\log(1 + e^{-|\hat{y}|}) + \max\{-\hat{y}, 0\}\right), - - where :math:`\hat{p} = \sigma(\hat{y})` and :math:`\hat{q} = 1 - \hat{p}` - denote the estimates of the probabilities of the positive and negative - classes, respectively. - - Indeed, starting with the observations that - - .. math:: - - \log(\sigma(\hat{y})) - = \log\left(\frac{1}{1 + e^{-\hat{y}}}\right) - = -\log(1 + e^{-\hat{y}}) - - and - - .. math:: - - \log(1 - \sigma(\hat{y})) - = \log\left(\frac{e^{-\hat{y}}}{1 + e^{-\hat{y}}}\right) - = -\hat{y} - \log(1 + e^{-\hat{y}}), - - we obtain - - .. math:: - - \begin{aligned} - L(y, \hat{y}) - &= -\alpha y \hat{q}^\gamma \log(\sigma(\hat{y})) - - (1 - y) \hat{p}^\gamma \log(1 - \sigma(\hat{y})) \\ - &= \alpha y \hat{q}^\gamma \log(1 + e^{-\hat{y}}) - + (1 - y) \hat{p}^\gamma \left(\hat{y} + \log(1 + e^{-\hat{y}})\right)\\ - &= (1 - y) \hat{p}^\gamma \hat{y} - + \left(\alpha y \hat{q}^\gamma + (1 - y) \hat{p}^\gamma\right) - \log(1 + e^{-\hat{y}}). - \end{aligned} - - Note that if :math:`\hat{y} < 0`, then the exponential term - :math:`e^{-\hat{y}}` could become very large. In this case, we can instead - observe that - - .. math:: - - \begin{align*} - \log(1 + e^{-\hat{y}}) - &= \log(1 + e^{-\hat{y}}) + \hat{y} - \hat{y} \\ - &= \log(1 + e^{-\hat{y}}) + \log(e^{\hat{y}}) - \hat{y} \\ - &= \log(1 + e^{\hat{y}}) - \hat{y}. - \end{align*} - - Moreover, the :math:`\hat{y} < 0` and :math:`\hat{y} \geq 0` cases can be - unified by writing - - .. math:: - - \log(1 + e^{-\hat{y}}) - = \log(1 + e^{-|\hat{y}|}) + \max\{-\hat{y}, 0\}. - - Thus, we arrive at the numerically stable formula shown earlier. - - References - ---------- - .. [1] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for - dense object detection. IEEE Transactions on Pattern Analysis and - Machine Intelligence, 2018. - (`DOI `__) - (`arXiv preprint `__) - - See Also - -------- - :meth:`~focal_loss.BinaryFocalLoss` - A wrapper around this function that makes it a - :class:`tf.keras.losses.Loss`. - """ - # Validate arguments - gamma = check_float(gamma, name='gamma', minimum=0) - pos_weight = check_float(pos_weight, name='pos_weight', minimum=0, - allow_none=True) - from_logits = check_bool(from_logits, name='from_logits') - label_smoothing = check_float(label_smoothing, name='label_smoothing', - minimum=0, maximum=1, allow_none=True) - - # Ensure predictions are a floating point tensor; converting labels to a - # tensor will be done in the helper functions - y_pred = tf.convert_to_tensor(y_pred) - if not y_pred.dtype.is_floating: - y_pred = tf.dtypes.cast(y_pred, dtype=tf.float32) - - # Delegate per-example loss computation to helpers depending on whether - # predictions are logits or probabilities - if from_logits: - return _binary_focal_loss_from_logits(labels=y_true, logits=y_pred, - gamma=gamma, - pos_weight=pos_weight, - label_smoothing=label_smoothing) - else: - return _binary_focal_loss_from_probs(labels=y_true, p=y_pred, - gamma=gamma, pos_weight=pos_weight, - label_smoothing=label_smoothing) - - -@tf.keras.utils.register_keras_serializable() -class BinaryFocalLoss(tf.keras.losses.Loss): - r"""Focal loss function for binary classification. - - This loss function generalizes binary cross-entropy by introducing a - hyperparameter called the *focusing parameter* that allows hard-to-classify - examples to be penalized more heavily relative to easy-to-classify examples. - - This class is a wrapper around :class:`~focal_loss.binary_focal_loss`. See - the documentation there for details about this loss function. - - Parameters - ---------- - gamma : float - The focusing parameter :math:`\gamma`. Must be non-negative. - - pos_weight : float, optional - The coefficient :math:`\alpha` to use on the positive examples. Must be - non-negative. - - from_logits : bool, optional - Whether model prediction will be logits or probabilities. - - label_smoothing : float, optional - Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary - ground truth labels are squeezed toward 0.5, with larger values of - `label_smoothing` leading to label values closer to 0.5. - - **kwargs : keyword arguments - Other keyword arguments for :class:`tf.keras.losses.Loss` (e.g., `name` - or `reduction`). - - Examples - -------- - - An instance of this class is a callable that takes a tensor of binary ground - truth labels `y_true` and a tensor of model predictions `y_pred` and returns - a scalar tensor obtained by reducing the per-example focal loss (the default - reduction is a batch-wise average). - - >>> from focal_loss import BinaryFocalLoss - >>> loss_func = BinaryFocalLoss(gamma=2) - >>> loss = loss_func([0, 1, 1], [0.1, 0.7, 0.9]) # A scalar tensor - >>> print(f'Mean focal loss: {loss.numpy():.3f}') - Mean focal loss: 0.011 - - Use this class in the :mod:`tf.keras` API like any other binary - classification loss function class found in :mod:`tf.keras.losses` (e.g., - :class:`tf.keras.losses.BinaryCrossentropy`: - - .. code-block:: python - - # Typical usage - model = tf.keras.Model(...) - model.compile( - optimizer=..., - loss=BinaryFocalLoss(gamma=2), # Used here like a tf.keras loss - metrics=..., - ) - history = model.fit(...) - - See Also - -------- - :meth:`~focal_loss.binary_focal_loss` - The function that performs the focal loss computation, taking a label - tensor and a prediction tensor and outputting a loss. - """ - - def __init__(self, gamma, *, pos_weight=None, from_logits=False, - label_smoothing=None, **kwargs): - # Validate arguments - gamma = check_float(gamma, name='gamma', minimum=0) - pos_weight = check_float(pos_weight, name='pos_weight', minimum=0, - allow_none=True) - from_logits = check_bool(from_logits, name='from_logits') - label_smoothing = check_float(label_smoothing, name='label_smoothing', - minimum=0, maximum=1, allow_none=True) - - super().__init__(**kwargs) - self.gamma = gamma - self.pos_weight = pos_weight - self.from_logits = from_logits - self.label_smoothing = label_smoothing - - def get_config(self): - """Returns the config of the layer. - - A layer config is a Python dictionary containing the configuration of a - layer. The same layer can be re-instantiated later (without its trained - weights) from this configuration. - - Returns - ------- - dict - This layer's config. - """ - config = super().get_config() - config.update(gamma=self.gamma, pos_weight=self.pos_weight, - from_logits=self.from_logits, - label_smoothing=self.label_smoothing) - return config - - def call(self, y_true, y_pred): - """Compute the per-example focal loss. - - This method simply calls :meth:`~focal_loss.binary_focal_loss` with the - appropriate arguments. - - Parameters - ---------- - y_true : tensor-like - Binary (0 or 1) class labels. - - y_pred : tensor-like - Either probabilities for the positive class or logits for the - positive class, depending on the `from_logits` attribute. The shapes - of `y_true` and `y_pred` should be broadcastable. - - Returns - ------- - :class:`tf.Tensor` - The per-example focal loss. Reduction to a scalar is handled by - this layer's :meth:`~focal_loss.BinaryFocalLoss.__call__` method. - """ - return binary_focal_loss(y_true=y_true, y_pred=y_pred, gamma=self.gamma, - pos_weight=self.pos_weight, - from_logits=self.from_logits, - label_smoothing=self.label_smoothing) - - -# Helper functions below - - -def _process_labels(labels, label_smoothing, dtype): - """Pre-process a binary label tensor, maybe applying smoothing. - - Parameters - ---------- - labels : tensor-like - Tensor of 0's and 1's. - - label_smoothing : float or None - Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary - ground truth labels `y_true` are squeezed toward 0.5, with larger values - of `label_smoothing` leading to label values closer to 0.5. - - dtype : tf.dtypes.DType - Desired type of the elements of `labels`. - - Returns - ------- - tf.Tensor - The processed labels. - """ - labels = tf.dtypes.cast(labels, dtype=dtype) - if label_smoothing is not None: - labels = (1 - label_smoothing) * labels + label_smoothing * 0.5 - return labels - - -def _binary_focal_loss_from_logits(labels, logits, gamma, pos_weight, - label_smoothing): - """Compute focal loss from logits using a numerically stable formula. - - Parameters - ---------- - labels : tensor-like - Tensor of 0's and 1's: binary class labels. - - logits : tf.Tensor - Logits for the positive class. - - gamma : float - Focusing parameter. - - pos_weight : float or None - If not None, losses for the positive class will be scaled by this - weight. - - label_smoothing : float or None - Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary - ground truth labels `y_true` are squeezed toward 0.5, with larger values - of `label_smoothing` leading to label values closer to 0.5. - - Returns - ------- - tf.Tensor - The loss for each example. - """ - labels = _process_labels(labels=labels, label_smoothing=label_smoothing, - dtype=logits.dtype) - - # Compute probabilities for the positive class - p = tf.math.sigmoid(logits) - - # Without label smoothing we can use TensorFlow's built-in per-example cross - # entropy loss functions and multiply the result by the modulating factor. - # Otherwise, we compute the focal loss ourselves using a numerically stable - # formula below - if label_smoothing is None: - # The labels and logits tensors' shapes need to be the same for the - # built-in cross-entropy functions. Since we want to allow broadcasting, - # we do some checks on the shapes and possibly broadcast explicitly - # Note: tensor.shape returns a tf.TensorShape, whereas tf.shape(tensor) - # returns an int tf.Tensor; this is why both are used below - labels_shape = labels.shape - logits_shape = logits.shape - if not labels_shape.is_fully_defined() or labels_shape != logits_shape: - labels_shape = tf.shape(labels) - logits_shape = tf.shape(logits) - shape = tf.broadcast_dynamic_shape(labels_shape, logits_shape) - labels = tf.broadcast_to(labels, shape) - logits = tf.broadcast_to(logits, shape) - if pos_weight is None: - loss_func = tf.nn.sigmoid_cross_entropy_with_logits - else: - loss_func = partial(tf.nn.weighted_cross_entropy_with_logits, - pos_weight=pos_weight) - loss = loss_func(labels=labels, logits=logits) - modulation_pos = (1 - p) ** gamma - modulation_neg = p ** gamma - mask = tf.dtypes.cast(labels, dtype=tf.bool) - modulation = tf.where(mask, modulation_pos, modulation_neg) - return modulation * loss - - # Terms for the positive and negative class components of the loss - pos_term = labels * ((1 - p) ** gamma) - neg_term = (1 - labels) * (p ** gamma) - - # Term involving the log and ReLU - log_weight = pos_term - if pos_weight is not None: - log_weight *= pos_weight - log_weight += neg_term - log_term = tf.math.log1p(tf.math.exp(-tf.math.abs(logits))) - log_term += tf.nn.relu(-logits) - log_term *= log_weight - - # Combine all the terms into the loss - loss = neg_term * logits + log_term - return loss - - -def _binary_focal_loss_from_probs(labels, p, gamma, pos_weight, - label_smoothing): - """Compute focal loss from probabilities. - - Parameters - ---------- - labels : tensor-like - Tensor of 0's and 1's: binary class labels. - - p : tf.Tensor - Estimated probabilities for the positive class. - - gamma : float - Focusing parameter. - - pos_weight : float or None - If not None, losses for the positive class will be scaled by this - weight. - - label_smoothing : float or None - Float in [0, 1]. When 0, no smoothing occurs. When positive, the binary - ground truth labels `y_true` are squeezed toward 0.5, with larger values - of `label_smoothing` leading to label values closer to 0.5. - - Returns - ------- - tf.Tensor - The loss for each example. - """ - # Predicted probabilities for the negative class - q = 1 - p - - # For numerical stability (so we don't inadvertently take the log of 0) - p = tf.math.maximum(p, _EPSILON) - q = tf.math.maximum(q, _EPSILON) - - # Loss for the positive examples - pos_loss = -(q ** gamma) * tf.math.log(p) - if pos_weight is not None: - pos_loss *= pos_weight - - # Loss for the negative examples - neg_loss = -(p ** gamma) * tf.math.log(q) - - # Combine loss terms - if label_smoothing is None: - labels = tf.dtypes.cast(labels, dtype=tf.bool) - loss = tf.where(labels, pos_loss, neg_loss) - else: - labels = _process_labels(labels=labels, label_smoothing=label_smoothing, - dtype=p.dtype) - loss = labels * pos_loss + (1 - labels) * neg_loss - - return loss diff --git a/annopro/focal_loss/_categorical_focal_loss.py b/annopro/focal_loss/_categorical_focal_loss.py deleted file mode 100644 index ca9dc5b..0000000 --- a/annopro/focal_loss/_categorical_focal_loss.py +++ /dev/null @@ -1,311 +0,0 @@ -"""Multiclass focal loss implementation.""" -# __ _ _ -# / _| | | | | -# | |_ ___ ___ __ _ | | | | ___ ___ ___ -# | _| / _ \ / __| / _` | | | | | / _ \ / __| / __| -# | | | (_) | | (__ | (_| | | | | | | (_) | \__ \ \__ \ -# |_| \___/ \___| \__,_| |_| |_| \___/ |___/ |___/ - -import itertools -from typing import Any, Optional - -import tensorflow as tf - -_EPSILON = tf.keras.backend.epsilon() - - -def sparse_categorical_focal_loss(y_true, y_pred, gamma, *, - class_weight: Optional[Any] = None, - from_logits: bool = False, axis: int = -1 - ) -> tf.Tensor: - r"""Focal loss function for multiclass classification with integer labels. - - This loss function generalizes multiclass softmax cross-entropy by - introducing a hyperparameter called the *focusing parameter* that allows - hard-to-classify examples to be penalized more heavily relative to - easy-to-classify examples. - - See :meth:`~focal_loss.binary_focal_loss` for a description of the focal - loss in the binary setting, as presented in the original work [1]_. - - In the multiclass setting, with integer labels :math:`y`, focal loss is - defined as - - .. math:: - - L(y, \hat{\mathbf{p}}) - = -\left(1 - \hat{p}_y\right)^\gamma \log(\hat{p}_y) - - where - - * :math:`y \in \{0, \ldots, K - 1\}` is an integer class label (:math:`K` - denotes the number of classes), - * :math:`\hat{\mathbf{p}} = (\hat{p}_0, \ldots, \hat{p}_{K-1}) - \in [0, 1]^K` is a vector representing an estimated probability - distribution over the :math:`K` classes, - * :math:`\gamma` (gamma, not :math:`y`) is the *focusing parameter* that - specifies how much higher-confidence correct predictions contribute to - the overall loss (the higher the :math:`\gamma`, the higher the rate at - which easy-to-classify examples are down-weighted). - - The usual multiclass softmax cross-entropy loss is recovered by setting - :math:`\gamma = 0`. - - Parameters - ---------- - y_true : tensor-like - Integer class labels. - - y_pred : tensor-like - Either probabilities or logits, depending on the `from_logits` - parameter. - - gamma : float or tensor-like of shape (K,) - The focusing parameter :math:`\gamma`. Higher values of `gamma` make - easy-to-classify examples contribute less to the loss relative to - hard-to-classify examples. Must be non-negative. This can be a - one-dimensional tensor, in which case it specifies a focusing parameter - for each class. - - class_weight: tensor-like of shape (K,) - Weighting factor for each of the :math:`k` classes. If not specified, - then all classes are weighted equally. - - from_logits : bool, optional - Whether `y_pred` contains logits or probabilities. - - axis : int, optional - Channel axis in the `y_pred` tensor. - - Returns - ------- - :class:`tf.Tensor` - The focal loss for each example. - - Examples - -------- - - This function computes the per-example focal loss between a one-dimensional - integer label vector and a two-dimensional prediction matrix: - - >>> import numpy as np - >>> from focal_loss import sparse_categorical_focal_loss - >>> y_true = [0, 1, 2] - >>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]] - >>> loss = sparse_categorical_focal_loss(y_true, y_pred, gamma=2) - >>> np.set_printoptions(precision=3) - >>> print(loss.numpy()) - [0.009 0.032 0.082] - - Warnings - -------- - This function does not reduce its output to a scalar, so it cannot be passed - to :meth:`tf.keras.Model.compile` as a `loss` argument. Instead, use the - wrapper class :class:`~focal_loss.SparseCategoricalFocalLoss`. - - References - ---------- - .. [1] T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for - dense object detection. IEEE Transactions on Pattern Analysis and - Machine Intelligence, 2018. - (`DOI `__) - (`arXiv preprint `__) - - See Also - -------- - :meth:`~focal_loss.SparseCategoricalFocalLoss` - A wrapper around this function that makes it a - :class:`tf.keras.losses.Loss`. - """ - # Process focusing parameter - gamma = tf.convert_to_tensor(gamma, dtype=tf.dtypes.float32) - gamma_rank = gamma.shape.rank - scalar_gamma = gamma_rank == 0 - - # Process class weight - if class_weight is not None: - class_weight = tf.convert_to_tensor(class_weight, - dtype=tf.dtypes.float32) - - # Process prediction tensor - y_pred = tf.convert_to_tensor(y_pred) - y_pred_rank = y_pred.shape.rank - if y_pred_rank is not None: - axis %= y_pred_rank - if axis != y_pred_rank - 1: - # Put channel axis last for sparse_softmax_cross_entropy_with_logits - perm = list(itertools.chain(range(axis), - range(axis + 1, y_pred_rank), [axis])) - y_pred = tf.transpose(y_pred, perm=perm) - elif axis != -1: - raise ValueError( - f'Cannot compute sparse categorical focal loss with axis={axis} on ' - 'a prediction tensor with statically unknown rank.') - y_pred_shape = tf.shape(y_pred) - - # Process ground truth tensor - y_true = tf.dtypes.cast(y_true, dtype=tf.dtypes.int64) - y_true_rank = y_true.shape.rank - - if y_true_rank is None: - raise NotImplementedError('Sparse categorical focal loss not supported ' - 'for target/label tensors of unknown rank') - - reshape_needed = (y_true_rank is not None and y_pred_rank is not None and - y_pred_rank != y_true_rank + 1) - if reshape_needed: - y_true = tf.reshape(y_true, [-1]) - y_pred = tf.reshape(y_pred, [-1, y_pred_shape[-1]]) - - if from_logits: - logits = y_pred - probs = tf.nn.softmax(y_pred, axis=-1) - else: - probs = y_pred - logits = tf.math.log(tf.clip_by_value(y_pred, _EPSILON, 1 - _EPSILON)) - - xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=y_true, - logits=logits, - ) - - y_true_rank = y_true.shape.rank - probs = tf.gather(probs, y_true, axis=-1, batch_dims=y_true_rank) - if not scalar_gamma: - gamma = tf.gather(gamma, y_true, axis=0, batch_dims=y_true_rank) - focal_modulation = (1 - probs) ** gamma - loss = focal_modulation * xent_loss - - if class_weight is not None: - class_weight = tf.gather(class_weight, y_true, axis=0, - batch_dims=y_true_rank) - loss *= class_weight - - if reshape_needed: - loss = tf.reshape(loss, y_pred_shape[:-1]) - - return loss - - -@tf.keras.utils.register_keras_serializable() -class SparseCategoricalFocalLoss(tf.keras.losses.Loss): - r"""Focal loss function for multiclass classification with integer labels. - - This loss function generalizes multiclass softmax cross-entropy by - introducing a hyperparameter :math:`\gamma` (gamma), called the - *focusing parameter*, that allows hard-to-classify examples to be penalized - more heavily relative to easy-to-classify examples. - - This class is a wrapper around - :class:`~focal_loss.sparse_categorical_focal_loss`. See the documentation - there for details about this loss function. - - Parameters - ---------- - gamma : float or tensor-like of shape (K,) - The focusing parameter :math:`\gamma`. Higher values of `gamma` make - easy-to-classify examples contribute less to the loss relative to - hard-to-classify examples. Must be non-negative. This can be a - one-dimensional tensor, in which case it specifies a focusing parameter - for each class. - - class_weight: tensor-like of shape (K,) - Weighting factor for each of the :math:`k` classes. If not specified, - then all classes are weighted equally. - - from_logits : bool, optional - Whether model prediction will be logits or probabilities. - - **kwargs : keyword arguments - Other keyword arguments for :class:`tf.keras.losses.Loss` (e.g., `name` - or `reduction`). - - Examples - -------- - - An instance of this class is a callable that takes a rank-one tensor of - integer class labels `y_true` and a tensor of model predictions `y_pred` and - returns a scalar tensor obtained by reducing the per-example focal loss (the - default reduction is a batch-wise average). - - >>> from focal_loss import SparseCategoricalFocalLoss - >>> loss_func = SparseCategoricalFocalLoss(gamma=2) - >>> y_true = [0, 1, 2] - >>> y_pred = [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]] - >>> loss_func(y_true, y_pred) - - - Use this class in the :mod:`tf.keras` API like any other multiclass - classification loss function class that accepts integer labels found in - :mod:`tf.keras.losses` (e.g., - :class:`tf.keras.losses.SparseCategoricalCrossentropy`: - - .. code-block:: python - - # Typical usage - model = tf.keras.Model(...) - model.compile( - optimizer=..., - loss=SparseCategoricalFocalLoss(gamma=2), # Used here like a tf.keras loss - metrics=..., - ) - history = model.fit(...) - - See Also - -------- - :meth:`~focal_loss.sparse_categorical_focal_loss` - The function that performs the focal loss computation, taking a label - tensor and a prediction tensor and outputting a loss. - """ - - def __init__(self, gamma, class_weight: Optional[Any] = None, - from_logits: bool = False, **kwargs): - super().__init__(**kwargs) - self.gamma = gamma - self.class_weight = class_weight - self.from_logits = from_logits - - def get_config(self): - """Returns the config of the layer. - - A layer config is a Python dictionary containing the configuration of a - layer. The same layer can be re-instantiated later (without its trained - weights) from this configuration. - - Returns - ------- - dict - This layer's config. - """ - config = super().get_config() - config.update(gamma=self.gamma, class_weight=self.class_weight, - from_logits=self.from_logits) - return config - - def call(self, y_true, y_pred): - """Compute the per-example focal loss. - - This method simply calls - :meth:`~focal_loss.sparse_categorical_focal_loss` with the appropriate - arguments. - - Parameters - ---------- - y_true : tensor-like, shape (N,) - Integer class labels. - - y_pred : tensor-like, shape (N, K) - Either probabilities or logits, depending on the `from_logits` - parameter. - - Returns - ------- - :class:`tf.Tensor` - The per-example focal loss. Reduction to a scalar is handled by - this layer's - :meth:`~focal_loss.SparseCateogiricalFocalLoss.__call__` method. - """ - return sparse_categorical_focal_loss(y_true=y_true, y_pred=y_pred, - class_weight=self.class_weight, - gamma=self.gamma, - from_logits=self.from_logits) diff --git a/annopro/focal_loss/utils/__init__.py b/annopro/focal_loss/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/annopro/focal_loss/utils/validation.py b/annopro/focal_loss/utils/validation.py deleted file mode 100644 index c619c38..0000000 --- a/annopro/focal_loss/utils/validation.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Helper functions for function parameter validation.""" - -import numbers - - -def check_type(obj, base, *, name=None, func=None, allow_none=False, - default=None, error_message=None): - """Check whether an object is an instance of a base type. - - Parameters - ---------- - obj : object - The object to be validated. - - name : str - The name of `obj` in the calling function. - - base : type or tuple of type - The base type that `obj` should be an instance of. - - func: callable, optional - A function to be applied to `obj` if it is of type `base`. If None, no - function will be applied and `obj` will be returned as-is. - - allow_none : bool, optional - Indicates whether the value None should be allowed to pass through. - - default : object, optional - The default value to return if `obj` is None and `allow_none` is True. - If `default` is not None, it must be of type `base`, and it will have - `func` applied to it if `func` is not None. - - error_message : str or None, optional - Custom error message to display if the type is incorrect. - - Returns - ------- - base type or None - The validated object. - - Raises - ------ - TypeError - If `obj` is not an instance of `base`. - - Examples - -------- - >>> check_type(1, int) - 1 - >>> check_type(1, (int, str)) - 1 - >>> check_type(1, str) - Traceback (most recent call last): - ... - TypeError: Invalid type. Expected: str. Actual: int. - >>> check_type(1, (str, bool)) - Traceback (most recent call last): - ... - TypeError: Invalid type. Expected: (str, bool). Actual: int. - >>> print(check_type(None, str, allow_none=True)) - None - >>> check_type(1, str, name='num') - Traceback (most recent call last): - ... - TypeError: Invalid type for parameter 'num'. Expected: str. Actual: int. - >>> check_type(1, int, func=str) - '1' - >>> check_type(1, int, func='not callable') - Traceback (most recent call last): - ... - ValueError: Parameter 'func' must be callable or None. - >>> check_type(2.0, str, error_message='Not a string!') - Traceback (most recent call last): - ... - TypeError: Not a string! - >>> check_type(None, int, allow_none=True, default=0) - 0 - - """ - if allow_none and obj is None: - if default is not None: - return check_type(default, base=base, name=name, func=func, - allow_none=False) - return None - - if isinstance(obj, base): - if func is None: - return obj - elif callable(func): - return func(obj) - else: - raise ValueError('Parameter \'func\' must be callable or None.') - - # Handle wrong type - if isinstance(base, tuple): - expect = '(' + ', '.join(cls.__name__ for cls in base) + ')' - else: - expect = base.__name__ - actual = type(obj).__name__ - if error_message is None: - error_message = 'Invalid type' - if name is not None: - error_message += f' for parameter \'{name}\'' - error_message += f'. Expected: {expect}. Actual: {actual}.' - raise TypeError(error_message) - - -def check_bool(obj, *, name=None, allow_none=False, default=None): - """Validate boolean function arguments. - - Parameters - ---------- - obj : object - The object to be validated. - - name : str, optional - The name of `obj` in the calling function. - - allow_none : bool, optional - Indicates whether the value None should be allowed. - - default : object, optional - The default value to return if `obj` is None and `allow_none` is True. - - Returns - ------- - bool or None - The validated bool. - - Raises - ------ - TypeError - If `obj` is not an instance of bool. - - Examples - -------- - >>> check_bool(True) - True - >>> check_bool(1.0) - Traceback (most recent call last): - ... - TypeError: Invalid type. Expected: bool. Actual: float. - >>> a = (1 < 2) - >>> check_bool(a, name='a') - True - >>> b = 'not a bool' - >>> check_bool(b, name='b') - Traceback (most recent call last): - ... - TypeError: Invalid type for parameter 'b'. Expected: bool. Actual: str. - """ - return check_type(obj, name=name, base=bool, func=bool, - allow_none=allow_none, default=default) - - -def _check_numeric(*, check_func, obj, name, base, func, positive, minimum, - maximum, allow_none, default): - """Helper function for check_float and check_int.""" - obj = check_type(obj, name=name, base=base, func=func, - allow_none=allow_none, default=default) - - if obj is None: - return None - - positive = check_bool(positive, name='positive') - if positive and obj <= 0: - if name is None: - message = 'Parameter must be positive.' - else: - message = f'Parameter \'{name}\' must be positive.' - raise ValueError(message) - - if minimum is not None: - minimum = check_func(minimum, name='minimum') - if obj < minimum: - if name is None: - message = f'Parameter must be at least {minimum}.' - else: - message = f'Parameter \'{name}\' must be at least {minimum}.' - raise ValueError(message) - - if maximum is not None: - maximum = check_func(maximum, name='minimum') - if obj > maximum: - if name is None: - message = f'Parameter must be at most {maximum}.' - else: - message = f'Parameter \'{name}\' must be at most {maximum}.' - raise ValueError(message) - - return obj - - -def check_int(obj, *, name=None, positive=False, minimum=None, maximum=None, - allow_none=False, default=None): - """Validate integer function arguments. - - Parameters - ---------- - obj : object - The object to be validated. - - name : str, optional - The name of `obj` in the calling function. - - positive : bool, optional - Whether `obj` must be a positive integer (1 or greater). - - minimum : int, optional - The minimum value that `obj` can take (inclusive). - - maximum : int, optional - The maximum value that `obj` can take (inclusive). - - allow_none : bool, optional - Indicates whether the value None should be allowed. - - default : object, optional - The default value to return if `obj` is None and `allow_none` is True. - - Returns - ------- - int or None - The validated integer. - - Raises - ------ - TypeError - If `obj` is not an integer. - - ValueError - If any of the optional positivity or minimum and maximum value - constraints are violated. - - Examples - -------- - >>> check_int(0) - 0 - >>> check_int(1, positive=True) - 1 - >>> check_int(1.0) - Traceback (most recent call last): - ... - TypeError: Invalid type. Expected: Integral. Actual: float. - >>> check_int(-1, positive=True) - Traceback (most recent call last): - ... - ValueError: Parameter must be positive. - >>> check_int(1, name='a', minimum=10) - Traceback (most recent call last): - ... - ValueError: Parameter 'a' must be at least 10. - - """ - return _check_numeric(check_func=check_int, obj=obj, name=name, - base=numbers.Integral, func=int, positive=positive, - minimum=minimum, maximum=maximum, - allow_none=allow_none, default=default) - - -def check_float(obj, *, name=None, positive=False, minimum=None, maximum=None, - allow_none=False, default=None): - """Validate float function arguments. - - Parameters - ---------- - obj : object - The object to be validated. - - name : str, optional - The name of `obj` in the calling function. - - positive : bool, optional - Whether `obj` must be a positive float. - - minimum : float, optional - The minimum value that `obj` can take (inclusive). - - maximum : float, optional - The maximum value that `obj` can take (inclusive). - - allow_none : bool, optional - Indicates whether the value None should be allowed. - - default : object, optional - The default value to return if `obj` is None and `allow_none` is True. - - Returns - ------- - float or None - The validated float. - - Raises - ------ - TypeError - If `obj` is not a float. - - ValueError - If any of the optional positivity or minimum and maximum value - constraints are violated. - - Examples - -------- - >>> check_float(0) - 0.0 - >>> check_float(1.0, positive=True) - 1.0 - >>> check_float(1.0 + 1.0j) - Traceback (most recent call last): - ... - TypeError: Invalid type. Expected: Real. Actual: complex. - >>> check_float(-1, positive=True) - Traceback (most recent call last): - ... - ValueError: Parameter must be positive. - >>> check_float(1.2, name='a', minimum=10) - Traceback (most recent call last): - ... - ValueError: Parameter 'a' must be at least 10.0. - - """ - return _check_numeric(check_func=check_float, obj=obj, name=name, - base=numbers.Real, func=float, positive=positive, - minimum=minimum, maximum=maximum, - allow_none=allow_none, default=default) diff --git a/annopro/prediction.py b/annopro/prediction.py deleted file mode 100644 index 0c0188a..0000000 --- a/annopro/prediction.py +++ /dev/null @@ -1,181 +0,0 @@ -import os -from tensorflow.keras.utils import Sequence -from tensorflow.keras.models import load_model -import numpy as np -import pandas as pd -import math -import pickle -import annopro.resources as resources -from annopro.focal_loss import BinaryFocalLoss -from annopro.data_procession.utils import NAMESPACES, Ontology - - -def predict(output_dir: str, promap_features_file: str, - used_gpu: str = "-1", diamond_scores_file: str = None): - if output_dir == None: - raise ValueError("Must provide the input fasta sequences.") - os.environ["CUDA_VISIBLE_DEVICES"] = used_gpu - for term_type in NAMESPACES.keys(): - init_evaluate(term_type=term_type, - promap_features_file=promap_features_file, - diamond_scores_file=diamond_scores_file, - output_dir=output_dir) - - -class DFGenerator(Sequence): - def __init__(self, df, terms_dict, nb_classes, batch_size): - self.start = 0 - self.size = len(df) - self.df = df - self.batch_size = batch_size - self.nb_classes = nb_classes - self.terms_dict = terms_dict - - def __len__(self): - return np.ceil(len(self.df) / float(self.batch_size)).astype(np.int32) - - def __getitem__(self, idx): - batch_index = np.arange(idx * self.batch_size, - min(self.size, (idx + 1) * self.batch_size)) - df = self.df.iloc[batch_index] - labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) - feature_data = [] - protein_si = [] - for i, row in enumerate(df.itertuples()): - feature_data.append(list(row.Promap_feature)) - protein_si.append(list(row.Protein_similary)) - data_onehot = np.array(feature_data) - data_si = np.array(protein_si) - self.start += self.batch_size - return ([data_onehot, data_si]) - - def __next__(self): - return self.next() - - def reset(self): - self.start = 0 - - def next(self): - if self.start < self.size: - batch_index = np.arange( - self.start, min(self.size, self.start + self.batch_size)) - df = self.df.iloc[batch_index] - labels = np.zeros((len(df), self.nb_classes), dtype=np.int32) - feature_data = [] - protein_si = [] - for i, row in enumerate(df.itertuples()): - feature_data.append(list(row.Promap_feature)) - protein_si.append(list(row.Protein_similary)) - data_onehot = np.array(feature_data) - data_si = np.array(protein_si) - self.start += self.batch_size - return ([data_onehot, data_si]) - else: - self.reset() - return self.next() - - -def diamond_score(diamond_scores_file, label, data_path, term_type): - with resources.open_binary("go.pkl") as file: - go: Ontology = pickle.load(file) - assert isinstance(go, Ontology) - with resources.open_binary("cafa_train.pkl") as file: - train_df = pd.read_pickle(file) - test_df = pd.read_pickle(data_path) - annotations = train_df['Prop_annotations'].values - annotations = list(map(lambda x: set(x), annotations)) - - prot_index = {} - for i, row in enumerate(train_df.itertuples()): - prot_index[row.Proteins] = i - - diamond_scores = {} - with open(diamond_scores_file) as f: - for line in f: - it = line.strip().split("\t") - if it[0] not in diamond_scores: - diamond_scores[it[0]] = {} - diamond_scores[it[0]][it[1]] = float(it[11]) - blast_preds = [] - - for i, row in enumerate(test_df.itertuples()): - annots = {} - prot_id = row.Proteins - # BlastKNN - if prot_id in diamond_scores: - sim_prots = diamond_scores[prot_id] - allgos = set() - total_score = 0.0 - for p_id, score in sim_prots.items(): - allgos |= annotations[prot_index[p_id]] - total_score += score - allgos = list(sorted(allgos)) - sim = np.zeros(len(allgos), dtype=np.float32) - for j, go_id in enumerate(allgos): - s = 0.0 - for p_id, score in sim_prots.items(): - if go_id in annotations[prot_index[p_id]]: - s += score - sim[j] = s / total_score - ind = np.argsort(-sim) - for go_id, score in zip(allgos, sim): - annots[go_id] = score - blast_preds.append(annots) - with resources.open_binary(f"terms_{NAMESPACES[term_type]}.pkl") as term_path: - terms = pd.read_pickle(term_path) - terms = terms['terms'].values.flatten() - alphas = {NAMESPACES['mf']: 0.55, - NAMESPACES['bp']: 0.6, NAMESPACES['cc']: 0.4} - - for i in range(0, len(label)): - annots_dict = blast_preds[i].copy() - for go_id in annots_dict: - annots_dict[go_id] *= alphas[go.get_namespace(go_id)] - for j in range(0, len(label[0])): - go_id = terms[j] - label[i, j] = label[i, j]*(1 - alphas[go.get_namespace(go_id)]) - if go_id in annots_dict: - label[i, j] = label[i, j] + annots_dict[go_id] - return label - - -def init_evaluate(term_type, promap_features_file, diamond_scores_file, output_dir: str, - data_size=8000, batch_size=16): - with resources.open_binary(f"terms_{NAMESPACES[term_type]}.pkl") as file: - terms_df = pd.read_pickle(file) - with open(promap_features_file, 'rb') as file: - data_df = pd.read_pickle(file) - if len(data_df) > data_size: - data_df = data_df.sample(n=data_size) - data_df.index = range(len(data_df)) - model = load_model( - resources.get_resource_path(f"{term_type}.h5"), - custom_objects={"focus_loss": BinaryFocalLoss}) - proteins = data_df["Proteins"] - terms = terms_df['terms'].values.flatten() - terms_dict = {v: i for i, v in enumerate(terms)} - nb_classes = len(terms) - data_generator = DFGenerator(data_df, terms_dict, nb_classes, batch_size) - data_steps = int(math.ceil(len(data_df) / batch_size)) - preds = model.predict(data_generator, steps=data_steps) - if diamond_scores_file: - preds = diamond_score(diamond_scores_file, preds, - promap_features_file, term_type=term_type) - # label_di=defaultdict(list) - protein = [] - go_terms = [] - score = [] - for i in range(len(preds)): - for j in range(len(preds[i])): - if preds[i][j] > 0: - protein.append(proteins[i]) - go_terms.append(terms[j]) - score.append(preds[i][j]) - res = [protein, go_terms, score] - res = pd.DataFrame(res) - res = res.T - res.columns = ['Proteins', 'GO-terms', 'Scores'] - res.sort_values(by='Scores', axis=0, ascending=False, inplace=True) - result_file = os.path.join(output_dir, f"{term_type}_result.csv") - res.to_csv(result_file, sep=',', index=False, header=True) - return res diff --git a/annopro/resources.py b/annopro/resources.py deleted file mode 100644 index de7bcad..0000000 --- a/annopro/resources.py +++ /dev/null @@ -1,115 +0,0 @@ -""" -This module manages all required resources for annopro. -""" -from io import FileIO, TextIOWrapper -import os -import wget -import hashlib - -RESOURCE_DIR = os.path.join(os.path.expanduser("~"), ".annopro/data") -os.makedirs(RESOURCE_DIR, exist_ok=True) - -RESOURCE_DICT = { - "cafa_train.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/cafa_train.pkl", - "md5sum": "07d3e4334c31c914efec3f52cae5e498" - }, - "cafa4_del.csv": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/cafa4_del.csv", - "md5sum": "d00b71439084cb19b7d3d0d4fbbaa819" - }, - "cafa4.dmnd": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/cafa4.dmnd", - "md5sum": "a2a6cba9af26dbe1911e14a306db2712" - }, - "data_grid.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/data_grid.pkl", - "md5sum": "fb2d2d86a4bc21c6e60fac996b3a90d3" - }, - "go.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/go.pkl", - "md5sum": "8d7a975d38a4af670b0370f4ea722a2a" - }, - "go.txt": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/go.txt", - "md5sum": "1dae308468fa00ae6d5796fd22c65044" - }, - "row_asses.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/row_asses.pkl", - "md5sum": "bf9bb1eda744a60c381d19b275ac6f33" - }, - "terms_biological_process.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/terms_biological_process.pkl", - "md5sum": "e79cf5e006432c19606b8a482cf7ddfa" - }, - "terms_cellular_component.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/terms_cellular_component.pkl", - "md5sum": "8cf58075eba65e2bb710566e4ef93f42" - }, - "terms_molecular_function.pkl": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/data/terms_molecular_function.pkl", - "md5sum": "002c3696fbca0061402a68129b35dcd4" - }, - "bp.h5": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/model_param/bp.h5", - "md5sum": "7e19158e5252a70ff831f5f583b1c2ed" - }, - "cc.h5": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/model_param/cc.h5", - "md5sum": "73876beec9370ff56b58878cf4446d2c" - }, - "mf.h5": { - "url": "https://promap.oss-cn-hangzhou.aliyuncs.com/model_param/mf.h5", - "md5sum": "f4fb632f553afeb45571a29e46286bb8" - } -} - - -def md5sum(file_path: str) -> str: - with open(file_path, "rb") as f: - md5 = hashlib.md5() - while True: - data = f.read(65536) - if not data: - break - md5.update(data) - return md5.hexdigest() - - -def md5check(file_path: str, expected: str): - return md5sum(file_path).startswith(expected) - - -def download_resource(name: str, overwrite: bool = False) -> str: - if name in RESOURCE_DICT: - resource = RESOURCE_DICT[name] - path_name = os.path.join(RESOURCE_DIR, name) - if os.path.exists(path_name): - if overwrite or not md5check(path_name, resource["md5sum"]): - os.remove(path_name) - else: - return path_name - print(f"Download {name}...") - wget.download( - url=resource["url"], - out=path_name) - print(f"\nValidate md5sum of {name}...") - if not md5check(path_name, resource["md5sum"]): - raise RuntimeError(f"{name} do not pass md5 validation, please visit https://github.com/idrblab/AnnoPRO for help") - return path_name - else: - raise FileNotFoundError(f"Invalid resource name: {name}") - - -def get_resource_path(name: str) -> str: - return download_resource(name) - - -def open_binary(name: str) -> FileIO: - file_path = get_resource_path(name) - return open(file_path, "rb") - - -def open_text(name: str) -> TextIOWrapper: - file_path = get_resource_path(name) - return open(file_path, "rt")