Skip to content

Commit

Permalink
deepspeed support
Browse files Browse the repository at this point in the history
deepspeed support
  • Loading branch information
huangshiyu13 authored Dec 20, 2023
2 parents 76d1e05 + fc02030 commit 8185373
Show file tree
Hide file tree
Showing 16 changed files with 528 additions and 118 deletions.
4 changes: 1 addition & 3 deletions examples/nlp/ds_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
"train_micro_batch_size_per_gpu": 16,
"steps_per_print": 10,
"zero_optimization": {
"stage": 2,
"reduce_bucket_size": 5e7,
"allgather_bucket_size": 5e7
"stage": 2
},
"fp16": {"enabled": false, "loss_scale_window": 100}
}
2 changes: 0 additions & 2 deletions examples/nlp/nlp_ppo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ wandb_entity: "openrl-lab"
ppo_epoch: 5
episode_length: 128
num_mini_batch: 20
use_share_model: true

hidden_size: 1


model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog
env:
args: {
Expand Down
1 change: 0 additions & 1 deletion examples/nlp/nlp_ppo_ds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ wandb_entity: "openrl-lab"
ppo_epoch: 5
episode_length: 128
num_mini_batch: 20
use_share_model: true

hidden_size: 1

Expand Down
7 changes: 3 additions & 4 deletions examples/nlp/train_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.modules.common import PPONet as Net
from openrl.modules.networks.policy_value_network_gpt import (
PolicyValueNetworkGPT as PolicyValueNetwork,
)
from openrl.modules.networks.policy_network_gpt import PolicyNetworkGPT as PolicyNetwork
from openrl.modules.networks.value_network_gpt import ValueNetworkGPT as ValueNetwork
from openrl.runners.common import PPOAgent as Agent


Expand All @@ -29,7 +28,7 @@ def train():
)

# create the neural network
model_dict = {"model": PolicyValueNetwork}
model_dict = {"policy": PolicyNetwork, "critic": ValueNetwork}
net = Net(env, device="cuda", cfg=cfg, model_dict=model_dict)

