diff --git a/alf/networks/critic_networks.py b/alf/networks/critic_networks.py index 72a910cc2..8a7fe3a62 100644 --- a/alf/networks/critic_networks.py +++ b/alf/networks/critic_networks.py @@ -67,10 +67,12 @@ def __init__(self, input_tensor_spec, output_tensor_spec=TensorSpec(()), observation_input_processors=None, + observation_input_preprocessors_ctor=None, observation_preprocessing_combiner=None, observation_conv_layer_params=None, observation_fc_layer_params=None, action_input_processors=None, + action_input_processors_ctor=None, action_preprocessing_combiner=None, action_fc_layer_params=None, observation_action_combiner=None, @@ -89,6 +91,10 @@ def __init__(self, observation_input_preprocessors (nested Network|nn.Module|None): a nest of input preprocessors, each of which will be applied to the corresponding observation input. + observation_input_preprocessors_ctor (Callable): if ``observation_input_preprocessors`` + is None and ``observation_input_preprocessors_ctor`` is provided, then + ``observation_input_preprocessors`` will be constructed by calling + ``observation_input_preprocessors_ctor(observation_spec)``. observation_preprocessing_combiner (NestCombiner): preprocessing called on complex observation inputs. observation_conv_layer_params (tuple[tuple]): a tuple of tuples where each @@ -99,6 +105,10 @@ def __init__(self, action_input_processors (nested Network|nn.Module|None): a nest of input preprocessors, each of which will be applied to the corresponding action input. + action_input_processors_ctor (Callable): if ``action_input_processors`` + is None and ``action_input_processors_ctor`` is provided, then + ``action_input_processors`` will be constructed by calling + ``action_input_processors_ctor(action_spec)``. action_preprocessing_combiner (NestCombiner): preprocessing called to combine complex action inputs. action_fc_layer_params (tuple[int]): a tuple of integers representing @@ -135,6 +145,7 @@ def __init__(self, obs_encoder = EncodingNetwork( observation_spec, input_preprocessors=observation_input_processors, + input_preprocessors_ctor=observation_input_preprocessors_ctor, preprocessing_combiner=observation_preprocessing_combiner, conv_layer_params=observation_conv_layer_params, fc_layer_params=observation_fc_layer_params, @@ -149,6 +160,7 @@ def __init__(self, action_encoder = EncodingNetwork( action_spec, input_preprocessors=action_input_processors, + input_preprocessors_ctor=action_input_processors_ctor, preprocessing_combiner=action_preprocessing_combiner, fc_layer_params=action_fc_layer_params, activation=activation,