Skip to content

Commit

Permalink
fix test w/o gpu bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Wen-Tse Chen committed Dec 20, 2023
1 parent 3af7588 commit f8879b3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
9 changes: 6 additions & 3 deletions openrl/envs/nlp/rewards/intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(
self.use_model_parallel = False

if intent_model == "builtin_intent":

self._device = "cpu"
self.use_data_parallel = False

from transformers import GPT2Config, GPT2LMHeadModel

class TestTokenizer:
Expand All @@ -66,6 +70,7 @@ def __init__(self, input_ids, attention_mask):
self._model = GPT2LMHeadModel(config)

else:
self._device = "cuda"
model_path = data_abs_path(intent_model)
self._tokenizer = AutoTokenizer.from_pretrained(intent_model)
self._model = AutoModelForSequenceClassification.from_pretrained(model_path)
Expand All @@ -81,12 +86,10 @@ def __init__(self, input_ids, attention_mask):
with open(ds_config) as file:
ds_config = json.load(file)

self._device = "cuda"
self._model = self._model.to("cuda")
self._model = self._model.to(self._device)
self._model, *_ = deepspeed.initialize(model=self._model, config=ds_config)
self.use_fp16 = ds_config["fp16"]["enabled"]
else:
self._device = "cuda"
if self.use_model_parallel:
self._model.parallelize()
elif self.use_data_parallel:
Expand Down
9 changes: 7 additions & 2 deletions openrl/envs/nlp/rewards/kl_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def __init__(

# reference model
if ref_model == "builtin_ref":

self.device = "cpu"
self.use_data_parallel = False

from transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config()
Expand Down Expand Up @@ -77,8 +81,9 @@ def __init__(
elif self.use_data_parallel: # else defaults to data parallel
if self.use_half:
self._ref_net = self._ref_net.half()
self._ref_net = torch.nn.DataParallel(self._ref_net)
self._ref_net = self._ref_net.to(self.device)
else:
self._ref_net = torch.nn.DataParallel(self._ref_net)
self._ref_net = self._ref_net.to(self.device)

# alpha adjustment
self._alpha = 0.2
Expand Down

0 comments on commit f8879b3

Please sign in to comment.