Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed support #282

Merged
merged 10 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/nlp/ds_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 10,
"zero_optimization": {
"stage": 2,
"reduce_bucket_size": 5e7,
"allgather_bucket_size": 5e7
"stage": 2
},
"fp16": {"enabled": false, "loss_scale_window": 100}
}
2 changes: 0 additions & 2 deletions examples/nlp/nlp_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ wandb_entity: "openrl-lab"
ppo_epoch: 5
episode_length: 128
num_mini_batch: 20
use_share_model: true

hidden_size: 1


model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog
env:
args: {
Expand Down
1 change: 0 additions & 1 deletion examples/nlp/nlp_ppo_ds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ wandb_entity: "openrl-lab"
ppo_epoch: 5
episode_length: 128
num_mini_batch: 20
use_share_model: true

hidden_size: 1

Expand Down
7 changes: 3 additions & 4 deletions examples/nlp/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.modules.networks.policy_value_network_gpt import (
PolicyValueNetworkGPT as PolicyValueNetwork,
)
from openrl.modules.networks.policy_network_gpt import PolicyNetworkGPT as PolicyNetwork
from openrl.modules.networks.value_network_gpt import ValueNetworkGPT as ValueNetwork
from openrl.runners.common import PPOAgent as Agent


Expand All @@ -29,7 +28,7 @@ def train():
)

# create the neural network
model_dict = {"model": PolicyValueNetwork}
model_dict = {"policy": PolicyNetwork, "critic": ValueNetwork}
net = Net(env, device="cuda", cfg=cfg, model_dict=model_dict)

