From 806336e98eca16a3e108901498c3936118dcc5c2 Mon Sep 17 00:00:00 2001 From: emailweixu Date: Thu, 14 Sep 2023 11:42:46 -0700 Subject: [PATCH] Make the error message of parallel_environment.cpp easier to decode. (#1533) Instead of reporting mismatch between potential two huge structures, we report which field causes the mismatch. --- alf/environments/fast_parallel_environment.py | 2 +- alf/environments/parallel_environment.cpp | 8 ++-- alf/utils/spec_utils.py | 38 +++++++++++++++++++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/alf/environments/fast_parallel_environment.py b/alf/environments/fast_parallel_environment.py index 2455dddb8..7b85e5355 100644 --- a/alf/environments/fast_parallel_environment.py +++ b/alf/environments/fast_parallel_environment.py @@ -195,7 +195,7 @@ def call_envs_with_same_args(self, func_name, *args, **kwargs): func_name (str): name of the function to call *args: args to pass to the function **kwargs: kwargs to pass to the function - + return: list: list of results from each environment """ diff --git a/alf/environments/parallel_environment.cpp b/alf/environments/parallel_environment.cpp index 1c7e7b732..df6af6083 100644 --- a/alf/environments/parallel_environment.cpp +++ b/alf/environments/parallel_environment.cpp @@ -128,10 +128,10 @@ void CheckStrides(const py::buffer_info& info, void SharedDataBuffer::CheckSize(const py::list& arrays, const py::object& nested_array) { if (arrays.size() != sizes_.size()) { - throw std::runtime_error( - py::str("array structure mismatch. Expected: {} arrays {}, Got: {} " - "arrays {}.") - .format(sizes_.size(), data_spec_, arrays.size(), nested_array)); + py::object consistent_with_spec = + py::module::import("alf.utils.spec_utils").attr("consistent_with_spec"); + consistent_with_spec(nested_array, data_spec_); + throw std::runtime_error(py::str("array structure mismatch.")); } } diff --git a/alf/utils/spec_utils.py b/alf/utils/spec_utils.py index 56ce511dd..1231932f5 100644 --- a/alf/utils/spec_utils.py +++ b/alf/utils/spec_utils.py @@ -14,6 +14,7 @@ """Collection of spec utility functions.""" import numpy as np +import sys import torch from typing import Iterable @@ -114,3 +115,40 @@ def is_same_spec(spec1, spec2): return False same = nest.map_structure(lambda s1, s2: s1 == s2, spec1, spec2) return all(nest.flatten(same)) + + +def consistent_with_spec(nested, spec): + """Check whether the nested structure is consistent with the spec. + + Besides return the bool value, it will also print out the mismatching + information. + + Args: + nested (nested Tensor): the nested structure to be checked + spec (nested TensorSpec): the spec + Returns: + bool: True if the nested structure is consistent with the spec + """ + nested_paths = nest.flatten( + nest.py_map_structure_with_path(lambda path, x: path, nested)) + spec_paths = nest.flatten( + nest.py_map_structure_with_path(lambda path, x: path, spec)) + if set(nested_paths) != set(spec_paths): + print("The nest does not match the spec!", file=sys.stderr) + not_in_spec = set(nested_paths) - set(spec_paths) + not_in_obs = set(spec_paths) - set(nested_paths) + if not_in_spec: + print("Not in spec: ", not_in_spec, file=sys.stderr) + if not_in_obs: + print("Not in nested: ", not_in_obs, file=sys.stderr) + return False + + def _check_spec(path, x, s): + if not (x.shape == s.shape and x.dtype == np.dtype(s.dtype_str)): + print("Spec mismatch at path: ", path, file=sys.stderr) + return False + else: + return True + + ok = nest.py_map_structure_with_path(_check_spec, nested, spec) + return all(nest.flatten(ok))