diff --git a/robustness/datasets.py b/robustness/datasets.py index 732dd19..e67bebe 100644 --- a/robustness/datasets.py +++ b/robustness/datasets.py @@ -88,7 +88,8 @@ def override_args(self, default_args, new_args): if len(extra_args) > 0: raise ValueError(f"Invalid arguments: {extra_args}") for k in kwargs: req_type = type(default_args[k]) - if (default_args[k] is not None) and (not isinstance(kwargs[k], req_type)): + no_nones = (default_args[k] is not None) and (kwargs[k] is not None) + if no_nones and (not isinstance(kwargs[k], req_type)): raise ValueError(f"Argument {k} should have type {req_type}") return {**default_args, **kwargs}