Skip to content

Commit

Permalink
update format
Browse files Browse the repository at this point in the history
  • Loading branch information
Wen-Tse Chen committed Dec 20, 2023
1 parent d470127 commit 3af7588
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 103 deletions.
2 changes: 1 addition & 1 deletion examples/nlp/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.modules.networks.value_network_gpt import ValueNetworkGPT as ValueNetwork
from openrl.modules.networks.policy_network_gpt import PolicyNetworkGPT as PolicyNetwork
from openrl.modules.networks.value_network_gpt import ValueNetworkGPT as ValueNetwork
from openrl.runners.common import PPOAgent as Agent


Expand Down
29 changes: 14 additions & 15 deletions openrl/envs/nlp/daily_dialog_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,16 @@ def __init__(
# set the observation and action space here
self._vocab_size = self.tokenizer.vocab_size

self.observation_space = DictSpace(
{
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
}
)
self.observation_space = DictSpace({
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
})
self.action_space = Discrete(n=self._vocab_size)
# see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency

Expand Down Expand Up @@ -116,16 +114,17 @@ def __init__(
self.set_reward()

def set_reward(self, reward_fn=None):

from openrl.envs.nlp.rewards.meteor import Meteor

meteor_config = {
"meteor_coeff": 0.5,
"test": False,
}
self.reward_function = {
"meteor": Meteor(**meteor_config),
}

# self.reward_function = reward_fn

def step_word(self, word: str) -> Tuple[Dict[str, torch.tensor], int, bool, dict]:
Expand All @@ -147,7 +146,7 @@ def step(
done = done or self.__current_obs.context_text.endswith(DailyDialog.EOU_TOKEN)

reward = 0.0
reward_info = dict()
reward_info = dict()

if done and self.reward_function:
for reward_function in self.reward_function.values():
Expand Down
22 changes: 10 additions & 12 deletions openrl/envs/nlp/fake_dialog_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,16 @@ def __init__(
# set the observation and action space here
self._vocab_size = 2

self.observation_space = DictSpace(
{
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
}
)
self.observation_space = DictSpace({
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
})
self.action_space = Discrete(n=self._vocab_size)

n = 2
Expand Down
2 changes: 1 addition & 1 deletion openrl/envs/nlp/rewards/intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self._intent_coeff = intent_coeff
self.use_deepspeed = use_deepspeed
self.use_half = False
self.use_data_parallel = not use_deepspeed # default to use data parallel
self.use_data_parallel = not use_deepspeed # default to use data parallel
self.use_model_parallel = False

if intent_model == "builtin_intent":
Expand Down
23 changes: 8 additions & 15 deletions openrl/envs/nlp/rewards/kl_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
ds_config: str = "default",
):
super().__init__()

self.device = "cuda"
self.use_deepspeed = use_deepspeed
self.use_half = False
Expand Down Expand Up @@ -116,7 +116,7 @@ def __call__(
self._ref_net, input_ids, past_model_kwargs
)

if self.use_half:
if self.use_half:
for key in ["input_ids", "position_ids", "attention_mask"]:
if key in model_inputs:
model_inputs[key] = model_inputs[key].int()
Expand All @@ -125,7 +125,6 @@ def __call__(
if key in model_inputs:
model_inputs[key] = model_inputs[key].long()


with torch.no_grad():
output = self._ref_net(output_hidden_states=True, **model_inputs)
output["past_key_values"] = None
Expand All @@ -139,15 +138,13 @@ def __call__(
ref_log_prob = ref_log_prob.reshape(action_log_probs.shape)

kl_div = action_log_probs.copy() - ref_log_prob.detach().cpu().numpy()
rew = -self._alpha * kl_div
rew = -self._alpha * kl_div
infos = []
for kl in kl_div:
infos.append(
{
"alpha": self._alpha,
"kl_div": kl.mean(),
}
)
infos.append({
"alpha": self._alpha,
"kl_div": kl.mean(),
})
return rew, infos

def _prepare_inputs_for_model(
Expand All @@ -173,11 +170,7 @@ def _prepare_inputs_for_model(
}
elif self.use_data_parallel:
model_inputs = {
key: (
value.to(self.device)
if isinstance(value, torch.Tensor)
else value
)
key: value.to(self.device) if isinstance(value, torch.Tensor) else value
for key, value in model_inputs.items()
}
elif self.use_deepspeed:
Expand Down
24 changes: 10 additions & 14 deletions openrl/envs/nlp/utils/metrics/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,16 @@ def _info(self):
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=[
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Sequence(
datasets.Value("string", id="sequence"), id="references"
),
}
),
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}
),
datasets.Features({
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Sequence(
datasets.Value("string", id="sequence"), id="references"
),
}),
datasets.Features({
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}),
],
codebase_urls=[
"https://github.com/nltk/nltk/blob/develop/nltk/translate/meteor_score.py"
Expand Down
1 change: 1 addition & 0 deletions openrl/envs/vec_env/wrappers/reward_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from openrl.envs.vec_env.wrappers.base_wrapper import VecEnvWrapper
from openrl.rewards.base_reward import BaseReward


class RewardWrapper(VecEnvWrapper):
def __init__(self, env: BaseVecEnv, reward_class: BaseReward):
super().__init__(env)
Expand Down
64 changes: 35 additions & 29 deletions openrl/modules/networks/policy_network_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
# limitations under the License.

""""""
from typing import Any, Optional, Dict
from typing import Any, Dict, Optional

import numpy as np
import torch
import torch.nn as nn
from transformers.modeling_utils import unwrap_model

from openrl.buffers.utils.util import get_policy_obs, get_policy_obs_space
from openrl.envs.nlp.utils.distribution import CategoricalDistribution
from openrl.modules.networks.base_policy_network import BasePolicyNetwork
from openrl.modules.networks.utils.act import ACTLayer
from openrl.modules.networks.utils.cnn import CNNBase
Expand All @@ -31,9 +33,7 @@
from openrl.modules.networks.utils.rnn import RNNLayer
from openrl.modules.networks.utils.util import init
from openrl.utils.util import check_v2 as check
from openrl.envs.nlp.utils.distribution import CategoricalDistribution

from transformers.modeling_utils import unwrap_model

class PolicyNetworkGPT(BasePolicyNetwork):
def __init__(
Expand All @@ -46,25 +46,26 @@ def __init__(
disable_drop_out: bool = True,
extra_args=None,
) -> None:

self.device = device
self.use_fp16 = cfg.use_fp16
self.use_deepspeed = cfg.use_deepspeed
self.use_half = False
self.use_data_parallel = not cfg.use_deepspeed # default to use data parallel
self.use_data_parallel = not cfg.use_deepspeed # default to use data parallel
self.use_model_parallel = False

assert not (self.use_deepspeed and self.use_data_parallel)
assert not (self.use_deepspeed and self.use_model_parallel)
assert not (self.use_data_parallel and self.use_model_parallel)

super(PolicyNetworkGPT, self).__init__(cfg, device)

self.disable_drop_out = disable_drop_out

self._action_dist = CategoricalDistribution(action_space.n)

from transformers import AutoConfig, AutoModelForCausalLM

config = AutoConfig.from_pretrained(cfg.model_path)
config_dict = config.to_dict()
for key in config_dict:
Expand All @@ -85,15 +86,14 @@ def __init__(
self._policy_model = torch.nn.DataParallel(self._policy_model)
self._policy_model = self._policy_model.to(self.device)


def forward(self, forward_type, *args, **kwargs):
if forward_type == "original":
return self.forward_original(*args, **kwargs)
elif forward_type == "eval_actions":
return self.eval_actions(*args, **kwargs)
else:
raise NotImplementedError

def _prepare_inputs_for_model(
self,
model: Any,
Expand Down Expand Up @@ -121,7 +121,11 @@ def forward_original(
self, raw_obs, rnn_states, masks, action_masks=None, deterministic=False
):
for key in raw_obs.keys():
raw_obs[key] = torch.from_numpy(raw_obs[key]) if type(raw_obs[key]) == np.ndarray else raw_obs[key]
raw_obs[key] = (
torch.from_numpy(raw_obs[key])
if type(raw_obs[key]) == np.ndarray
else raw_obs[key]
)
rnn_states = check(rnn_states)

if self.use_half:
Expand All @@ -138,35 +142,37 @@ def forward_original(
else:
input_ids = input_ids.to(self._policy_model.device)
attention_mask = attention_mask.to(self._policy_model.device)

past_model_kwargs = None

if past_model_kwargs is None:
past_model_kwargs = {
"attention_mask": attention_mask,
}

model_inputs = self._prepare_inputs_for_model(
self._policy_model, input_ids, past_model_kwargs
)

# forward pass to transformers
output = self._policy_model(**model_inputs)

# compute action probs - policy head
next_token_logits = output.logits[:, -1]
next_token_logits = output.logits[:, -1]
dist = self._action_dist.proba_distribution(action_logits=next_token_logits)

actions = dist.mode() if deterministic else dist.sample()
action_log_probs = dist.log_prob(actions)

return actions.unsqueeze(-1), action_log_probs.unsqueeze(-1), rnn_states

def eval_actions(
self, obs, rnn_states, action, masks, action_masks=None, active_masks=None
):
for key in obs.keys():
obs[key] = torch.from_numpy(obs[key]) if type(obs[key]) == np.ndarray else obs[key]
obs[key] = (
torch.from_numpy(obs[key]) if type(obs[key]) == np.ndarray else obs[key]
)
if self.use_data_parallel:
obs[key] = obs[key].to(self.device)
else:
Expand All @@ -176,37 +182,37 @@ def eval_actions(
else:
action = check(action).to(self._policy_model.device).squeeze()
rnn_states = check(rnn_states)

if self.half:
input_ids = obs["input_encoded_pt"].int()
attention_mask = obs["input_attention_mask_pt"].int()
else:
input_ids = obs["input_encoded_pt"].long()
attention_mask = obs["input_attention_mask_pt"].long()

past_model_kwargs = None

if past_model_kwargs is None:
past_model_kwargs = {
"attention_mask": attention_mask,
}

model_inputs = self._prepare_inputs_for_model(
self._policy_model, input_ids, past_model_kwargs
)

# forward pass to transformers
output = self._policy_model(**model_inputs)

# compute action probs - policy head
next_token_logits = output.logits[:, -1]
dist = self._action_dist.proba_distribution(action_logits=next_token_logits)

action_log_probs = dist.log_prob(action)
dist_entropy = dist.entropy()
values = None

return action_log_probs.unsqueeze(-1), dist_entropy.mean(), values

def get_policy_values(self, obs, rnn_states, masks):
raise NotImplementedError
raise NotImplementedError
Loading

0 comments on commit 3af7588

Please sign in to comment.