Skip to content

Commit

Permalink
Added cross validation
Browse files Browse the repository at this point in the history
  • Loading branch information
SurgeArrester committed Mar 6, 2021
1 parent 2e615e8 commit bd346d8
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 2 deletions.
70 changes: 70 additions & 0 deletions ElM2D/ElM2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,76 @@ def sort(self):
return sorted_indices


def cross_validate(self, y=None, k=5, shuffle=True, seed=42):
"""
Implementation of cross validation with K-Folds.
Splits the formula_list into k equal sized partitions and returns five
tuples of training and test sets. Returns a list of length k, each item
containing 2 (4 with target data) numpy arrays of formulae of
length n - n/k and n/k.
Parameters:
y=None: (optional) a numpy array of target properties to cross validate
k=5: Number of k-folds
shuffle=True: whether to shuffle the input formulae or not
Usage:
cvs = mapper.cross_validate()
for i, (X_train, X_test) in enumerate(cvs):
sub_mapper = ElM2D()
sub_mapper.fit(X_train)
sub_mapper.save(f"train_elm2d_{i}.pk")
sub_mapper.fit(X_test)
sub_mapper.save(f"test_elm2d_{i}.pk")
...
cvs = mapper.cross_validate(y=df["target"])
for X_train, X_test, y_train, y_test in cvs:
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
errors.append(mae(y_pred, y_test))
print(np.mean(errors))
"""
inds = np.arange(len(self.formula_list)) # TODO Exception

if shuffle:
np.random.seed(seed)
np.random.shuffle(inds)

formulas = self.formula_list.to_numpy(str)[inds]
splits = np.array_split(formulas, k)

X_ret = []

for i in range(k):
train_splits = np.delete(np.arange(k), i)
X_train = splits[train_splits[0]]

for index in train_splits[1:]:
X_train = np.concatenate((X_train, splits[index]))

X_test = splits[i]
X_ret.append((X_train, X_test))

if y is None:
return X_ret

y = y[inds]
y_splits = np.array_split(y, k)
y_ret = []

for i in range(k):
train_splits = np.delete(np.arange(k), i)
y_train = y_splits[train_splits[0]]

for index in train_splits[1:]:
y_train = np.concatenate((y_train, y_splits[index]))

y_test = y_splits[i]
y_ret.append((y_train, y_test))

return [(X_ret[i][0], X_ret[i][1], y_ret[i][0], y_ret[i][1]) for i in range(k)]

def _process_list(self, formula_list, n_proc, metric="mod_petti"):
'''
Given an iterable list of formulas in composition form
Expand Down
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,35 @@ mapper.export_dm("large_df_emb_UMAP.csv")
mapper = ElM2D()
mapper.import_dm("large_df_dm.csv")
mapper.import_embedding("large_df_emb_UMAP.csv")
mapper.formula_list = df["formula"]
```

### Cross Validation

```python
cvs = mapper.cross_validation()
for i, (X_train, X_test) in enumerate(cvs):
sub_mapper = ElM2D()

sub_mapper.fit(X_train)
sub_mapper.save(f"train_elm2d_{i}.pk")

sub_mapper.fit(X_test)
sub_mapper.save(f"test_elm2d_{i}.pk")
...
from sklearn.metrics import mean_average_error as mae

cvs = mapper.cross_validation(y=df["target"])

for X_train, X_test, y_train, y_test in cvs:
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
errors.append(mae(y_pred, y_test))

print(np.mean(errors))
```


## Citing

If you would like to cite this code in your work, please use the Chemistry of Materials reference
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
setup(
name = 'ElM2D',
packages = ['ElM2D'],
version = '0.2.9',
version = '0.2.10',
license='GPL3',
description = 'A high performance mapping class to embed large datasets of ionic compositions with respect to the ElMD metric.',
author = 'Cameron Hagreaves',
author_email = 'cameron.h@rgreaves.me.uk',
url = 'https://github.com/lrcfmd/ElM2D/',
download_url = 'https://github.com/lrcfmd/ElM2D/archive/0.2.9.tar.gz',
download_url = 'https://github.com/lrcfmd/ElM2D/archive/0.2.10.tar.gz',
keywords = ['ChemInformatics', 'Materials Science', 'Machine Learning', 'Materials Representation'],
install_requires=[
'cython',
Expand Down

0 comments on commit bd346d8

Please sign in to comment.