Skip to content

Commit

Permalink
Make the error message of parallel_environment.cpp easier to decode. (#…
Browse files Browse the repository at this point in the history
…1533)

Instead of reporting mismatch between potential two huge structures,
we report which field causes the mismatch.
  • Loading branch information
emailweixu authored Sep 14, 2023
1 parent 2485b77 commit 806336e
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 5 deletions.
2 changes: 1 addition & 1 deletion alf/environments/fast_parallel_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def call_envs_with_same_args(self, func_name, *args, **kwargs):
func_name (str): name of the function to call
*args: args to pass to the function
**kwargs: kwargs to pass to the function
return:
list: list of results from each environment
"""
Expand Down
8 changes: 4 additions & 4 deletions alf/environments/parallel_environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ void CheckStrides(const py::buffer_info& info,
void SharedDataBuffer::CheckSize(const py::list& arrays,
const py::object& nested_array) {
if (arrays.size() != sizes_.size()) {
throw std::runtime_error(
py::str("array structure mismatch. Expected: {} arrays {}, Got: {} "
"arrays {}.")
.format(sizes_.size(), data_spec_, arrays.size(), nested_array));
py::object consistent_with_spec =
py::module::import("alf.utils.spec_utils").attr("consistent_with_spec");
consistent_with_spec(nested_array, data_spec_);
throw std::runtime_error(py::str("array structure mismatch."));
}
}

Expand Down
38 changes: 38 additions & 0 deletions alf/utils/spec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Collection of spec utility functions."""

import numpy as np
import sys
import torch
from typing import Iterable

Expand Down Expand Up @@ -114,3 +115,40 @@ def is_same_spec(spec1, spec2):
return False
same = nest.map_structure(lambda s1, s2: s1 == s2, spec1, spec2)
return all(nest.flatten(same))


def consistent_with_spec(nested, spec):
"""Check whether the nested structure is consistent with the spec.
Besides return the bool value, it will also print out the mismatching
information.
Args:
nested (nested Tensor): the nested structure to be checked
spec (nested TensorSpec): the spec
Returns:
bool: True if the nested structure is consistent with the spec
"""
nested_paths = nest.flatten(
nest.py_map_structure_with_path(lambda path, x: path, nested))
spec_paths = nest.flatten(
nest.py_map_structure_with_path(lambda path, x: path, spec))
if set(nested_paths) != set(spec_paths):
print("The nest does not match the spec!", file=sys.stderr)
not_in_spec = set(nested_paths) - set(spec_paths)
not_in_obs = set(spec_paths) - set(nested_paths)
if not_in_spec:
print("Not in spec: ", not_in_spec, file=sys.stderr)
if not_in_obs:
print("Not in nested: ", not_in_obs, file=sys.stderr)
return False

def _check_spec(path, x, s):
if not (x.shape == s.shape and x.dtype == np.dtype(s.dtype_str)):
print("Spec mismatch at path: ", path, file=sys.stderr)
return False
else:
return True

ok = nest.py_map_structure_with_path(_check_spec, nested, spec)
return all(nest.flatten(ok))

0 comments on commit 806336e

Please sign in to comment.