diff --git a/psite/train.py b/psite/train.py index 17b5c5b..7507c5b 100644 --- a/psite/train.py +++ b/psite/train.py @@ -54,9 +54,25 @@ def get_txrep(txinfo, type_rep='longest', path_exp=None, ignore_version=False): txinfo = txinfo.loc[txinfo['transcript_biotype'] == 'protein_coding'] if type_rep == 'longest': txrep = txinfo.sort_values(['gene_id', 'tx_len'], ascending=[True, False]) - txrep = txrep.groupby('gene_id').first().reset_index() + txrep = txrep.groupby('gene_id').first().reset_index(drop=True) elif type_rep == 'principal': - txrep = txinfo.loc[(txinfo['txtype'] == 'principal')] + try: + tx_extra = pd.read_table(path_exp) + tx_extra.rename(columns={ + 'Gene stable ID': 'gene_id', 'Transcript stable ID': 'tx_name', + 'APPRIS annotation': 'appris', 'Ensembl Canonical': 'canonical', + 'RefSeq match transcript (MANE Select)': 'mane_select'}, inplace=True) + tx_extra = tx_extra[~tx_extra.appris.isna()] + tx_extra = tx_extra.assign(appris=pd.Categorical(tx_extra.appris, categories=['principal1', + 'principal2', 'principal3', 'principal4', 'principal5', 'alternative1', 'alternative2'])) + tx_extra = tx_extra.sort_values(by='appris').reset_index(drop=True) + appris_prin = ['principal1', 'principal2', 'principal3', 'principal4', 'principal5'] + tx_extra = tx_extra[tx_extra.appris.isin(appris_prin)] + tx_extra = tx_extra[~tx_extra.duplicated(subset=['gene_id'])] + txrep = txinfo[txinfo.tx_name.isin(tx_extra.tx_name)] + except: + print('input extra txinfo file is incorrect', file=sys.stderr) + exit(1) elif type_rep == 'salmon': # salmon_output_dir/quant.sf try: tx_quant = pd.read_csv(path_exp, sep='\t') @@ -68,8 +84,8 @@ def get_txrep(txinfo, type_rep='longest', path_exp=None, ignore_version=False): tx_quant = tx_quant.groupby('gene_id').first().reset_index() txrep = txinfo.loc[(txinfo['tx_name'].isin(tx_quant['Name']))] except: - print('input salmon quant results incorrect', file=sys.stderr) - exit() + print('input salmon quant results is incorrect', file=sys.stderr) + exit(1) elif type_rep == 'kallisto': # kallisto_output_dir/abundance.tsv try: tx_quant = pd.read_csv(path_exp, sep='\t') @@ -81,11 +97,11 @@ def get_txrep(txinfo, type_rep='longest', path_exp=None, ignore_version=False): tx_quant = tx_quant.groupby('gene_id').first().reset_index() txrep = txinfo.loc[(txinfo['tx_name'].isin(tx_quant['target_id']))] except: - print('input kallisto quant results incorrect', file=sys.stderr) - exit() + print('input kallisto quant results is incorrect', file=sys.stderr) + exit(1) else: print('Incorrect txinfo_rep option!', file=sys.stderr) - exit() + exit(1) return txrep