Skip to content

Commit

Permalink
updated gmpind
Browse files Browse the repository at this point in the history
  • Loading branch information
RayLei-TRI committed Aug 30, 2023
1 parent 38c9e71 commit 5e93457
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
8 changes: 4 additions & 4 deletions GMPFeaturizer/GMP_individual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy import sparse

from ..base_feature import BaseFeature
from ..constants import ATOM_SYMBOL_TO_INDEX_DICT, ATOM_LIST, PI, get_GMP_group_info
from ..constants import ATOM_SYMBOL_TO_INDEX_DICT, ATOM_LIST, PI, get_GMP_group_info, get_GMP_group_info2
from ..util import (
_gen_2Darray_for_ffi,
list_symbols_to_indices,
Expand Down Expand Up @@ -737,7 +737,7 @@ def _get_feature_setup(self):
group_info_list = []
if self.custom_cutoff == 4:
for order, sigma in self.desc_list:
group_info_list += get_GMP_group_info(order)
group_info_list.append(get_GMP_group_info2(order))
if order == -1:
feature_setup.append(
[
Expand Down Expand Up @@ -802,8 +802,8 @@ def _prepare_feature_parameters(self):
)
self.params_set["num"] = len(self.params_set["total"])

self.params_set["total_num"] = int(np.sum(self.group_info_list))

# self.params_set["total_num"] = int(np.sum(self.group_info_list))
self.params_set["total_num"] = int(np.sum([entry[1] for entry in self.group_info_list]))
# if "prime_threshold" in self.GMPs:
# self.params_set["prime_threshold"] = float(self.GMPs["prime_threshold"])

Expand Down
23 changes: 23 additions & 0 deletions GMPFeaturizer/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,26 @@ def get_GMP_group_info(mcsh_order):
if mcsh_order == 9:
return [3,6,6,3,6,6,6,6,3,3,6,1]

def get_GMP_group_info2(mcsh_order):
if mcsh_order == -1:
return (-1, 1)
if mcsh_order == 0:
return (0, 1)
if mcsh_order == 1:
return (1, 3)
if mcsh_order == 2:
return (2, 6)
if mcsh_order == 3:
return (3, 10)
if mcsh_order == 4:
return (4, 15)
if mcsh_order == 5:
return (5, 21)
if mcsh_order == 6:
return (6, 28)
if mcsh_order == 7:
return (7, 36)
if mcsh_order == 8:
return (8, 45)
if mcsh_order == 9:
return (9, 55)

0 comments on commit 5e93457

Please sign in to comment.