From 5e934577e2552ded5458beeebf9652e8ef90ccfc Mon Sep 17 00:00:00 2001 From: Ray-Lei Date: Wed, 30 Aug 2023 12:25:18 -0400 Subject: [PATCH] updated gmpind --- GMPFeaturizer/GMP_individual/__init__.py | 8 ++++---- GMPFeaturizer/constants.py | 23 +++++++++++++++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/GMPFeaturizer/GMP_individual/__init__.py b/GMPFeaturizer/GMP_individual/__init__.py index 7b19008..e5f3b57 100644 --- a/GMPFeaturizer/GMP_individual/__init__.py +++ b/GMPFeaturizer/GMP_individual/__init__.py @@ -10,7 +10,7 @@ from scipy import sparse from ..base_feature import BaseFeature -from ..constants import ATOM_SYMBOL_TO_INDEX_DICT, ATOM_LIST, PI, get_GMP_group_info +from ..constants import ATOM_SYMBOL_TO_INDEX_DICT, ATOM_LIST, PI, get_GMP_group_info, get_GMP_group_info2 from ..util import ( _gen_2Darray_for_ffi, list_symbols_to_indices, @@ -737,7 +737,7 @@ def _get_feature_setup(self): group_info_list = [] if self.custom_cutoff == 4: for order, sigma in self.desc_list: - group_info_list += get_GMP_group_info(order) + group_info_list.append(get_GMP_group_info2(order)) if order == -1: feature_setup.append( [ @@ -802,8 +802,8 @@ def _prepare_feature_parameters(self): ) self.params_set["num"] = len(self.params_set["total"]) - self.params_set["total_num"] = int(np.sum(self.group_info_list)) - + # self.params_set["total_num"] = int(np.sum(self.group_info_list)) + self.params_set["total_num"] = int(np.sum([entry[1] for entry in self.group_info_list])) # if "prime_threshold" in self.GMPs: # self.params_set["prime_threshold"] = float(self.GMPs["prime_threshold"]) diff --git a/GMPFeaturizer/constants.py b/GMPFeaturizer/constants.py index 3112fd8..334bbdb 100644 --- a/GMPFeaturizer/constants.py +++ b/GMPFeaturizer/constants.py @@ -396,3 +396,26 @@ def get_GMP_group_info(mcsh_order): if mcsh_order == 9: return [3,6,6,3,6,6,6,6,3,3,6,1] +def get_GMP_group_info2(mcsh_order): + if mcsh_order == -1: + return (-1, 1) + if mcsh_order == 0: + return (0, 1) + if mcsh_order == 1: + return (1, 3) + if mcsh_order == 2: + return (2, 6) + if mcsh_order == 3: + return (3, 10) + if mcsh_order == 4: + return (4, 15) + if mcsh_order == 5: + return (5, 21) + if mcsh_order == 6: + return (6, 28) + if mcsh_order == 7: + return (7, 36) + if mcsh_order == 8: + return (8, 45) + if mcsh_order == 9: + return (9, 55) \ No newline at end of file