Skip to content

Commit

Permalink
Switching from frozendict to immutabledict
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 363554002
  • Loading branch information
jaehunro authored and fedjax authors committed Mar 18, 2021
1 parent f95ea1a commit 59204c3
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions fedjax/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from fedjax.core.typing import Params
from fedjax.core.typing import PRNGKey
from fedjax.core.typing import Updates
import frozendict
import haiku as hk
import immutabledict
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -69,9 +69,9 @@ class Model:
apply_fn: Callable[..., jnp.ndarray]
loss_fn: MetricsFn
reg_fn: Callable[[Params], jnp.ndarray] = lambda p: 0.
metrics_fn_map: Mapping[str, MetricsFn] = frozendict.frozendict()
train_kwargs: Mapping[str, Any] = frozendict.frozendict()
test_kwargs: Mapping[str, Any] = frozendict.frozendict()
metrics_fn_map: Mapping[str, MetricsFn] = immutabledict.immutabledict()
train_kwargs: Mapping[str, Any] = immutabledict.immutabledict()
test_kwargs: Mapping[str, Any] = immutabledict.immutabledict()
modify_grads_fn: Callable[[Updates], Updates] = lambda g: g

def init_params(self, rng: PRNGKey) -> Params:
Expand Down Expand Up @@ -117,9 +117,9 @@ def _get_defaults(reg_fn, metrics_fn_map, train_kwargs, test_kwargs):
metrics_fn_map = metrics_fn_map or collections.OrderedDict()
train_kwargs = train_kwargs or {}
test_kwargs = test_kwargs or {}
frozen_metrics_fn_map = frozendict.frozendict(metrics_fn_map)
frozen_train_kwargs = frozendict.frozendict(train_kwargs)
frozen_test_kwargs = frozendict.frozendict(test_kwargs)
frozen_metrics_fn_map = immutabledict.immutabledict(metrics_fn_map)
frozen_train_kwargs = immutabledict.immutabledict(train_kwargs)
frozen_test_kwargs = immutabledict.immutabledict(test_kwargs)
return reg_fn, frozen_metrics_fn_map, frozen_train_kwargs, frozen_test_kwargs


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
packages=find_namespace_packages(exclude=['*_test.py']),
install_requires=[
'dm-haiku',
'frozendict',
'immutabledict',
'jax',
'jaxlib',
'optax',
Expand Down

0 comments on commit 59204c3

Please sign in to comment.