# initialize the trainer
Expand Down
14 changes: 11 additions & 3 deletions openrl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def __init__(

def ppo_update(self, sample, turn_on=True):
for optimizer in self.algo_module.optimizers.values():
optimizer.zero_grad()
if not self.use_deepspeed:
optimizer.zero_grad()

(
critic_obs_batch,
Expand Down Expand Up @@ -152,8 +153,15 @@ def ppo_update(self, sample, turn_on=True):

self.algo_module.scaler.update()
else:
for optimizer in self.algo_module.optimizers.values():
optimizer.step()
if self.use_deepspeed:
if self._use_share_model:
self.algo_module.optimizers["model"].step()
else:
self.algo_module.optimizers["policy"].step()
self.algo_module.optimizers["critic"].step()
else:
for optimizer in self.algo_module.optimizers.values():
optimizer.step()

if self.world_size > 1:
torch.cuda.synchronize()
Expand Down
39 changes: 25 additions & 14 deletions openrl/envs/nlp/daily_dialog_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,16 @@ def __init__(
# set the observation and action space here
self._vocab_size = self.tokenizer.vocab_size

self.observation_space = DictSpace(
{
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
}
)
self.observation_space = DictSpace({
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
})
self.action_space = Discrete(n=self._vocab_size)
# see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency

Expand Down Expand Up @@ -113,8 +111,21 @@ def __init__(
self.__time_step = None
self.reward_function = None

def set_reward(self, reward_fn):
self.reward_function = reward_fn
self.set_reward()

def set_reward(self, reward_fn=None):

from openrl.envs.nlp.rewards.meteor import Meteor

meteor_config = {
"meteor_coeff": 0.5,
"test": False,
}
self.reward_function = {
"meteor": Meteor(**meteor_config),
}

# self.reward_function = reward_fn

def step_word(self, word: str) -> Tuple[Dict[str, torch.tensor], int, bool, dict]:
action = self.tokenizer.encode(word)[1]
Expand Down
22 changes: 10 additions & 12 deletions openrl/envs/nlp/fake_dialog_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,16 @@ def __init__(
# set the observation and action space here
self._vocab_size = 2

self.observation_space = DictSpace(
{
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
}
)
self.observation_space = DictSpace({
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
})
self.action_space = Discrete(n=self._vocab_size)

n = 2
Expand Down
29 changes: 20 additions & 9 deletions openrl/envs/nlp/rewards/intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def __init__(

self._intent_coeff = intent_coeff
self.use_deepspeed = use_deepspeed
self.use_half = False
self.use_data_parallel = not use_deepspeed # default to use data parallel
self.use_model_parallel = False

if intent_model == "builtin_intent":
from transformers import GPT2Config, GPT2LMHeadModel

Expand Down Expand Up @@ -80,16 +84,16 @@ def __init__(self, input_ids, attention_mask):
self._device = "cuda"
self._model = self._model.to("cuda")
self._model, *_ = deepspeed.initialize(model=self._model, config=ds_config)
self.use_fp16 = ds_config["fp16"]["enabled"]
else:
if torch.cuda.is_available():
manager = LocalGPUManager()
manager.log_info()
self._device = f"cuda:{manager.get_gpu()}"
else:
self._device = "cpu"
print("Intent Model choose to use device:{}".format(self._device))

self._model = self._model.to(self._device)
self._device = "cuda"
if self.use_model_parallel:
self._model.parallelize()
elif self.use_data_parallel:
if self.use_half:
self._model = self._model.half()
self._model = torch.nn.DataParallel(self._model)
self._model = self._model.to(self._device)

def __call__(
self,
Expand Down Expand Up @@ -120,6 +124,13 @@ def get_input_for_classifier(prompt, generated_text):
input_texts, return_tensors="pt", truncation=True, padding=True
)

if self.use_half:
encoded.input_ids = encoded.input_ids.int()
encoded.attention_mask = encoded.attention_mask.int()
else:
encoded.input_ids = encoded.input_ids.long()
encoded.attention_mask = encoded.attention_mask.long()

with torch.no_grad():
outputs = self._model(
input_ids=encoded.input_ids.to(self._device),
Expand Down
55 changes: 36 additions & 19 deletions openrl/envs/nlp/rewards/kl_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,17 @@ def __init__(
ds_config: str = "default",
):
super().__init__()

self.device = "cuda"
self.use_deepspeed = use_deepspeed
self.use_half = False
self.use_data_parallel = not use_deepspeed
self.use_model_parallel = False
assert not (self.use_deepspeed and self.use_data_parallel)
assert not (self.use_deepspeed and self.use_model_parallel)
assert not (self.use_data_parallel and self.use_model_parallel)

# reference model
self._apply_model_parallel = apply_model_parallel
if ref_model == "builtin_ref":
from transformers import GPT2Config, GPT2LMHeadModel

Expand All @@ -64,11 +71,14 @@ def __init__(
self.use_fp16 = False

self._ref_engine, *_ = deepspeed.initialize(model=self, config=ds_config)
elif torch.cuda.is_available():
if self._apply_model_parallel and self._ref_net.is_parallelizable:
else:
if self.use_model_parallel:
self._ref_net.parallelize()
else: # else defaults to data parallel
elif self.use_data_parallel: # else defaults to data parallel
if self.use_half:
self._ref_net = self._ref_net.half()
self._ref_net = torch.nn.DataParallel(self._ref_net)
self._ref_net = self._ref_net.to(self.device)

# alpha adjustment
self._alpha = 0.2
Expand Down Expand Up @@ -106,32 +116,35 @@ def __call__(
self._ref_net, input_ids, past_model_kwargs
)

if self.use_deepspeed:
if self.use_fp16:
for key in ["input_ids", "position_ids"]:
model_inputs[key] = model_inputs[key].half().int()
for key in ["attention_mask"]:
model_inputs[key] = model_inputs[key].half()
if self.use_half:
for key in ["input_ids", "position_ids", "attention_mask"]:
if key in model_inputs:
model_inputs[key] = model_inputs[key].int()
else:
for key in ["input_ids", "position_ids", "attention_mask"]:
if key in model_inputs:
model_inputs[key] = model_inputs[key].long()

with torch.no_grad():
output = self._ref_net(output_hidden_states=True, **model_inputs)
output["past_key_values"] = None
next_token_logits = output.logits[:, -1, :]
if self.use_deepspeed and self.use_fp16:
next_token_logits = next_token_logits.double()
dist = self._action_dist.proba_distribution(action_logits=next_token_logits)
action_input = actions.to(next_token_logits.device)
ref_log_prob = dist.log_prob(action_input)

ref_log_prob = ref_log_prob.reshape(action_log_probs.shape)

kl_div = action_log_probs.copy() - ref_log_prob.detach().cpu().numpy()
rew = -self._alpha * kl_div
infos = []
for kl in kl_div:
infos.append(
{
"alpha": self._alpha,
"kl_div": kl.mean(),
}
)
infos.append({
"alpha": self._alpha,
"kl_div": kl.mean(),
})
return rew, infos

def _prepare_inputs_for_model(
Expand All @@ -144,7 +157,7 @@ def _prepare_inputs_for_model(
input_ids, **model_kwargs
)

if self._apply_model_parallel and unwrap_model(model).is_parallelizable:
if self.use_model_parallel:
# if model is in parallel mode, move the tensors to the first device
model_inputs = {
key: (
Expand All @@ -155,8 +168,12 @@ def _prepare_inputs_for_model(
)
for key, value in model_inputs.items()
}

if self.use_deepspeed:
elif self.use_data_parallel:
model_inputs = {
key: value.to(self.device) if isinstance(value, torch.Tensor) else value
for key, value in model_inputs.items()
}
elif self.use_deepspeed:
model_inputs = {
key: value.to("cuda") if isinstance(value, torch.Tensor) else value
for key, value in model_inputs.items()
Expand Down
24 changes: 10 additions & 14 deletions openrl/envs/nlp/utils/metrics/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,16 @@ def _info(self):
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=[
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Sequence(
datasets.Value("string", id="sequence"), id="references"
),
}
),
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}
),
datasets.Features({
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Sequence(
datasets.Value("string", id="sequence"), id="references"
),
}),
datasets.Features({
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}),
],
codebase_urls=[
"https://github.com/nltk/nltk/blob/develop/nltk/translate/meteor_score.py"
Expand Down
4 changes: 2 additions & 2 deletions openrl/envs/vec_env/wrappers/reward_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class RewardWrapper(VecEnvWrapper):
def __init__(self, env: BaseVecEnv, reward_class: BaseReward):
super().__init__(env)
self.reward_class = reward_class
if len(self.reward_class.inner_rew_funcs) > 0:
env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs})
# if len(self.reward_class.inner_rew_funcs) > 0:
# env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs})

def step(
self, action: ActType, extra_data: Optional[Dict[str, Any]]
Expand Down
Loading
Loading