From 3c34b5e0e2d15bf4d0c874932cf59ae49bd5edd0 Mon Sep 17 00:00:00 2001 From: Chen001117 Date: Sat, 11 Nov 2023 22:38:31 -0500 Subject: [PATCH 1/8] not using shared model --- examples/nlp/ds_config.json | 4 +- examples/nlp/eval_ds_config.json | 4 +- examples/nlp/nlp_ppo_ds.yaml | 4 +- examples/nlp/train_ppo.py | 9 +- openrl/algorithms/ppo.py | 14 +- openrl/modules/networks/policy_network_gpt.py | 169 ++++++++++++++++++ openrl/modules/networks/value_network_gpt.py | 104 +++++++++++ 7 files changed, 294 insertions(+), 14 deletions(-) create mode 100644 openrl/modules/networks/policy_network_gpt.py create mode 100644 openrl/modules/networks/value_network_gpt.py diff --git a/examples/nlp/ds_config.json b/examples/nlp/ds_config.json index 544bc405..d3b68fe1 100644 --- a/examples/nlp/ds_config.json +++ b/examples/nlp/ds_config.json @@ -1,6 +1,6 @@ { - "train_batch_size": 32, - "train_micro_batch_size_per_gpu": 16, + "train_batch_size": 16, + "train_micro_batch_size_per_gpu": 4, "steps_per_print": 10, "zero_optimization": { "stage": 2, diff --git a/examples/nlp/eval_ds_config.json b/examples/nlp/eval_ds_config.json index 58c08252..e9429896 100644 --- a/examples/nlp/eval_ds_config.json +++ b/examples/nlp/eval_ds_config.json @@ -1,6 +1,6 @@ { - "train_batch_size": 32, - "train_micro_batch_size_per_gpu": 16, + "train_batch_size": 16, + "train_micro_batch_size_per_gpu": 4, "steps_per_print": 10, "zero_optimization": { "stage": 0, diff --git a/examples/nlp/nlp_ppo_ds.yaml b/examples/nlp/nlp_ppo_ds.yaml index ab0c0b6c..3a031ae6 100644 --- a/examples/nlp/nlp_ppo_ds.yaml +++ b/examples/nlp/nlp_ppo_ds.yaml @@ -7,9 +7,9 @@ use_valuenorm: true use_adv_normalize: true wandb_entity: "openrl-lab" ppo_epoch: 5 -episode_length: 128 +episode_length: 64 num_mini_batch: 20 -use_share_model: true +# use_share_model: true hidden_size: 1 diff --git a/examples/nlp/train_ppo.py b/examples/nlp/train_ppo.py index 728e4aa5..384d8f9d 100644 --- a/examples/nlp/train_ppo.py +++ b/examples/nlp/train_ppo.py @@ -3,9 +3,8 @@ from openrl.configs.config import create_config_parser from openrl.envs.common import make from openrl.modules.common import PPONet as Net -from openrl.modules.networks.policy_value_network_gpt import ( - PolicyValueNetworkGPT as PolicyValueNetwork, -) +from openrl.modules.networks.value_network_gpt import ValueNetworkGPT as ValueNetwork +from openrl.modules.networks.policy_network_gpt import PolicyNetworkGPT as PolicyNetwork from openrl.runners.common import PPOAgent as Agent @@ -29,11 +28,11 @@ def train(): ) # create the neural network - model_dict = {"model": PolicyValueNetwork} + model_dict = {"policy": PolicyNetwork, "critic": ValueNetwork} net = Net(env, device="cuda", cfg=cfg, model_dict=model_dict) # initialize the trainer - agent = Agent(net, use_wandb=True) + agent = Agent(net, use_wandb=False) # start training agent.train(total_time_steps=100000) diff --git a/openrl/algorithms/ppo.py b/openrl/algorithms/ppo.py index e72e01bb..fafd657e 100644 --- a/openrl/algorithms/ppo.py +++ b/openrl/algorithms/ppo.py @@ -45,7 +45,8 @@ def __init__( def ppo_update(self, sample, turn_on=True): for optimizer in self.algo_module.optimizers.values(): - optimizer.zero_grad() + if not self.use_deepspeed: + optimizer.zero_grad() ( critic_obs_batch, @@ -152,8 +153,15 @@ def ppo_update(self, sample, turn_on=True): self.algo_module.scaler.update() else: - for optimizer in self.algo_module.optimizers.values(): - optimizer.step() + if self.use_deepspeed: + if self._use_share_model: + self.algo_module.optimizers["model"].step() + else: + self.algo_module.optimizers["policy"].step() + self.algo_module.optimizers["critic"].step() + else: + for optimizer in self.algo_module.optimizers.values(): + optimizer.step() if self.world_size > 1: torch.cuda.synchronize() diff --git a/openrl/modules/networks/policy_network_gpt.py b/openrl/modules/networks/policy_network_gpt.py new file mode 100644 index 00000000..cad3157e --- /dev/null +++ b/openrl/modules/networks/policy_network_gpt.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2021 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import Any, Optional, Dict + +import numpy as np +import torch +import torch.nn as nn + +from openrl.buffers.utils.util import get_policy_obs, get_policy_obs_space +from openrl.modules.networks.base_policy_network import BasePolicyNetwork +from openrl.modules.networks.utils.act import ACTLayer +from openrl.modules.networks.utils.cnn import CNNBase +from openrl.modules.networks.utils.mix import MIXBase +from openrl.modules.networks.utils.mlp import MLPBase, MLPLayer +from openrl.modules.networks.utils.popart import PopArt +from openrl.modules.networks.utils.rnn import RNNLayer +from openrl.modules.networks.utils.util import init +from openrl.utils.util import check_v2 as check +from openrl.envs.nlp.utils.distribution import CategoricalDistribution + +from transformers.modeling_utils import unwrap_model + +class PolicyNetworkGPT(BasePolicyNetwork): + def __init__( + self, + cfg, + input_space, + action_space, + device=torch.device("cpu"), + use_half=False, + disable_drop_out: bool = True, + extra_args=None, + ) -> None: + + self.use_half = use_half + self.tpdv = dict(dtype=torch.float32, device=device) + + super(PolicyNetworkGPT, self).__init__(cfg, device) + + self.disable_drop_out = disable_drop_out + + + + self._action_dist = CategoricalDistribution(action_space.n) + + from transformers import AutoConfig, AutoModelForCausalLM + config = AutoConfig.from_pretrained(cfg.model_path) + config_dict = config.to_dict() + for key in config_dict: + if "drop" in key: + config_dict[key] = 0.0 + config = config.from_dict(config_dict) + self._policy_model = AutoModelForCausalLM.from_pretrained( + cfg.model_path, config=config + ) + self._policy_model.config.use_cache = False + + def forward(self, forward_type, *args, **kwargs): + if forward_type == "original": + return self.forward_original(*args, **kwargs) + elif forward_type == "eval_actions": + return self.eval_actions(*args, **kwargs) + else: + raise NotImplementedError + + def _prepare_inputs_for_model( + self, + model: Any, + input_ids: torch.tensor, + model_kwargs: Optional[Dict[str, torch.tensor]] = None, + ): + model_inputs = unwrap_model(model).prepare_inputs_for_generation( + input_ids, **model_kwargs + ) + return model_inputs + + def forward_original( + self, raw_obs, rnn_states, masks, action_masks=None, deterministic=False + ): + for key in raw_obs.keys(): + raw_obs[key] = torch.from_numpy(raw_obs[key]) if type(raw_obs[key]) == np.ndarray else raw_obs[key] + raw_obs[key] = raw_obs[key].to(self._policy_model.device) + # raw_obs[key] = check(raw_obs[key], self.use_half, self.tpdv) + # if self._use_fp16: + # raw_obs[key] = raw_obs[key].half() + rnn_states = check(rnn_states) + + input_ids = raw_obs["input_encoded_pt"].int() + attention_mask = raw_obs["input_attention_mask_pt"] + + past_model_kwargs = None + + if past_model_kwargs is None: + past_model_kwargs = { + "attention_mask": attention_mask, + } + + model_inputs = self._prepare_inputs_for_model( + self._policy_model, input_ids, past_model_kwargs + ) + + # forward pass to transformers + output = self._policy_model(**model_inputs) + + # compute action probs - policy head + next_token_logits = output.logits[:, -1] + dist = self._action_dist.proba_distribution(action_logits=next_token_logits) + + actions = dist.mode() if deterministic else dist.sample() + action_log_probs = dist.log_prob(actions) + + return actions.unsqueeze(-1), action_log_probs.unsqueeze(-1), rnn_states + + def eval_actions( + self, obs, rnn_states, action, masks, action_masks=None, active_masks=None + ): + for key in obs.keys(): + obs[key] = torch.from_numpy(obs[key]) if type(obs[key]) == np.ndarray else obs[key] + obs[key] = obs[key].to(self._policy_model.device) + # obs[key] = check(obs[key], self.use_half, self.tpdv) + # if self._use_fp16: + # obs[key] = obs[key].half() + action = check(action).to(self._policy_model.device).squeeze() + rnn_states = check(rnn_states) + + input_ids = obs["input_encoded_pt"].int() + attention_mask = obs["input_attention_mask_pt"] + + past_model_kwargs = None + + if past_model_kwargs is None: + past_model_kwargs = { + "attention_mask": attention_mask, + } + + model_inputs = self._prepare_inputs_for_model( + self._policy_model, input_ids, past_model_kwargs + ) + + # forward pass to transformers + output = self._policy_model(**model_inputs) + + # compute action probs - policy head + next_token_logits = output.logits[:, -1] + dist = self._action_dist.proba_distribution(action_logits=next_token_logits) + + action_log_probs = dist.log_prob(action) + dist_entropy = dist.entropy() + values = None + + return action_log_probs.unsqueeze(-1), dist_entropy.mean(), values + + def get_policy_values(self, obs, rnn_states, masks): + raise NotImplementedError \ No newline at end of file diff --git a/openrl/modules/networks/value_network_gpt.py b/openrl/modules/networks/value_network_gpt.py new file mode 100644 index 00000000..13db87b8 --- /dev/null +++ b/openrl/modules/networks/value_network_gpt.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2021 The OpenRL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""""" +from typing import Any, Optional, Dict + +import numpy as np +import torch +import torch.nn as nn + +from openrl.buffers.utils.util import get_critic_obs_space +from openrl.modules.networks.base_value_network import BaseValueNetwork +from openrl.modules.networks.utils.cnn import CNNBase +from openrl.modules.networks.utils.mix import MIXBase +from openrl.modules.networks.utils.mlp import MLPBase, MLPLayer +from openrl.modules.networks.utils.popart import PopArt +from openrl.modules.networks.utils.rnn import RNNLayer +from openrl.modules.networks.utils.util import init +from openrl.modules.utils.valuenorm import ValueNorm +from openrl.utils.util import check_v2 as check + +from transformers.modeling_utils import unwrap_model + +class ValueNetworkGPT(BaseValueNetwork): + def __init__( + self, + cfg, + input_space, + action_space=None, + use_half=False, + device=torch.device("cpu"), + extra_args=None, + ): + + self.use_half = use_half + self.tpdv = dict(dtype=torch.float32, device=device) + + super(ValueNetworkGPT, self).__init__(cfg, device) + + from transformers import AutoModelForCausalLM + + self._value_model = AutoModelForCausalLM.from_pretrained(cfg.model_path) + self._value_model.config.use_cache = False + self._value_head = nn.Linear( + self._value_model.config.hidden_size, 1, bias=False + ) + self.value_normalizer = ( + ValueNorm(1, device=device) if self._use_valuenorm else None + ) + + self._value_head.to(self.device) + + + def _prepare_inputs_for_model( + self, + model: Any, + input_ids: torch.tensor, + model_kwargs: Optional[Dict[str, torch.tensor]] = None, + ): + model_inputs = unwrap_model(model).prepare_inputs_for_generation( + input_ids, **model_kwargs + ) + return model_inputs + + def forward(self, critic_obs, rnn_states, masks): + for key in critic_obs.keys(): + critic_obs[key] = torch.from_numpy(critic_obs[key]) if type(critic_obs[key]) == np.ndarray else critic_obs[key] + critic_obs[key] = critic_obs[key].to(self._value_model.device) + # critic_obs[key] = check(critic_obs[key], self.use_half, self.tpdv) + # if self._use_fp16: + # critic_obs[key] = critic_obs[key].half() + masks = check(masks).to(self._value_model.device) + rnn_states = check(rnn_states) + + input_ids = critic_obs["input_encoded_pt"].int() + attention_mask = critic_obs["input_attention_mask_pt"] + + past_model_kwargs = None + if not past_model_kwargs: + past_model_kwargs = { + "attention_mask": attention_mask, + } + + model_inputs = self._prepare_inputs_for_model( + self._value_model, input_ids, past_model_kwargs + ) + output = self._value_model(output_hidden_states=True, **model_inputs) + last_tokens_hidden = output.hidden_states[-1][:, -1] + values = self._value_head.forward(last_tokens_hidden) + + return values, rnn_states From 2a1cf9cac6ed9e514081f5670e17fdb16245e6d5 Mon Sep 17 00:00:00 2001 From: Chen001117 Date: Tue, 28 Nov 2023 11:50:10 +0800 Subject: [PATCH 2/8] update data parallel and model parallel --- examples/nlp/ds_config.json | 4 +- examples/nlp/eval_ds_config.json | 4 +- examples/nlp/nlp_ppo.yaml | 12 ++--- examples/nlp/nlp_ppo_ds.yaml | 13 +++-- examples/nlp/train_ppo.py | 2 +- openrl/envs/nlp/rewards/kl_penalty.py | 27 +++++++--- openrl/modules/networks/policy_network_gpt.py | 54 ++++++++++++++----- openrl/modules/networks/value_network_gpt.py | 44 ++++++++++++--- openrl/modules/utils/valuenorm.py | 18 +++---- openrl/utils/logger.py | 42 +++++++++------ 10 files changed, 151 insertions(+), 69 deletions(-) diff --git a/examples/nlp/ds_config.json b/examples/nlp/ds_config.json index d3b68fe1..544bc405 100644 --- a/examples/nlp/ds_config.json +++ b/examples/nlp/ds_config.json @@ -1,6 +1,6 @@ { - "train_batch_size": 16, - "train_micro_batch_size_per_gpu": 4, + "train_batch_size": 32, + "train_micro_batch_size_per_gpu": 16, "steps_per_print": 10, "zero_optimization": { "stage": 2, diff --git a/examples/nlp/eval_ds_config.json b/examples/nlp/eval_ds_config.json index e9429896..58c08252 100644 --- a/examples/nlp/eval_ds_config.json +++ b/examples/nlp/eval_ds_config.json @@ -1,6 +1,6 @@ { - "train_batch_size": 16, - "train_micro_batch_size_per_gpu": 4, + "train_batch_size": 32, + "train_micro_batch_size_per_gpu": 16, "steps_per_print": 10, "zero_optimization": { "stage": 0, diff --git a/examples/nlp/nlp_ppo.yaml b/examples/nlp/nlp_ppo.yaml index b46e6211..1ba77379 100644 --- a/examples/nlp/nlp_ppo.yaml +++ b/examples/nlp/nlp_ppo.yaml @@ -9,23 +9,21 @@ wandb_entity: "openrl-lab" ppo_epoch: 5 episode_length: 128 num_mini_batch: 20 -use_share_model: true hidden_size: 1 - -model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog +model_path: /home/chenwenze/data_server/huggingface/models/facebook/opt-125m env: args: { - 'tokenizer_path': 'gpt2', - 'data_path': 'daily_dialog', + 'tokenizer_path': '/home/chenwenze/data_server/huggingface/models/facebook/opt-125m', + 'data_path': '/home/chenwenze/data_server/huggingface/datasets/daily_dialog', } vec_info_class: id: "NLPVecInfo" reward_class: id: "NLPReward" args: { - "ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog", - "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier", + "ref_model": "/home/chenwenze/data_server/huggingface/models/facebook/opt-125m", + "intent_model": "/home/chenwenze/data_server/huggingface/models/rajkumarrrk/roberta-daily-dialog-intent-classifier", } \ No newline at end of file diff --git a/examples/nlp/nlp_ppo_ds.yaml b/examples/nlp/nlp_ppo_ds.yaml index 3a031ae6..c9d4ad60 100644 --- a/examples/nlp/nlp_ppo_ds.yaml +++ b/examples/nlp/nlp_ppo_ds.yaml @@ -7,9 +7,8 @@ use_valuenorm: true use_adv_normalize: true wandb_entity: "openrl-lab" ppo_epoch: 5 -episode_length: 64 +episode_length: 128 num_mini_batch: 20 -# use_share_model: true hidden_size: 1 @@ -18,11 +17,11 @@ use_fp16: false use_offload: false deepspeed_config: ds_config.json -model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog +model_path: /home/chenwenze/data_server/huggingface/models/facebook/opt-125m env: args: { - 'tokenizer_path': 'gpt2', - 'data_path': 'daily_dialog', + 'tokenizer_path': '/home/chenwenze/data_server/huggingface/models/gpt2', + 'data_path': '/home/chenwenze/data_server/huggingface/datasets/daily_dialog', } vec_info_class: id: "NLPVecInfo" @@ -31,8 +30,8 @@ reward_class: args: { "use_deepspeed": true, "ref_ds_config": "eval_ds_config.json", - "ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog", + "ref_model": "/home/chenwenze/data_server/huggingface/models/facebook/opt-125m", "intent_ds_config": "eval_ds_config.json", - "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier", + "intent_model": "/home/chenwenze/data_server/huggingface/models/rajkumarrrk/roberta-daily-dialog-intent-classifier", } \ No newline at end of file diff --git a/examples/nlp/train_ppo.py b/examples/nlp/train_ppo.py index 384d8f9d..18347a6b 100644 --- a/examples/nlp/train_ppo.py +++ b/examples/nlp/train_ppo.py @@ -32,7 +32,7 @@ def train(): net = Net(env, device="cuda", cfg=cfg, model_dict=model_dict) # initialize the trainer - agent = Agent(net, use_wandb=False) + agent = Agent(net, use_wandb=True) # start training agent.train(total_time_steps=100000) diff --git a/openrl/envs/nlp/rewards/kl_penalty.py b/openrl/envs/nlp/rewards/kl_penalty.py index fe9e9594..7f5a6426 100644 --- a/openrl/envs/nlp/rewards/kl_penalty.py +++ b/openrl/envs/nlp/rewards/kl_penalty.py @@ -35,10 +35,16 @@ def __init__( ds_config: str = "default", ): super().__init__() + + self.device = "cuda" + self.use_data_parallel = False + self.use_model_parallel = False self.use_deepspeed = use_deepspeed + assert not (self.use_deepspeed and self.use_data_parallel) + assert not (self.use_deepspeed and self.use_model_parallel) + assert not (self.use_data_parallel and self.use_model_parallel) # reference model - self._apply_model_parallel = apply_model_parallel if ref_model == "builtin_ref": from transformers import GPT2Config, GPT2LMHeadModel @@ -65,10 +71,11 @@ def __init__( self._ref_engine, *_ = deepspeed.initialize(model=self, config=ds_config) elif torch.cuda.is_available(): - if self._apply_model_parallel and self._ref_net.is_parallelizable: + if self.use_model_parallel: self._ref_net.parallelize() - else: # else defaults to data parallel + elif self.use_data_parallel: # else defaults to data parallel self._ref_net = torch.nn.DataParallel(self._ref_net) + self._ref_net = self._ref_net.to(self.device) # alpha adjustment self._alpha = 0.2 @@ -144,7 +151,7 @@ def _prepare_inputs_for_model( input_ids, **model_kwargs ) - if self._apply_model_parallel and unwrap_model(model).is_parallelizable: + if self.use_model_parallel: # if model is in parallel mode, move the tensors to the first device model_inputs = { key: ( @@ -155,8 +162,16 @@ def _prepare_inputs_for_model( ) for key, value in model_inputs.items() } - - if self.use_deepspeed: + elif self.use_data_parallel: + model_inputs = { + key: ( + value.to(self.device) + if isinstance(value, torch.Tensor) + else value + ) + for key, value in model_inputs.items() + } + elif self.use_deepspeed: model_inputs = { key: value.to("cuda") if isinstance(value, torch.Tensor) else value for key, value in model_inputs.items() diff --git a/openrl/modules/networks/policy_network_gpt.py b/openrl/modules/networks/policy_network_gpt.py index cad3157e..5c97feef 100644 --- a/openrl/modules/networks/policy_network_gpt.py +++ b/openrl/modules/networks/policy_network_gpt.py @@ -47,15 +47,21 @@ def __init__( extra_args=None, ) -> None: + self.device = device self.use_half = use_half - self.tpdv = dict(dtype=torch.float32, device=device) + + self.use_data_parallel = False + self.use_model_parallel = False + self.use_deepspeed = cfg.use_deepspeed + + assert not (self.use_deepspeed and self.use_data_parallel) + assert not (self.use_deepspeed and self.use_model_parallel) + assert not (self.use_data_parallel and self.use_model_parallel) super(PolicyNetworkGPT, self).__init__(cfg, device) self.disable_drop_out = disable_drop_out - - self._action_dist = CategoricalDistribution(action_space.n) from transformers import AutoConfig, AutoModelForCausalLM @@ -70,6 +76,14 @@ def __init__( ) self._policy_model.config.use_cache = False + if torch.cuda.is_available(): + if self.use_model_parallel: + self._policy_model.parallelize() + elif self.use_data_parallel: + self._policy_model = torch.nn.DataParallel(self._policy_model) + self._policy_model = self._policy_model.to(self.device) + + def forward(self, forward_type, *args, **kwargs): if forward_type == "original": return self.forward_original(*args, **kwargs) @@ -87,6 +101,18 @@ def _prepare_inputs_for_model( model_inputs = unwrap_model(model).prepare_inputs_for_generation( input_ids, **model_kwargs ) + + if self.use_model_parallel: + model_inputs = { + key: ( + value.to(model.transformer.first_device) + if isinstance(value, torch.Tensor) + and hasattr(model.transformer, "first_device") + else value + ) + for key, value in model_inputs.items() + } + return model_inputs def forward_original( @@ -94,10 +120,11 @@ def forward_original( ): for key in raw_obs.keys(): raw_obs[key] = torch.from_numpy(raw_obs[key]) if type(raw_obs[key]) == np.ndarray else raw_obs[key] - raw_obs[key] = raw_obs[key].to(self._policy_model.device) - # raw_obs[key] = check(raw_obs[key], self.use_half, self.tpdv) - # if self._use_fp16: - # raw_obs[key] = raw_obs[key].half() + if self.use_data_parallel: + raw_obs[key] = raw_obs[key].to(self.device) + else: + raw_obs[key] = raw_obs[key].to(self._policy_model.device) + rnn_states = check(rnn_states) input_ids = raw_obs["input_encoded_pt"].int() @@ -131,11 +158,14 @@ def eval_actions( ): for key in obs.keys(): obs[key] = torch.from_numpy(obs[key]) if type(obs[key]) == np.ndarray else obs[key] - obs[key] = obs[key].to(self._policy_model.device) - # obs[key] = check(obs[key], self.use_half, self.tpdv) - # if self._use_fp16: - # obs[key] = obs[key].half() - action = check(action).to(self._policy_model.device).squeeze() + if self.use_data_parallel: + obs[key] = obs[key].to(self.device) + else: + obs[key] = obs[key].to(self._policy_model.device) + if self.use_data_parallel: + action = check(action).to(self.device).squeeze() + else: + action = check(action).to(self._policy_model.device).squeeze() rnn_states = check(rnn_states) input_ids = obs["input_encoded_pt"].int() diff --git a/openrl/modules/networks/value_network_gpt.py b/openrl/modules/networks/value_network_gpt.py index 13db87b8..4815cff7 100644 --- a/openrl/modules/networks/value_network_gpt.py +++ b/openrl/modules/networks/value_network_gpt.py @@ -45,8 +45,15 @@ def __init__( extra_args=None, ): + self.device = device self.use_half = use_half - self.tpdv = dict(dtype=torch.float32, device=device) + + self.use_data_parallel = False + self.use_model_parallel = False + self.use_deepspeed = cfg.use_deepspeed + assert not (self.use_deepspeed and self.use_data_parallel) + assert not (self.use_deepspeed and self.use_model_parallel) + assert not (self.use_data_parallel and self.use_model_parallel) super(ValueNetworkGPT, self).__init__(cfg, device) @@ -63,6 +70,15 @@ def __init__( self._value_head.to(self.device) + if torch.cuda.is_available(): + if self.use_model_parallel: + self._value_model.parallelize() + elif self.use_data_parallel: + self._value_model = torch.nn.DataParallel(self._value_model) + self._value_model = self._value_model.to(self.device) + self._value_head = torch.nn.DataParallel(self._value_head) + self._value_head = self._value_head.to(self.device) + def _prepare_inputs_for_model( self, @@ -73,16 +89,28 @@ def _prepare_inputs_for_model( model_inputs = unwrap_model(model).prepare_inputs_for_generation( input_ids, **model_kwargs ) + + if self.use_model_parallel: + model_inputs = { + key: ( + value.to(model.transformer.first_device) + if isinstance(value, torch.Tensor) + and hasattr(model.transformer, "first_device") + else value + ) + for key, value in model_inputs.items() + } + return model_inputs def forward(self, critic_obs, rnn_states, masks): for key in critic_obs.keys(): critic_obs[key] = torch.from_numpy(critic_obs[key]) if type(critic_obs[key]) == np.ndarray else critic_obs[key] - critic_obs[key] = critic_obs[key].to(self._value_model.device) - # critic_obs[key] = check(critic_obs[key], self.use_half, self.tpdv) - # if self._use_fp16: - # critic_obs[key] = critic_obs[key].half() - masks = check(masks).to(self._value_model.device) + if self.use_data_parallel: + critic_obs[key] = critic_obs[key].to(self.device) + else: + critic_obs[key] = critic_obs[key].to(self._value_model.device) + rnn_states = check(rnn_states) input_ids = critic_obs["input_encoded_pt"].int() @@ -99,6 +127,10 @@ def forward(self, critic_obs, rnn_states, masks): ) output = self._value_model(output_hidden_states=True, **model_inputs) last_tokens_hidden = output.hidden_states[-1][:, -1] + + if self.use_model_parallel: + last_tokens_hidden = last_tokens_hidden.to(self.device) + values = self._value_head.forward(last_tokens_hidden) return values, rnn_states diff --git a/openrl/modules/utils/valuenorm.py b/openrl/modules/utils/valuenorm.py index bed1d705..0367084a 100644 --- a/openrl/modules/utils/valuenorm.py +++ b/openrl/modules/utils/valuenorm.py @@ -24,15 +24,15 @@ def __init__( self.per_element_update = per_element_update self.tpdv = dict(dtype=torch.float32, device=device) - # self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) - # self.running_mean_sq = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) - # self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(**self.tpdv) - - self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False) - self.running_mean_sq = nn.Parameter( - torch.zeros(input_shape), requires_grad=False - ) - self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False) + self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) + self.running_mean_sq = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) + self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(**self.tpdv) + + # self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False) + # self.running_mean_sq = nn.Parameter( + # torch.zeros(input_shape), requires_grad=False + # ) + # self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False) self.reset_parameters() diff --git a/openrl/utils/logger.py b/openrl/utils/logger.py index 3fe61b53..d9c49f34 100644 --- a/openrl/utils/logger.py +++ b/openrl/utils/logger.py @@ -46,6 +46,10 @@ def __init__( self.use_wandb = use_wandb self.use_tensorboard = use_tensorboard + self.skip_logging = False + if cfg.use_deepspeed and cfg.local_rank != 0: + self.skip_logging = True + self.log_level = log_level self.log_path = log_path self.project_name = project_name @@ -126,20 +130,21 @@ def _init(self) -> None: ) if self.use_wandb: - wandb.init( - config=self.cfg, - project=self.project_name, - entity=self.wandb_entity, - notes=socket.gethostname(), - name=self.scenario_name - + "_" - + str(self.exp_name) - + "_seed" - + str(self.cfg.seed), - dir=str(run_dir), - job_type="training", - reinit=True, - ) + if not self.skip_logging: + wandb.init( + config=self.cfg, + project=self.project_name, + entity=self.wandb_entity, + notes=socket.gethostname(), + name=self.scenario_name + + "_" + + str(self.exp_name) + + "_seed" + + str(self.cfg.seed), + dir=str(run_dir), + job_type="training", + reinit=True, + ) elif self.use_tensorboard: from tensorboardX import SummaryWriter @@ -152,7 +157,8 @@ def _init(self) -> None: def close(self): if self.use_wandb: - wandb.finish() + if not self.skip_logging: + wandb.finish() def info(self, msg: str): logging.info(msg) @@ -167,7 +173,8 @@ def log_learner_info( return for k, v in infos.items(): if self.use_wandb: - wandb.log({"Learner_{}/{}".format(leaner_id, k): v}, step=step) + if not self.skip_logging: + wandb.log({"Learner_{}/{}".format(leaner_id, k): v}, step=step) elif self.use_tensorboard: self.writter.add_scalars( "Learner_{}/{}".format(leaner_id, k), @@ -192,7 +199,8 @@ def log_info( logging_info_str += f"\t{k}: {v}\n" if self.use_wandb: - wandb.log({k: v}, step=step) + if not self.skip_logging: + wandb.log({k: v}, step=step) elif self.use_tensorboard: self.writter.add_scalars(k, {k: v}, step) if self.log_to_terminal: From 05295b87dacc0ebc0a4b8a062beca77b04039429 Mon Sep 17 00:00:00 2001 From: Chen001117 Date: Tue, 19 Dec 2023 13:08:30 +0800 Subject: [PATCH 3/8] ds_support --- examples/nlp/ds_config.json | 4 +- examples/nlp/nlp_ppo.yaml | 10 ++--- examples/nlp/nlp_ppo_ds.yaml | 10 ++--- openrl/envs/nlp/rewards/intent.py | 29 +++++++++---- openrl/envs/nlp/rewards/kl_penalty.py | 29 ++++++++----- openrl/modules/networks/policy_network_gpt.py | 41 ++++++++++++------- openrl/modules/networks/value_network_gpt.py | 29 ++++++++----- 7 files changed, 96 insertions(+), 56 deletions(-) diff --git a/examples/nlp/ds_config.json b/examples/nlp/ds_config.json index 544bc405..3de0eb2d 100644 --- a/examples/nlp/ds_config.json +++ b/examples/nlp/ds_config.json @@ -3,9 +3,7 @@ "train_micro_batch_size_per_gpu": 16, "steps_per_print": 10, "zero_optimization": { - "stage": 2, - "reduce_bucket_size": 5e7, - "allgather_bucket_size": 5e7 + "stage": 2 }, "fp16": {"enabled": false, "loss_scale_window": 100} } \ No newline at end of file diff --git a/examples/nlp/nlp_ppo.yaml b/examples/nlp/nlp_ppo.yaml index 1ba77379..caf97bb2 100644 --- a/examples/nlp/nlp_ppo.yaml +++ b/examples/nlp/nlp_ppo.yaml @@ -12,18 +12,18 @@ num_mini_batch: 20 hidden_size: 1 -model_path: /home/chenwenze/data_server/huggingface/models/facebook/opt-125m +model_path: /rajkumarrrk/gpt2-fine-tuned-on-daily-dialog env: args: { - 'tokenizer_path': '/home/chenwenze/data_server/huggingface/models/facebook/opt-125m', - 'data_path': '/home/chenwenze/data_server/huggingface/datasets/daily_dialog', + 'tokenizer_path': 'gpt2', + 'data_path': 'daily_dialog', } vec_info_class: id: "NLPVecInfo" reward_class: id: "NLPReward" args: { - "ref_model": "/home/chenwenze/data_server/huggingface/models/facebook/opt-125m", - "intent_model": "/home/chenwenze/data_server/huggingface/models/rajkumarrrk/roberta-daily-dialog-intent-classifier", + "ref_model": "/rajkumarrrk/gpt2-fine-tuned-on-daily-dialog", + "intent_model": "/rajkumarrrk/roberta-daily-dialog-intent-classifier", } \ No newline at end of file diff --git a/examples/nlp/nlp_ppo_ds.yaml b/examples/nlp/nlp_ppo_ds.yaml index c9d4ad60..1f2ad7f8 100644 --- a/examples/nlp/nlp_ppo_ds.yaml +++ b/examples/nlp/nlp_ppo_ds.yaml @@ -17,11 +17,11 @@ use_fp16: false use_offload: false deepspeed_config: ds_config.json -model_path: /home/chenwenze/data_server/huggingface/models/facebook/opt-125m +model_path: /rajkumarrrk/gpt2-fine-tuned-on-daily-dialog/ env: args: { - 'tokenizer_path': '/home/chenwenze/data_server/huggingface/models/gpt2', - 'data_path': '/home/chenwenze/data_server/huggingface/datasets/daily_dialog', + 'tokenizer_path': 'gpt2', + 'data_path': 'daily_dialog', } vec_info_class: id: "NLPVecInfo" @@ -30,8 +30,8 @@ reward_class: args: { "use_deepspeed": true, "ref_ds_config": "eval_ds_config.json", - "ref_model": "/home/chenwenze/data_server/huggingface/models/facebook/opt-125m", + "ref_model": /rajkumarrrk/gpt2-fine-tuned-on-daily-dialog/, "intent_ds_config": "eval_ds_config.json", - "intent_model": "/home/chenwenze/data_server/huggingface/models/rajkumarrrk/roberta-daily-dialog-intent-classifier", + "intent_model": "/rajkumarrrk/roberta-daily-dialog-intent-classifier", } \ No newline at end of file diff --git a/openrl/envs/nlp/rewards/intent.py b/openrl/envs/nlp/rewards/intent.py index 0d449d13..f0929932 100644 --- a/openrl/envs/nlp/rewards/intent.py +++ b/openrl/envs/nlp/rewards/intent.py @@ -36,6 +36,10 @@ def __init__( self._intent_coeff = intent_coeff self.use_deepspeed = use_deepspeed + self.use_half = False + self.use_data_parallel = not use_deepspeed # default to use data parallel + self.use_model_parallel = False + if intent_model == "builtin_intent": from transformers import GPT2Config, GPT2LMHeadModel @@ -80,16 +84,16 @@ def __init__(self, input_ids, attention_mask): self._device = "cuda" self._model = self._model.to("cuda") self._model, *_ = deepspeed.initialize(model=self._model, config=ds_config) + self.use_fp16 = ds_config["fp16"]["enabled"] else: - if torch.cuda.is_available(): - manager = LocalGPUManager() - manager.log_info() - self._device = f"cuda:{manager.get_gpu()}" - else: - self._device = "cpu" - print("Intent Model choose to use device:{}".format(self._device)) - - self._model = self._model.to(self._device) + self._device = "cuda" + if self.use_model_parallel: + self._model.parallelize() + elif self.use_data_parallel: + if self.use_half: + self._model = self._model.half() + self._model = torch.nn.DataParallel(self._model) + self._model = self._model.to(self._device) def __call__( self, @@ -120,6 +124,13 @@ def get_input_for_classifier(prompt, generated_text): input_texts, return_tensors="pt", truncation=True, padding=True ) + if self.use_half: + encoded.input_ids = encoded.input_ids.int() + encoded.attention_mask = encoded.attention_mask.int() + else: + encoded.input_ids = encoded.input_ids.long() + encoded.attention_mask = encoded.attention_mask.long() + with torch.no_grad(): outputs = self._model( input_ids=encoded.input_ids.to(self._device), diff --git a/openrl/envs/nlp/rewards/kl_penalty.py b/openrl/envs/nlp/rewards/kl_penalty.py index 7f5a6426..406c9215 100644 --- a/openrl/envs/nlp/rewards/kl_penalty.py +++ b/openrl/envs/nlp/rewards/kl_penalty.py @@ -37,9 +37,10 @@ def __init__( super().__init__() self.device = "cuda" - self.use_data_parallel = False - self.use_model_parallel = False self.use_deepspeed = use_deepspeed + self.use_half = False + self.use_data_parallel = not use_deepspeed + self.use_model_parallel = False assert not (self.use_deepspeed and self.use_data_parallel) assert not (self.use_deepspeed and self.use_model_parallel) assert not (self.use_data_parallel and self.use_model_parallel) @@ -70,10 +71,12 @@ def __init__( self.use_fp16 = False self._ref_engine, *_ = deepspeed.initialize(model=self, config=ds_config) - elif torch.cuda.is_available(): + else: if self.use_model_parallel: self._ref_net.parallelize() elif self.use_data_parallel: # else defaults to data parallel + if self.use_half: + self._ref_net = self._ref_net.half() self._ref_net = torch.nn.DataParallel(self._ref_net) self._ref_net = self._ref_net.to(self.device) @@ -113,24 +116,30 @@ def __call__( self._ref_net, input_ids, past_model_kwargs ) - if self.use_deepspeed: - if self.use_fp16: - for key in ["input_ids", "position_ids"]: - model_inputs[key] = model_inputs[key].half().int() - for key in ["attention_mask"]: - model_inputs[key] = model_inputs[key].half() + if self.use_half: + for key in ["input_ids", "position_ids", "attention_mask"]: + if key in model_inputs: + model_inputs[key] = model_inputs[key].int() + else: + for key in ["input_ids", "position_ids", "attention_mask"]: + if key in model_inputs: + model_inputs[key] = model_inputs[key].long() + with torch.no_grad(): output = self._ref_net(output_hidden_states=True, **model_inputs) output["past_key_values"] = None next_token_logits = output.logits[:, -1, :] + if self.use_deepspeed and self.use_fp16: + next_token_logits = next_token_logits.double() dist = self._action_dist.proba_distribution(action_logits=next_token_logits) action_input = actions.to(next_token_logits.device) ref_log_prob = dist.log_prob(action_input) ref_log_prob = ref_log_prob.reshape(action_log_probs.shape) + kl_div = action_log_probs.copy() - ref_log_prob.detach().cpu().numpy() - rew = -self._alpha * kl_div + rew = -self._alpha * kl_div infos = [] for kl in kl_div: infos.append( diff --git a/openrl/modules/networks/policy_network_gpt.py b/openrl/modules/networks/policy_network_gpt.py index 5c97feef..0cda244e 100644 --- a/openrl/modules/networks/policy_network_gpt.py +++ b/openrl/modules/networks/policy_network_gpt.py @@ -48,11 +48,11 @@ def __init__( ) -> None: self.device = device - self.use_half = use_half - - self.use_data_parallel = False - self.use_model_parallel = False + self.use_fp16 = cfg.use_fp16 self.use_deepspeed = cfg.use_deepspeed + self.use_half = False + self.use_data_parallel = not cfg.use_deepspeed # default to use data parallel + self.use_model_parallel = False assert not (self.use_deepspeed and self.use_data_parallel) assert not (self.use_deepspeed and self.use_model_parallel) @@ -80,6 +80,8 @@ def __init__( if self.use_model_parallel: self._policy_model.parallelize() elif self.use_data_parallel: + if self.use_half: + self._policy_model = self._policy_model.half() self._policy_model = torch.nn.DataParallel(self._policy_model) self._policy_model = self._policy_model.to(self.device) @@ -120,15 +122,22 @@ def forward_original( ): for key in raw_obs.keys(): raw_obs[key] = torch.from_numpy(raw_obs[key]) if type(raw_obs[key]) == np.ndarray else raw_obs[key] + rnn_states = check(rnn_states) + + if self.use_half: + input_ids = raw_obs["input_encoded_pt"].int() + attention_mask = raw_obs["input_attention_mask_pt"].int() + else: + input_ids = raw_obs["input_encoded_pt"].long() + attention_mask = raw_obs["input_attention_mask_pt"].long() + + for key in raw_obs.keys(): if self.use_data_parallel: - raw_obs[key] = raw_obs[key].to(self.device) + input_ids = input_ids.to(self.device) + attention_mask = attention_mask.to(self.device) else: - raw_obs[key] = raw_obs[key].to(self._policy_model.device) - - rnn_states = check(rnn_states) - - input_ids = raw_obs["input_encoded_pt"].int() - attention_mask = raw_obs["input_attention_mask_pt"] + input_ids = input_ids.to(self._policy_model.device) + attention_mask = attention_mask.to(self._policy_model.device) past_model_kwargs = None @@ -145,7 +154,7 @@ def forward_original( output = self._policy_model(**model_inputs) # compute action probs - policy head - next_token_logits = output.logits[:, -1] + next_token_logits = output.logits[:, -1] dist = self._action_dist.proba_distribution(action_logits=next_token_logits) actions = dist.mode() if deterministic else dist.sample() @@ -168,8 +177,12 @@ def eval_actions( action = check(action).to(self._policy_model.device).squeeze() rnn_states = check(rnn_states) - input_ids = obs["input_encoded_pt"].int() - attention_mask = obs["input_attention_mask_pt"] + if self.half: + input_ids = obs["input_encoded_pt"].int() + attention_mask = obs["input_attention_mask_pt"].int() + else: + input_ids = obs["input_encoded_pt"].long() + attention_mask = obs["input_attention_mask_pt"].long() past_model_kwargs = None diff --git a/openrl/modules/networks/value_network_gpt.py b/openrl/modules/networks/value_network_gpt.py index 4815cff7..b4ed9b1c 100644 --- a/openrl/modules/networks/value_network_gpt.py +++ b/openrl/modules/networks/value_network_gpt.py @@ -46,11 +46,12 @@ def __init__( ): self.device = device - self.use_half = use_half - self.use_data_parallel = False - self.use_model_parallel = False + self.use_fp16 = cfg.use_fp16 self.use_deepspeed = cfg.use_deepspeed + self.use_half = False + self.use_data_parallel = not cfg.use_deepspeed + self.use_model_parallel = False assert not (self.use_deepspeed and self.use_data_parallel) assert not (self.use_deepspeed and self.use_model_parallel) assert not (self.use_data_parallel and self.use_model_parallel) @@ -62,18 +63,22 @@ def __init__( self._value_model = AutoModelForCausalLM.from_pretrained(cfg.model_path) self._value_model.config.use_cache = False self._value_head = nn.Linear( - self._value_model.config.hidden_size, 1, bias=False + self._value_model.config.n_embd, 1, bias=False # gpt2 + # self._value_model.config.word_embed_proj_dim, 1, bias=False # opt-x ) self.value_normalizer = ( ValueNorm(1, device=device) if self._use_valuenorm else None ) - self._value_head.to(self.device) - - if torch.cuda.is_available(): + if self.use_deepspeed: + self._value_head.to(self.device) + else: if self.use_model_parallel: self._value_model.parallelize() elif self.use_data_parallel: + if self.use_half: + self._value_model = self._value_model.half() + self._value_head = self._value_head.half() self._value_model = torch.nn.DataParallel(self._value_model) self._value_model = self._value_model.to(self.device) self._value_head = torch.nn.DataParallel(self._value_head) @@ -113,9 +118,13 @@ def forward(self, critic_obs, rnn_states, masks): rnn_states = check(rnn_states) - input_ids = critic_obs["input_encoded_pt"].int() - attention_mask = critic_obs["input_attention_mask_pt"] - + if self.use_half: + input_ids = critic_obs["input_encoded_pt"].int() + attention_mask = critic_obs["input_attention_mask_pt"].int() + else: + input_ids = critic_obs["input_encoded_pt"].long() + attention_mask = critic_obs["input_attention_mask_pt"].long() + past_model_kwargs = None if not past_model_kwargs: past_model_kwargs = { From cd7f5b075839e719fb66117259bcf98b750e60d5 Mon Sep 17 00:00:00 2001 From: Chen001117 Date: Tue, 19 Dec 2023 13:08:48 +0800 Subject: [PATCH 4/8] meteor_init_bug --- openrl/envs/nlp/daily_dialog_env.py | 18 +++++++++++++++--- openrl/envs/vec_env/wrappers/reward_wrapper.py | 5 ++--- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/openrl/envs/nlp/daily_dialog_env.py b/openrl/envs/nlp/daily_dialog_env.py index 61e68946..98dd4d85 100644 --- a/openrl/envs/nlp/daily_dialog_env.py +++ b/openrl/envs/nlp/daily_dialog_env.py @@ -113,8 +113,20 @@ def __init__( self.__time_step = None self.reward_function = None - def set_reward(self, reward_fn): - self.reward_function = reward_fn + self.set_reward() + + def set_reward(self, reward_fn=None): + + from openrl.envs.nlp.rewards.meteor import Meteor + meteor_config = { + "meteor_coeff": 0.5, + "test": False, + } + self.reward_function = { + "meteor": Meteor(**meteor_config), + } + + # self.reward_function = reward_fn def step_word(self, word: str) -> Tuple[Dict[str, torch.tensor], int, bool, dict]: action = self.tokenizer.encode(word)[1] @@ -135,7 +147,7 @@ def step( done = done or self.__current_obs.context_text.endswith(DailyDialog.EOU_TOKEN) reward = 0.0 - reward_info = dict() + reward_info = dict() if done and self.reward_function: for reward_function in self.reward_function.values(): diff --git a/openrl/envs/vec_env/wrappers/reward_wrapper.py b/openrl/envs/vec_env/wrappers/reward_wrapper.py index d0a4d630..2b5ca266 100644 --- a/openrl/envs/vec_env/wrappers/reward_wrapper.py +++ b/openrl/envs/vec_env/wrappers/reward_wrapper.py @@ -24,13 +24,12 @@ from openrl.envs.vec_env.wrappers.base_wrapper import VecEnvWrapper from openrl.rewards.base_reward import BaseReward - class RewardWrapper(VecEnvWrapper): def __init__(self, env: BaseVecEnv, reward_class: BaseReward): super().__init__(env) self.reward_class = reward_class - if len(self.reward_class.inner_rew_funcs) > 0: - env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs}) + # if len(self.reward_class.inner_rew_funcs) > 0: + # env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs}) def step( self, action: ActType, extra_data: Optional[Dict[str, Any]] From 70163284cf9976ab6c15a5b6cd19f6486eed0911 Mon Sep 17 00:00:00 2001 From: Wen-Tse Chen Date: Tue, 19 Dec 2023 21:45:13 -0500 Subject: [PATCH 5/8] meteor_init_bug --- examples/nlp/nlp_ppo.yaml | 6 +++--- examples/nlp/nlp_ppo_ds.yaml | 6 +++--- openrl/rewards/nlp_reward.py | 18 +++++++++++------- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/examples/nlp/nlp_ppo.yaml b/examples/nlp/nlp_ppo.yaml index caf97bb2..918a75b8 100644 --- a/examples/nlp/nlp_ppo.yaml +++ b/examples/nlp/nlp_ppo.yaml @@ -12,7 +12,7 @@ num_mini_batch: 20 hidden_size: 1 -model_path: /rajkumarrrk/gpt2-fine-tuned-on-daily-dialog +model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog env: args: { 'tokenizer_path': 'gpt2', @@ -23,7 +23,7 @@ vec_info_class: reward_class: id: "NLPReward" args: { - "ref_model": "/rajkumarrrk/gpt2-fine-tuned-on-daily-dialog", - "intent_model": "/rajkumarrrk/roberta-daily-dialog-intent-classifier", + "ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog", + "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier", } \ No newline at end of file diff --git a/examples/nlp/nlp_ppo_ds.yaml b/examples/nlp/nlp_ppo_ds.yaml index 1f2ad7f8..88dac18c 100644 --- a/examples/nlp/nlp_ppo_ds.yaml +++ b/examples/nlp/nlp_ppo_ds.yaml @@ -17,7 +17,7 @@ use_fp16: false use_offload: false deepspeed_config: ds_config.json -model_path: /rajkumarrrk/gpt2-fine-tuned-on-daily-dialog/ +model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog env: args: { 'tokenizer_path': 'gpt2', @@ -30,8 +30,8 @@ reward_class: args: { "use_deepspeed": true, "ref_ds_config": "eval_ds_config.json", - "ref_model": /rajkumarrrk/gpt2-fine-tuned-on-daily-dialog/, + "ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog", "intent_ds_config": "eval_ds_config.json", - "intent_model": "/rajkumarrrk/roberta-daily-dialog-intent-classifier", + "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier", } \ No newline at end of file diff --git a/openrl/rewards/nlp_reward.py b/openrl/rewards/nlp_reward.py index 51c76fb3..bedfcc59 100644 --- a/openrl/rewards/nlp_reward.py +++ b/openrl/rewards/nlp_reward.py @@ -22,13 +22,17 @@ def __init__( self.rew_infos = [] self.env_infos = [] - meteor_config = { - "meteor_coeff": 0.5, - "test": ref_model == "builtin_ref", - } - self.inner_rew_funcs = { - "meteor": Meteor(**meteor_config), - } + # bug unfixed + self.inner_rew_funcs = dict() + + # meteor_config = { + # "meteor_coeff": 0.5, + # "test": ref_model == "builtin_ref", + # } + # self.inner_rew_funcs = { + # "meteor": Meteor(**meteor_config), + # } + kl_config = { "action_space": env.action_space, From 3af758829a1c17ff17e5e05df29df4eb3e11e251 Mon Sep 17 00:00:00 2001 From: Wen-Tse Chen Date: Tue, 19 Dec 2023 22:08:13 -0500 Subject: [PATCH 6/8] update format --- examples/nlp/train_ppo.py | 2 +- openrl/envs/nlp/daily_dialog_env.py | 29 ++++----- openrl/envs/nlp/fake_dialog_env.py | 22 +++---- openrl/envs/nlp/rewards/intent.py | 2 +- openrl/envs/nlp/rewards/kl_penalty.py | 23 +++---- openrl/envs/nlp/utils/metrics/meteor.py | 24 +++---- .../envs/vec_env/wrappers/reward_wrapper.py | 1 + openrl/modules/networks/policy_network_gpt.py | 64 ++++++++++--------- openrl/modules/networks/value_network_gpt.py | 33 ++++++---- openrl/rewards/nlp_reward.py | 3 +- 10 files changed, 100 insertions(+), 103 deletions(-) diff --git a/examples/nlp/train_ppo.py b/examples/nlp/train_ppo.py index 18347a6b..4fefcf52 100644 --- a/examples/nlp/train_ppo.py +++ b/examples/nlp/train_ppo.py @@ -3,8 +3,8 @@ from openrl.configs.config import create_config_parser from openrl.envs.common import make from openrl.modules.common import PPONet as Net -from openrl.modules.networks.value_network_gpt import ValueNetworkGPT as ValueNetwork from openrl.modules.networks.policy_network_gpt import PolicyNetworkGPT as PolicyNetwork +from openrl.modules.networks.value_network_gpt import ValueNetworkGPT as ValueNetwork from openrl.runners.common import PPOAgent as Agent diff --git a/openrl/envs/nlp/daily_dialog_env.py b/openrl/envs/nlp/daily_dialog_env.py index 98dd4d85..332db319 100644 --- a/openrl/envs/nlp/daily_dialog_env.py +++ b/openrl/envs/nlp/daily_dialog_env.py @@ -72,18 +72,16 @@ def __init__( # set the observation and action space here self._vocab_size = self.tokenizer.vocab_size - self.observation_space = DictSpace( - { - "input_encoded_pt": spaces.Box( - low=0, - high=self._vocab_size, - shape=(self._max_text_length + self.max_steps,), - ), - "input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self._max_text_length + self.max_steps,) - ), - } - ) + self.observation_space = DictSpace({ + "input_encoded_pt": spaces.Box( + low=0, + high=self._vocab_size, + shape=(self._max_text_length + self.max_steps,), + ), + "input_attention_mask_pt": spaces.Box( + low=0, high=1, shape=(self._max_text_length + self.max_steps,) + ), + }) self.action_space = Discrete(n=self._vocab_size) # see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency @@ -116,8 +114,9 @@ def __init__( self.set_reward() def set_reward(self, reward_fn=None): - + from openrl.envs.nlp.rewards.meteor import Meteor + meteor_config = { "meteor_coeff": 0.5, "test": False, @@ -125,7 +124,7 @@ def set_reward(self, reward_fn=None): self.reward_function = { "meteor": Meteor(**meteor_config), } - + # self.reward_function = reward_fn def step_word(self, word: str) -> Tuple[Dict[str, torch.tensor], int, bool, dict]: @@ -147,7 +146,7 @@ def step( done = done or self.__current_obs.context_text.endswith(DailyDialog.EOU_TOKEN) reward = 0.0 - reward_info = dict() + reward_info = dict() if done and self.reward_function: for reward_function in self.reward_function.values(): diff --git a/openrl/envs/nlp/fake_dialog_env.py b/openrl/envs/nlp/fake_dialog_env.py index 02247bc0..27f9d8f4 100644 --- a/openrl/envs/nlp/fake_dialog_env.py +++ b/openrl/envs/nlp/fake_dialog_env.py @@ -30,18 +30,16 @@ def __init__( # set the observation and action space here self._vocab_size = 2 - self.observation_space = DictSpace( - { - "input_encoded_pt": spaces.Box( - low=0, - high=self._vocab_size, - shape=(self._max_text_length + self.max_steps,), - ), - "input_attention_mask_pt": spaces.Box( - low=0, high=1, shape=(self._max_text_length + self.max_steps,) - ), - } - ) + self.observation_space = DictSpace({ + "input_encoded_pt": spaces.Box( + low=0, + high=self._vocab_size, + shape=(self._max_text_length + self.max_steps,), + ), + "input_attention_mask_pt": spaces.Box( + low=0, high=1, shape=(self._max_text_length + self.max_steps,) + ), + }) self.action_space = Discrete(n=self._vocab_size) n = 2 diff --git a/openrl/envs/nlp/rewards/intent.py b/openrl/envs/nlp/rewards/intent.py index f0929932..0a0c4d3e 100644 --- a/openrl/envs/nlp/rewards/intent.py +++ b/openrl/envs/nlp/rewards/intent.py @@ -37,7 +37,7 @@ def __init__( self._intent_coeff = intent_coeff self.use_deepspeed = use_deepspeed self.use_half = False - self.use_data_parallel = not use_deepspeed # default to use data parallel + self.use_data_parallel = not use_deepspeed # default to use data parallel self.use_model_parallel = False if intent_model == "builtin_intent": diff --git a/openrl/envs/nlp/rewards/kl_penalty.py b/openrl/envs/nlp/rewards/kl_penalty.py index 406c9215..9516b788 100644 --- a/openrl/envs/nlp/rewards/kl_penalty.py +++ b/openrl/envs/nlp/rewards/kl_penalty.py @@ -35,7 +35,7 @@ def __init__( ds_config: str = "default", ): super().__init__() - + self.device = "cuda" self.use_deepspeed = use_deepspeed self.use_half = False @@ -116,7 +116,7 @@ def __call__( self._ref_net, input_ids, past_model_kwargs ) - if self.use_half: + if self.use_half: for key in ["input_ids", "position_ids", "attention_mask"]: if key in model_inputs: model_inputs[key] = model_inputs[key].int() @@ -125,7 +125,6 @@ def __call__( if key in model_inputs: model_inputs[key] = model_inputs[key].long() - with torch.no_grad(): output = self._ref_net(output_hidden_states=True, **model_inputs) output["past_key_values"] = None @@ -139,15 +138,13 @@ def __call__( ref_log_prob = ref_log_prob.reshape(action_log_probs.shape) kl_div = action_log_probs.copy() - ref_log_prob.detach().cpu().numpy() - rew = -self._alpha * kl_div + rew = -self._alpha * kl_div infos = [] for kl in kl_div: - infos.append( - { - "alpha": self._alpha, - "kl_div": kl.mean(), - } - ) + infos.append({ + "alpha": self._alpha, + "kl_div": kl.mean(), + }) return rew, infos def _prepare_inputs_for_model( @@ -173,11 +170,7 @@ def _prepare_inputs_for_model( } elif self.use_data_parallel: model_inputs = { - key: ( - value.to(self.device) - if isinstance(value, torch.Tensor) - else value - ) + key: value.to(self.device) if isinstance(value, torch.Tensor) else value for key, value in model_inputs.items() } elif self.use_deepspeed: diff --git a/openrl/envs/nlp/utils/metrics/meteor.py b/openrl/envs/nlp/utils/metrics/meteor.py index ab15e66d..c2345fa9 100644 --- a/openrl/envs/nlp/utils/metrics/meteor.py +++ b/openrl/envs/nlp/utils/metrics/meteor.py @@ -88,20 +88,16 @@ def _info(self): citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, features=[ - datasets.Features( - { - "predictions": datasets.Value("string", id="sequence"), - "references": datasets.Sequence( - datasets.Value("string", id="sequence"), id="references" - ), - } - ), - datasets.Features( - { - "predictions": datasets.Value("string", id="sequence"), - "references": datasets.Value("string", id="sequence"), - } - ), + datasets.Features({ + "predictions": datasets.Value("string", id="sequence"), + "references": datasets.Sequence( + datasets.Value("string", id="sequence"), id="references" + ), + }), + datasets.Features({ + "predictions": datasets.Value("string", id="sequence"), + "references": datasets.Value("string", id="sequence"), + }), ], codebase_urls=[ "https://github.com/nltk/nltk/blob/develop/nltk/translate/meteor_score.py" diff --git a/openrl/envs/vec_env/wrappers/reward_wrapper.py b/openrl/envs/vec_env/wrappers/reward_wrapper.py index 2b5ca266..25cdc424 100644 --- a/openrl/envs/vec_env/wrappers/reward_wrapper.py +++ b/openrl/envs/vec_env/wrappers/reward_wrapper.py @@ -24,6 +24,7 @@ from openrl.envs.vec_env.wrappers.base_wrapper import VecEnvWrapper from openrl.rewards.base_reward import BaseReward + class RewardWrapper(VecEnvWrapper): def __init__(self, env: BaseVecEnv, reward_class: BaseReward): super().__init__(env) diff --git a/openrl/modules/networks/policy_network_gpt.py b/openrl/modules/networks/policy_network_gpt.py index 0cda244e..906f1fb5 100644 --- a/openrl/modules/networks/policy_network_gpt.py +++ b/openrl/modules/networks/policy_network_gpt.py @@ -15,13 +15,15 @@ # limitations under the License. """""" -from typing import Any, Optional, Dict +from typing import Any, Dict, Optional import numpy as np import torch import torch.nn as nn +from transformers.modeling_utils import unwrap_model from openrl.buffers.utils.util import get_policy_obs, get_policy_obs_space +from openrl.envs.nlp.utils.distribution import CategoricalDistribution from openrl.modules.networks.base_policy_network import BasePolicyNetwork from openrl.modules.networks.utils.act import ACTLayer from openrl.modules.networks.utils.cnn import CNNBase @@ -31,9 +33,7 @@ from openrl.modules.networks.utils.rnn import RNNLayer from openrl.modules.networks.utils.util import init from openrl.utils.util import check_v2 as check -from openrl.envs.nlp.utils.distribution import CategoricalDistribution -from transformers.modeling_utils import unwrap_model class PolicyNetworkGPT(BasePolicyNetwork): def __init__( @@ -46,25 +46,26 @@ def __init__( disable_drop_out: bool = True, extra_args=None, ) -> None: - + self.device = device self.use_fp16 = cfg.use_fp16 self.use_deepspeed = cfg.use_deepspeed self.use_half = False - self.use_data_parallel = not cfg.use_deepspeed # default to use data parallel + self.use_data_parallel = not cfg.use_deepspeed # default to use data parallel self.use_model_parallel = False assert not (self.use_deepspeed and self.use_data_parallel) assert not (self.use_deepspeed and self.use_model_parallel) assert not (self.use_data_parallel and self.use_model_parallel) - + super(PolicyNetworkGPT, self).__init__(cfg, device) - + self.disable_drop_out = disable_drop_out - + self._action_dist = CategoricalDistribution(action_space.n) - + from transformers import AutoConfig, AutoModelForCausalLM + config = AutoConfig.from_pretrained(cfg.model_path) config_dict = config.to_dict() for key in config_dict: @@ -85,7 +86,6 @@ def __init__( self._policy_model = torch.nn.DataParallel(self._policy_model) self._policy_model = self._policy_model.to(self.device) - def forward(self, forward_type, *args, **kwargs): if forward_type == "original": return self.forward_original(*args, **kwargs) @@ -93,7 +93,7 @@ def forward(self, forward_type, *args, **kwargs): return self.eval_actions(*args, **kwargs) else: raise NotImplementedError - + def _prepare_inputs_for_model( self, model: Any, @@ -121,7 +121,11 @@ def forward_original( self, raw_obs, rnn_states, masks, action_masks=None, deterministic=False ): for key in raw_obs.keys(): - raw_obs[key] = torch.from_numpy(raw_obs[key]) if type(raw_obs[key]) == np.ndarray else raw_obs[key] + raw_obs[key] = ( + torch.from_numpy(raw_obs[key]) + if type(raw_obs[key]) == np.ndarray + else raw_obs[key] + ) rnn_states = check(rnn_states) if self.use_half: @@ -138,35 +142,37 @@ def forward_original( else: input_ids = input_ids.to(self._policy_model.device) attention_mask = attention_mask.to(self._policy_model.device) - + past_model_kwargs = None - + if past_model_kwargs is None: past_model_kwargs = { "attention_mask": attention_mask, } - + model_inputs = self._prepare_inputs_for_model( self._policy_model, input_ids, past_model_kwargs ) - + # forward pass to transformers output = self._policy_model(**model_inputs) - + # compute action probs - policy head - next_token_logits = output.logits[:, -1] + next_token_logits = output.logits[:, -1] dist = self._action_dist.proba_distribution(action_logits=next_token_logits) - + actions = dist.mode() if deterministic else dist.sample() action_log_probs = dist.log_prob(actions) - + return actions.unsqueeze(-1), action_log_probs.unsqueeze(-1), rnn_states def eval_actions( self, obs, rnn_states, action, masks, action_masks=None, active_masks=None ): for key in obs.keys(): - obs[key] = torch.from_numpy(obs[key]) if type(obs[key]) == np.ndarray else obs[key] + obs[key] = ( + torch.from_numpy(obs[key]) if type(obs[key]) == np.ndarray else obs[key] + ) if self.use_data_parallel: obs[key] = obs[key].to(self.device) else: @@ -176,32 +182,32 @@ def eval_actions( else: action = check(action).to(self._policy_model.device).squeeze() rnn_states = check(rnn_states) - + if self.half: input_ids = obs["input_encoded_pt"].int() attention_mask = obs["input_attention_mask_pt"].int() else: input_ids = obs["input_encoded_pt"].long() attention_mask = obs["input_attention_mask_pt"].long() - + past_model_kwargs = None - + if past_model_kwargs is None: past_model_kwargs = { "attention_mask": attention_mask, } - + model_inputs = self._prepare_inputs_for_model( self._policy_model, input_ids, past_model_kwargs ) - + # forward pass to transformers output = self._policy_model(**model_inputs) - + # compute action probs - policy head next_token_logits = output.logits[:, -1] dist = self._action_dist.proba_distribution(action_logits=next_token_logits) - + action_log_probs = dist.log_prob(action) dist_entropy = dist.entropy() values = None @@ -209,4 +215,4 @@ def eval_actions( return action_log_probs.unsqueeze(-1), dist_entropy.mean(), values def get_policy_values(self, obs, rnn_states, masks): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/openrl/modules/networks/value_network_gpt.py b/openrl/modules/networks/value_network_gpt.py index b4ed9b1c..afffffc2 100644 --- a/openrl/modules/networks/value_network_gpt.py +++ b/openrl/modules/networks/value_network_gpt.py @@ -15,11 +15,12 @@ # limitations under the License. """""" -from typing import Any, Optional, Dict +from typing import Any, Dict, Optional import numpy as np import torch import torch.nn as nn +from transformers.modeling_utils import unwrap_model from openrl.buffers.utils.util import get_critic_obs_space from openrl.modules.networks.base_value_network import BaseValueNetwork @@ -32,7 +33,6 @@ from openrl.modules.utils.valuenorm import ValueNorm from openrl.utils.util import check_v2 as check -from transformers.modeling_utils import unwrap_model class ValueNetworkGPT(BaseValueNetwork): def __init__( @@ -44,7 +44,7 @@ def __init__( device=torch.device("cpu"), extra_args=None, ): - + self.device = device self.use_fp16 = cfg.use_fp16 @@ -55,21 +55,23 @@ def __init__( assert not (self.use_deepspeed and self.use_data_parallel) assert not (self.use_deepspeed and self.use_model_parallel) assert not (self.use_data_parallel and self.use_model_parallel) - + super(ValueNetworkGPT, self).__init__(cfg, device) - + from transformers import AutoModelForCausalLM - + self._value_model = AutoModelForCausalLM.from_pretrained(cfg.model_path) self._value_model.config.use_cache = False self._value_head = nn.Linear( - self._value_model.config.n_embd, 1, bias=False # gpt2 + self._value_model.config.n_embd, + 1, + bias=False, # gpt2 # self._value_model.config.word_embed_proj_dim, 1, bias=False # opt-x ) self.value_normalizer = ( ValueNorm(1, device=device) if self._use_valuenorm else None ) - + if self.use_deepspeed: self._value_head.to(self.device) else: @@ -84,7 +86,6 @@ def __init__( self._value_head = torch.nn.DataParallel(self._value_head) self._value_head = self._value_head.to(self.device) - def _prepare_inputs_for_model( self, model: Any, @@ -105,19 +106,23 @@ def _prepare_inputs_for_model( ) for key, value in model_inputs.items() } - + return model_inputs def forward(self, critic_obs, rnn_states, masks): for key in critic_obs.keys(): - critic_obs[key] = torch.from_numpy(critic_obs[key]) if type(critic_obs[key]) == np.ndarray else critic_obs[key] + critic_obs[key] = ( + torch.from_numpy(critic_obs[key]) + if type(critic_obs[key]) == np.ndarray + else critic_obs[key] + ) if self.use_data_parallel: critic_obs[key] = critic_obs[key].to(self.device) else: critic_obs[key] = critic_obs[key].to(self._value_model.device) - + rnn_states = check(rnn_states) - + if self.use_half: input_ids = critic_obs["input_encoded_pt"].int() attention_mask = critic_obs["input_attention_mask_pt"].int() @@ -130,7 +135,7 @@ def forward(self, critic_obs, rnn_states, masks): past_model_kwargs = { "attention_mask": attention_mask, } - + model_inputs = self._prepare_inputs_for_model( self._value_model, input_ids, past_model_kwargs ) diff --git a/openrl/rewards/nlp_reward.py b/openrl/rewards/nlp_reward.py index bedfcc59..38cd306a 100644 --- a/openrl/rewards/nlp_reward.py +++ b/openrl/rewards/nlp_reward.py @@ -24,7 +24,7 @@ def __init__( # bug unfixed self.inner_rew_funcs = dict() - + # meteor_config = { # "meteor_coeff": 0.5, # "test": ref_model == "builtin_ref", @@ -32,7 +32,6 @@ def __init__( # self.inner_rew_funcs = { # "meteor": Meteor(**meteor_config), # } - kl_config = { "action_space": env.action_space, From f8879b3ec171b17d16bed8a72b6fe80f4690a0cc Mon Sep 17 00:00:00 2001 From: Wen-Tse Chen Date: Wed, 20 Dec 2023 00:22:57 -0500 Subject: [PATCH 7/8] fix test w/o gpu bug --- openrl/envs/nlp/rewards/intent.py | 9 ++++++--- openrl/envs/nlp/rewards/kl_penalty.py | 9 +++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/openrl/envs/nlp/rewards/intent.py b/openrl/envs/nlp/rewards/intent.py index 0a0c4d3e..2c82e96f 100644 --- a/openrl/envs/nlp/rewards/intent.py +++ b/openrl/envs/nlp/rewards/intent.py @@ -41,6 +41,10 @@ def __init__( self.use_model_parallel = False if intent_model == "builtin_intent": + + self._device = "cpu" + self.use_data_parallel = False + from transformers import GPT2Config, GPT2LMHeadModel class TestTokenizer: @@ -66,6 +70,7 @@ def __init__(self, input_ids, attention_mask): self._model = GPT2LMHeadModel(config) else: + self._device = "cuda" model_path = data_abs_path(intent_model) self._tokenizer = AutoTokenizer.from_pretrained(intent_model) self._model = AutoModelForSequenceClassification.from_pretrained(model_path) @@ -81,12 +86,10 @@ def __init__(self, input_ids, attention_mask): with open(ds_config) as file: ds_config = json.load(file) - self._device = "cuda" - self._model = self._model.to("cuda") + self._model = self._model.to(self._device) self._model, *_ = deepspeed.initialize(model=self._model, config=ds_config) self.use_fp16 = ds_config["fp16"]["enabled"] else: - self._device = "cuda" if self.use_model_parallel: self._model.parallelize() elif self.use_data_parallel: diff --git a/openrl/envs/nlp/rewards/kl_penalty.py b/openrl/envs/nlp/rewards/kl_penalty.py index 9516b788..3cfafd4b 100644 --- a/openrl/envs/nlp/rewards/kl_penalty.py +++ b/openrl/envs/nlp/rewards/kl_penalty.py @@ -47,6 +47,10 @@ def __init__( # reference model if ref_model == "builtin_ref": + + self.device = "cpu" + self.use_data_parallel = False + from transformers import GPT2Config, GPT2LMHeadModel config = GPT2Config() @@ -77,8 +81,9 @@ def __init__( elif self.use_data_parallel: # else defaults to data parallel if self.use_half: self._ref_net = self._ref_net.half() - self._ref_net = torch.nn.DataParallel(self._ref_net) - self._ref_net = self._ref_net.to(self.device) + else: + self._ref_net = torch.nn.DataParallel(self._ref_net) + self._ref_net = self._ref_net.to(self.device) # alpha adjustment self._alpha = 0.2 From fc020301258336a13d9d2a20d14477e38f485879 Mon Sep 17 00:00:00 2001 From: Wen-Tse Chen Date: Wed, 20 Dec 2023 00:35:36 -0500 Subject: [PATCH 8/8] fix set reward bug --- openrl/envs/nlp/daily_dialog_env.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/openrl/envs/nlp/daily_dialog_env.py b/openrl/envs/nlp/daily_dialog_env.py index 332db319..2aa08684 100644 --- a/openrl/envs/nlp/daily_dialog_env.py +++ b/openrl/envs/nlp/daily_dialog_env.py @@ -111,21 +111,9 @@ def __init__( self.__time_step = None self.reward_function = None - self.set_reward() - def set_reward(self, reward_fn=None): - from openrl.envs.nlp.rewards.meteor import Meteor - - meteor_config = { - "meteor_coeff": 0.5, - "test": False, - } - self.reward_function = { - "meteor": Meteor(**meteor_config), - } - - # self.reward_function = reward_fn + self.reward_function = reward_fn def step_word(self, word: str) -> Tuple[Dict[str, torch.tensor], int, bool, dict]: action = self.tokenizer.encode(word)[1]