Skip to content

Commit

Permalink
Merge pull request #147 from tomlincr/set_params
Browse files Browse the repository at this point in the history
Add set_params to Estimator base class
  • Loading branch information
LucaCappelletti94 authored Feb 22, 2024
2 parents 62ba66f + 45e833e commit 5e1fe9c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,9 @@ dmypy.json

# Pyre type checker
.pyre/

# VSCode project settings
.vscode/

# Mac OS X clutter
**/.DS_Store
8 changes: 7 additions & 1 deletion karateclub/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,20 @@ def get_memberships(self):
def get_cluster_centers(self):
"""Getting the cluster centers."""
pass

def get_params(self):
"""Get parameter dictionary for this estimator.."""
rx = re.compile(r'^\_')
params = self.__dict__
params = {key: params[key] for key in params if not rx.search(key)}
return params

def set_params(self, **parameters):
"""Set the parameters of this estimator."""
for parameter, value in parameters.items():
setattr(self, parameter, value)
return self

def _set_seed(self):
"""Creating the initial random seed."""
random.seed(self.seed)
Expand Down
13 changes: 12 additions & 1 deletion test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,15 @@ def test_get_params():
params = model.get_params()
assert len(params) != 0
assert type(params) is dict
assert '_embedding' not in params
assert '_embedding' not in params

def test_set_params():
model = DeepWalk()
default_params = model.get_params()
params = {'dimensions': 1,
'seed': 123}
model.set_params(**params)
new_params = model.get_params()
assert new_params != default_params
assert new_params['dimensions'] == 1
assert new_params['seed'] == 123

0 comments on commit 5e1fe9c

Please sign in to comment.