# initialize the trainer
Expand Down
14 changes: 11 additions & 3 deletions openrl/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ def __init__(

def ppo_update(self, sample, turn_on=True):
for optimizer in self.algo_module.optimizers.values():
optimizer.zero_grad()
if not self.use_deepspeed:
optimizer.zero_grad()

(
critic_obs_batch,
Expand Down Expand Up @@ -152,8 +153,15 @@ def ppo_update(self, sample, turn_on=True):

self.algo_module.scaler.update()
else:
for optimizer in self.algo_module.optimizers.values():
optimizer.step()
if self.use_deepspeed:
if self._use_share_model:
self.algo_module.optimizers["model"].step()
else:
self.algo_module.optimizers["policy"].step()
self.algo_module.optimizers["critic"].step()
else:
for optimizer in self.algo_module.optimizers.values():
optimizer.step()

if self.world_size > 1:
torch.cuda.synchronize()
Expand Down
25 changes: 12 additions & 13 deletions openrl/envs/nlp/daily_dialog_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,16 @@ def __init__(
# set the observation and action space here
self._vocab_size = self.tokenizer.vocab_size

self.observation_space = DictSpace(
{
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
}
)
self.observation_space = DictSpace({
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
})
self.action_space = Discrete(n=self._vocab_size)
# see https://github.com/huggingface/transformers/issues/4875 : rounding up to nearest power of 2 for better GPU efficiency

Expand Down Expand Up @@ -113,7 +111,8 @@ def __init__(
self.__time_step = None
self.reward_function = None

def set_reward(self, reward_fn):
def set_reward(self, reward_fn=None):

self.reward_function = reward_fn

def step_word(self, word: str) -> Tuple[Dict[str, torch.tensor], int, bool, dict]:
Expand Down
22 changes: 10 additions & 12 deletions openrl/envs/nlp/fake_dialog_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,16 @@ def __init__(
# set the observation and action space here
self._vocab_size = 2

self.observation_space = DictSpace(
{
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
}
)
self.observation_space = DictSpace({
"input_encoded_pt": spaces.Box(
low=0,
high=self._vocab_size,
shape=(self._max_text_length + self.max_steps,),
),
"input_attention_mask_pt": spaces.Box(
low=0, high=1, shape=(self._max_text_length + self.max_steps,)
),
})
self.action_space = Discrete(n=self._vocab_size)

n = 2
Expand Down
36 changes: 25 additions & 11 deletions openrl/envs/nlp/rewards/intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,15 @@ def __init__(

self._intent_coeff = intent_coeff
self.use_deepspeed = use_deepspeed
self.use_half = False
self.use_data_parallel = not use_deepspeed # default to use data parallel
self.use_model_parallel = False

if intent_model == "builtin_intent":

self._device = "cpu"
self.use_data_parallel = False

from transformers import GPT2Config, GPT2LMHeadModel

class TestTokenizer:
Expand All @@ -62,6 +70,7 @@ def __init__(self, input_ids, attention_mask):
self._model = GPT2LMHeadModel(config)

else:
self._device = "cuda"
model_path = data_abs_path(intent_model)
self._tokenizer = AutoTokenizer.from_pretrained(intent_model)
self._model = AutoModelForSequenceClassification.from_pretrained(model_path)
Expand All @@ -77,19 +86,17 @@ def __init__(self, input_ids, attention_mask):
with open(ds_config) as file:
ds_config = json.load(file)

self._device = "cuda"
self._model = self._model.to("cuda")
self._model = self._model.to(self._device)
self._model, *_ = deepspeed.initialize(model=self._model, config=ds_config)
self.use_fp16 = ds_config["fp16"]["enabled"]
else:
if torch.cuda.is_available():
manager = LocalGPUManager()
manager.log_info()
self._device = f"cuda:{manager.get_gpu()}"
else:
self._device = "cpu"
print("Intent Model choose to use device:{}".format(self._device))

self._model = self._model.to(self._device)
if self.use_model_parallel:
self._model.parallelize()
elif self.use_data_parallel:
if self.use_half:
self._model = self._model.half()
self._model = torch.nn.DataParallel(self._model)
self._model = self._model.to(self._device)

def __call__(
self,
Expand Down Expand Up @@ -120,6 +127,13 @@ def get_input_for_classifier(prompt, generated_text):
input_texts, return_tensors="pt", truncation=True, padding=True
)

if self.use_half:
encoded.input_ids = encoded.input_ids.int()
encoded.attention_mask = encoded.attention_mask.int()
else:
encoded.input_ids = encoded.input_ids.long()
encoded.attention_mask = encoded.attention_mask.long()

with torch.no_grad():
outputs = self._model(
input_ids=encoded.input_ids.to(self._device),
Expand Down
62 changes: 42 additions & 20 deletions openrl/envs/nlp/rewards/kl_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,22 @@ def __init__(
ds_config: str = "default",
):
super().__init__()

self.device = "cuda"
self.use_deepspeed = use_deepspeed
self.use_half = False
self.use_data_parallel = not use_deepspeed
self.use_model_parallel = False
assert not (self.use_deepspeed and self.use_data_parallel)
assert not (self.use_deepspeed and self.use_model_parallel)
assert not (self.use_data_parallel and self.use_model_parallel)

# reference model
self._apply_model_parallel = apply_model_parallel
if ref_model == "builtin_ref":

self.device = "cpu"
self.use_data_parallel = False

from transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config()
Expand All @@ -64,11 +75,15 @@ def __init__(
self.use_fp16 = False

self._ref_engine, *_ = deepspeed.initialize(model=self, config=ds_config)
elif torch.cuda.is_available():
if self._apply_model_parallel and self._ref_net.is_parallelizable:
else:
if self.use_model_parallel:
self._ref_net.parallelize()
else: # else defaults to data parallel
self._ref_net = torch.nn.DataParallel(self._ref_net)
elif self.use_data_parallel: # else defaults to data parallel
if self.use_half:
self._ref_net = self._ref_net.half()
else:
self._ref_net = torch.nn.DataParallel(self._ref_net)
self._ref_net = self._ref_net.to(self.device)

# alpha adjustment
self._alpha = 0.2
Expand Down Expand Up @@ -106,32 +121,35 @@ def __call__(
self._ref_net, input_ids, past_model_kwargs
)

if self.use_deepspeed:
if self.use_fp16:
for key in ["input_ids", "position_ids"]:
model_inputs[key] = model_inputs[key].half().int()
for key in ["attention_mask"]:
model_inputs[key] = model_inputs[key].half()
if self.use_half:
for key in ["input_ids", "position_ids", "attention_mask"]:
if key in model_inputs:
model_inputs[key] = model_inputs[key].int()
else:
for key in ["input_ids", "position_ids", "attention_mask"]:
if key in model_inputs:
model_inputs[key] = model_inputs[key].long()

with torch.no_grad():
output = self._ref_net(output_hidden_states=True, **model_inputs)
output["past_key_values"] = None
next_token_logits = output.logits[:, -1, :]
if self.use_deepspeed and self.use_fp16:
next_token_logits = next_token_logits.double()
dist = self._action_dist.proba_distribution(action_logits=next_token_logits)
action_input = actions.to(next_token_logits.device)
ref_log_prob = dist.log_prob(action_input)

ref_log_prob = ref_log_prob.reshape(action_log_probs.shape)

kl_div = action_log_probs.copy() - ref_log_prob.detach().cpu().numpy()
rew = -self._alpha * kl_div
infos = []
for kl in kl_div:
infos.append(
{
"alpha": self._alpha,
"kl_div": kl.mean(),
}
)
infos.append({
"alpha": self._alpha,
"kl_div": kl.mean(),
})
return rew, infos

def _prepare_inputs_for_model(
Expand All @@ -144,7 +162,7 @@ def _prepare_inputs_for_model(
input_ids, **model_kwargs
)

if self._apply_model_parallel and unwrap_model(model).is_parallelizable:
if self.use_model_parallel:
# if model is in parallel mode, move the tensors to the first device
model_inputs = {
key: (
Expand All @@ -155,8 +173,12 @@ def _prepare_inputs_for_model(
)
for key, value in model_inputs.items()
}

if self.use_deepspeed:
elif self.use_data_parallel:
model_inputs = {
key: value.to(self.device) if isinstance(value, torch.Tensor) else value
for key, value in model_inputs.items()
}
elif self.use_deepspeed:
model_inputs = {
key: value.to("cuda") if isinstance(value, torch.Tensor) else value
for key, value in model_inputs.items()
Expand Down
24 changes: 10 additions & 14 deletions openrl/envs/nlp/utils/metrics/meteor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,16 @@ def _info(self):
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=[
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Sequence(
datasets.Value("string", id="sequence"), id="references"
),
}
),
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}
),
datasets.Features({
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Sequence(
datasets.Value("string", id="sequence"), id="references"
),
}),
datasets.Features({
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}),
],
codebase_urls=[
"https://github.com/nltk/nltk/blob/develop/nltk/translate/meteor_score.py"
Expand Down
4 changes: 2 additions & 2 deletions openrl/envs/vec_env/wrappers/reward_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class RewardWrapper(VecEnvWrapper):
def __init__(self, env: BaseVecEnv, reward_class: BaseReward):
super().__init__(env)
self.reward_class = reward_class
if len(self.reward_class.inner_rew_funcs) > 0:
env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs})
# if len(self.reward_class.inner_rew_funcs) > 0:
# env.call("set_reward", **{"reward_fn": self.reward_class.inner_rew_funcs})

def step(
self, action: ActType, extra_data: Optional[Dict[str, Any]]
Expand Down
Loading

0 comments on commit 8185373

Please sign in to comment.