Skip to content

Commit

Permalink
add option to get bwdf for specific label and none normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
naik-aakash committed Oct 30, 2024
1 parent aac8567 commit c79a750
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/lobsterpy/featurize/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ class BatchIcoxxlistFeaturizer:
def __init__(
self,
path_to_lobster_calcs: str | Path,
normalization: Literal["formula_units", "area"] = "formula_units",
normalization: Literal["formula_units", "area", "none"] = "formula_units",
bin_width: float = 0.02,
max_length: float = 6.0,
min_length: float = 0.0,
Expand Down
54 changes: 51 additions & 3 deletions src/lobsterpy/featurize/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ def __init__(
bin_width: float = 0.02,
max_length: float = 6.0,
min_length: float = 0.0,
normalization: Literal["formula_units", "area"] = "formula_units",
normalization: Literal["formula_units", "area", "none"] = "formula_units",
are_cobis: bool = False,
are_coops: bool = False,
):
Expand Down Expand Up @@ -1306,7 +1306,7 @@ def calc_bwdf(self):
if self.normalization == "area":
total_area = np.sum(np.abs(bwdf_atom_pair[atom_pair]["icoxx_binned"]) * self.bin_width)
bwdf_atom_pair[atom_pair]["icoxx_binned"] = bwdf_atom_pair[atom_pair]["icoxx_binned"] / total_area
else:
elif self.normalization == "formula_units":
formula_units = self.structure.composition.get_reduced_formula_and_factor()[-1]
bwdf_atom_pair[atom_pair]["icoxx_binned"] = (
bwdf_atom_pair[atom_pair]["icoxx_binned"] / formula_units
Expand Down Expand Up @@ -1412,12 +1412,60 @@ def calc_site_bwdf(self, site_index: int) -> dict:
if self.normalization == "area":
total_area = np.sum(np.abs(site_bwdf["icoxx_binned"]) * self.bin_width)
site_bwdf["icoxx_binned"] = site_bwdf["icoxx_binned"] / total_area
else:
elif self.normalization == "formula_units":
formula_units = self.structure.composition.get_reduced_formula_and_factor()[-1]
site_bwdf["icoxx_binned"] = site_bwdf["icoxx_binned"] / formula_units

return site_bwdf

def calc_label_bwdf(self, bond_label: str) -> dict:
"""
Compute BWDF from ICOXXLIST.lobster data for a bond label.
Args:
bond_label: bond label for which BWDF needs to be computed
Returns:
BWDF for a bond label as a dictionary
"""
index = self.icoxxlist.icohpcollection._list_labels.index(bond_label)
bond_length = self.icoxxlist.icohpcollection._list_length[index]
atom1 = self.icoxxlist.icohpcollection._list_atom1[index]
atom2 = self.icoxxlist.icohpcollection._list_atom2[index]
icoxx = sum(self.icoxxlist.icohpcollection._list_icohp[index].values())
trans = self.icoxxlist.icohpcollection._list_translation[index]

# Complete data
complete_data = [(sorted([atom1, atom2]), bond_length, trans, icoxx)]

# Calculate number of bins
n_bins = int(np.ceil((self.max_length - self.min_length) / self.bin_width))

# Get bin edges and centers
bin_edges = np.round(np.linspace(self.min_length, self.max_length, n_bins), 5)
bin_centers = bin_edges[:-1] + self.bin_width / 2

# Initialize dictionary for storing binned data by atom pair
label_bwdf = {"icoxx_binned": np.zeros(bin_centers.shape)}

for interactions in complete_data:
for ii, l1, l2 in zip(range(len(bin_centers)), bin_edges[:-1], bin_edges[1:]):
if interactions[1] >= l1 and interactions[1] < l2:
label_bwdf["icoxx_binned"][ii] += interactions[3] # sum icoxx values in the bin

label_bwdf["centers"] = bin_centers
label_bwdf["edges"] = bin_edges
label_bwdf["bin_width"] = self.bin_width

if self.normalization == "area":
total_area = np.sum(np.abs(label_bwdf["icoxx_binned"]) * self.bin_width)
label_bwdf["icoxx_binned"] = label_bwdf["icoxx_binned"] / total_area
elif self.normalization == "formula_units":
formula_units = self.structure.composition.get_reduced_formula_and_factor()[-1]
label_bwdf["icoxx_binned"] = label_bwdf["icoxx_binned"] / formula_units

return label_bwdf

@staticmethod
def _get_features_col_names(bwdf: dict) -> list[str]:
"""
Expand Down

0 comments on commit c79a750

Please sign in to comment.