Skip to content

Commit

Permalink
add input_preprocessors_ctor to critic_network (#1541)
Browse files Browse the repository at this point in the history
  • Loading branch information
runjerry authored Sep 26, 2023
1 parent 33bf2d6 commit 5e30715
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions alf/networks/critic_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ def __init__(self,
input_tensor_spec,
output_tensor_spec=TensorSpec(()),
observation_input_processors=None,
observation_input_preprocessors_ctor=None,
observation_preprocessing_combiner=None,
observation_conv_layer_params=None,
observation_fc_layer_params=None,
action_input_processors=None,
action_input_processors_ctor=None,
action_preprocessing_combiner=None,
action_fc_layer_params=None,
observation_action_combiner=None,
Expand All @@ -89,6 +91,10 @@ def __init__(self,
observation_input_preprocessors (nested Network|nn.Module|None): a nest of
input preprocessors, each of which will be applied to the
corresponding observation input.
observation_input_preprocessors_ctor (Callable): if ``observation_input_preprocessors``
is None and ``observation_input_preprocessors_ctor`` is provided, then
``observation_input_preprocessors`` will be constructed by calling
``observation_input_preprocessors_ctor(observation_spec)``.
observation_preprocessing_combiner (NestCombiner): preprocessing called
on complex observation inputs.
observation_conv_layer_params (tuple[tuple]): a tuple of tuples where each
Expand All @@ -99,6 +105,10 @@ def __init__(self,
action_input_processors (nested Network|nn.Module|None): a nest of
input preprocessors, each of which will be applied to the
corresponding action input.
action_input_processors_ctor (Callable): if ``action_input_processors``
is None and ``action_input_processors_ctor`` is provided, then
``action_input_processors`` will be constructed by calling
``action_input_processors_ctor(action_spec)``.
action_preprocessing_combiner (NestCombiner): preprocessing called
to combine complex action inputs.
action_fc_layer_params (tuple[int]): a tuple of integers representing
Expand Down Expand Up @@ -135,6 +145,7 @@ def __init__(self,
obs_encoder = EncodingNetwork(
observation_spec,
input_preprocessors=observation_input_processors,
input_preprocessors_ctor=observation_input_preprocessors_ctor,
preprocessing_combiner=observation_preprocessing_combiner,
conv_layer_params=observation_conv_layer_params,
fc_layer_params=observation_fc_layer_params,
Expand All @@ -149,6 +160,7 @@ def __init__(self,
action_encoder = EncodingNetwork(
action_spec,
input_preprocessors=action_input_processors,
input_preprocessors_ctor=action_input_processors_ctor,
preprocessing_combiner=action_preprocessing_combiner,
fc_layer_params=action_fc_layer_params,
activation=activation,
Expand Down

0 comments on commit 5e30715

Please sign in to comment.