Skip to content

Commit

Permalink
fix: misc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jesseylin committed Aug 25, 2023
1 parent c843fa0 commit fcc7f12
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

@dataclass
class SystemParams(Defaults):
sequence_max = defaults(100.0, min=1.0, max=1000.0)
sequence_density = defaults(100, min=1, max=1000)
sequence_max: float = defaults(100.0, min=1.0, max=1000.0)
sequence_density: int = defaults(100, min=1, max=1000)


@SystemParams.expose_config()
Expand Down Expand Up @@ -79,3 +79,11 @@ def _tensor_map(
y_tensor = torch.sum(mixture_tensor, dim=0)

return x_tensor, y_tensor


def test():
assert len(ExponentialMixture.CONFIG_DEFINITION) > 0


if __name__ == "__main__":
test()
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, premap_key: str = "input", postmap_key: str = "output") -> No
self.locator = BlobLocator()

@abstractmethod
def map(self, input: Dict) -> Dict:
def map(self, input: Dict, override_existing: bool = True) -> Dict:
"""Map input to output.
A `Map` operates on regular Python dicts, but it views them as
Expand All @@ -43,9 +43,12 @@ def map(self, input: Dict) -> Dict:
Args:
input:
generic dict containing input data to the map at `premap_key`
Generic dict containing input data to the map at `premap_key`
override_existing:
Flag which controls whether the value associated to
`postmap_key` should be overridden if it exists alrady
Returns:
A dict containing output data from the map at `postmap_key`
A dict containing output data from the map at `postmap_key`
"""
raise NotImplementedError

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ def get_auto_disc_registered_modules_info(registered_modules):
infos = []
for module_name, module_class_path in registered_modules.items():
module_class = get_cls_from_path(module_class_path)
if module_class is None:
raise ValueError(
f"Could not retrieve class from path: {module_class_path}."
)
info = {}
info["name"] = module_name
try:
Expand Down

0 comments on commit fcc7f12

Please sign in to comment.