Skip to content

Commit

Permalink
text_generation_utils memory reduction if no logprob needed (NVIDIA#6773
Browse files Browse the repository at this point in the history
)

* repro for gpt eval mp mem issue

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* add print statements for memory allocation

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* adjusted hot fix that prevents softmax on the entire output embedding,now memory bottlenecked by attention softmax which needs to be solved with FA or long attention

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* using compute_logprob to configure inference

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* enable compute logprob for peft

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* remove print statements

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ci

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* added docstrings

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add missing config

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* remove truncate prompt length feature

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* tensor before all gather needs to be contiguous

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>

---------

Signed-off-by: Yang Zhang <yangzhang@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Evelina <10428420+ekmb@users.noreply.github.com>
Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
  • Loading branch information
4 people authored Jun 6, 2023
1 parent f9bb1b0 commit 3063e32
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 58 deletions.
22 changes: 12 additions & 10 deletions examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def main(cfg) -> None:

if os.path.isdir(cfg.model.restore_from_path):
save_restore_connector.model_extracted_dir = cfg.model.restore_from_path
model = NLPModel.restore_from(
model = MegatronGPTSFTModel.restore_from(
restore_path=cfg.model.restore_from_path,
trainer=trainer,
override_config_path=peft_model_cfg,
Expand All @@ -180,15 +180,17 @@ def main(cfg) -> None:
for batch in response:
batch_sentences = [s for s in batch['sentences']]
batch_tokens = [s for s in batch['tokens']]
batch_logprob = [s.tolist() for s in batch['logprob']]
for s, t, l in zip(batch_sentences, batch_tokens, batch_logprob):
if cfg.inference.get("verbose", False):
d = {
'sentence': s,
'tokens_with_logprobs': ', '.join([f"{_t} {_l:.4f}" for _t, _l in zip(t, l)]),
}
f.write(json.dumps(d, sort_keys=True, indent=2) + '\n')
else:
if cfg.inference.compute_logprob:
batch_logprob = [s.tolist() for s in batch['logprob']]
for s, t, l in zip(batch_sentences, batch_tokens, batch_logprob):
if cfg.inference.get("verbose", False):
d = {
'sentence': s,
'tokens_with_logprobs': ', '.join([f"{_t} {_l:.4f}" for _t, _l in zip(t, l)]),
}
f.write(json.dumps(d, sort_keys=True, indent=2) + '\n')
else:
for s in batch_sentences:
d = {'sentence': s}
f.write(json.dumps(d) + '\n')
print("predictions saved to {}".format(cfg.inference.outfile_path))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1111,7 +1111,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
inference_config = inference_config.copy()
compute_logprob = inference_config['compute_logprob']
if compute_logprob:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
inference_config['tokens_to_generate'] = 1
inference_config['all_probs'] = True
Expand All @@ -1121,7 +1120,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
compute_prob_response = get_computeprob_response(self.tokenizer, response, batch)
return compute_prob_response
else:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
return generate(self, **inference_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
LengthParam,
SamplingParam,
generate,
get_computeprob_response,
megatron_gpt_generate,
)
from nemo.utils import AppState, logging
Expand Down Expand Up @@ -539,7 +540,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
inference_config = inference_config.copy()
compute_logprob = inference_config['compute_logprob']
if compute_logprob:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
inference_config['tokens_to_generate'] = 1
inference_config['all_probs'] = True
Expand All @@ -549,8 +549,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
compute_prob_response = get_computeprob_response(self.tokenizer, response, batch)
return compute_prob_response
else:
del inference_config['compute_logprob']

# for megatron_gpt_eval.py
if isinstance(batch, list):
inference_config['inputs'] = batch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
inference_config = inference_config.copy()
compute_logprob = inference_config['compute_logprob']
if compute_logprob:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
inference_config['tokens_to_generate'] = 1
inference_config['all_probs'] = True
Expand All @@ -474,7 +473,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int]
compute_prob_response = get_computeprob_response(self.tokenizer, response, batch)
return compute_prob_response
else:
del inference_config['compute_logprob']
inference_config['inputs'] = batch
return generate(self, **inference_config, strategy=self.inference_strategy)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(self, model):

def forward_step(self, batch, tensor_shape):
fwd_bwd_function = get_forward_backward_func()

output_tensor = fwd_bwd_function(
forward_step_func=self.model.get_forward_output_only_func(),
data_iterator=iter([batch,]),
Expand Down
110 changes: 70 additions & 40 deletions nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para
inputs=inputs,
tokens_to_generate=length_params['max_length'],
all_probs=sampling_params['all_probs'],
compute_logprob=sampling_params['compute_logprob'],
temperature=sampling_params['temperature'],
add_BOS=sampling_params['add_BOS'],
top_k=sampling_params['top_k'],
Expand All @@ -116,6 +117,7 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para
inputs=inputs,
tokens_to_generate=length_params['max_length'],
all_probs=sampling_params['all_probs'],
compute_logprob=sampling_params['compute_logprob'],
temperature=sampling_params['temperature'],
add_BOS=sampling_params['add_BOS'],
top_k=sampling_params['top_k'],
Expand Down Expand Up @@ -269,6 +271,7 @@ def send_generate_info(
context_length_tensor,
tokens_to_generate,
all_probs,
compute_logprob,
temperature,
top_k,
top_p,
Expand All @@ -288,6 +291,7 @@ def send_generate_info(
context_tokens_tensor.size(1), # seq_len
tokens_to_generate,
all_probs,
compute_logprob, # whether to compute log probabilities matrix
temperature,
top_k,
top_p,
Expand Down Expand Up @@ -317,18 +321,19 @@ def receive_generate_info():
"""
model_parallel_group = parallel_state.get_model_parallel_group()
src = get_model_parallel_src_rank()
input_info_tensor = torch.empty(10, dtype=torch.float32, device=torch.cuda.current_device())
input_info_tensor = torch.empty(11, dtype=torch.float32, device=torch.cuda.current_device())
torch.distributed.broadcast(input_info_tensor, src, model_parallel_group)
batch_size = int(input_info_tensor[0].item())
seq_len = int(input_info_tensor[1].item())
tokens_to_generate = int(input_info_tensor[2].item())
all_probs = bool(input_info_tensor[3].item())
temperature = float(input_info_tensor[4].item())
top_k = int(input_info_tensor[5].item())
top_p = float(input_info_tensor[6].item())
greedy = bool(input_info_tensor[7].item())
repetition_penalty = float(input_info_tensor[8].item())
min_tokens_to_generate = int(input_info_tensor[9].item())
compute_logprob = bool(input_info_tensor[4].item()) # whether to compute log probabilities matrix
temperature = float(input_info_tensor[5].item())
top_k = int(input_info_tensor[6].item())
top_p = float(input_info_tensor[7].item())
greedy = bool(input_info_tensor[8].item())
repetition_penalty = float(input_info_tensor[9].item())
min_tokens_to_generate = int(input_info_tensor[10].item())

context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device())
context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device())
Expand All @@ -349,6 +354,7 @@ def receive_generate_info():
context_tokens_tensor,
tokens_to_generate,
all_probs,
compute_logprob,
temperature,
top_k,
top_p,
Expand All @@ -370,6 +376,7 @@ def synced_generate(
top_k=0,
top_p=0.0,
greedy=False,
compute_logprob=False,
repetition_penalty=1.2,
min_tokens_to_generate=0,
end_strings=[],
Expand All @@ -394,6 +401,7 @@ def synced_generate(
context_length_tensor,
tokens_to_generate,
all_probs,
compute_logprob=compute_logprob,
temperature=temperature,
end_strings=end_strings,
extra={
Expand All @@ -411,7 +419,8 @@ def synced_generate(
if parallel_state.is_pipeline_last_stage():
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
torch.distributed.broadcast(output_logits, src, group)
if compute_logprob:
torch.distributed.broadcast(output_logits, src, group)
if all_probs:
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
Expand All @@ -422,15 +431,18 @@ def synced_generate(
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()

precision = model._trainer.precision
if precision in [16, "16"]:
dtype = torch.float16
elif precision == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
output_logits = torch.empty(tokens.size(0), context_length - 1, dtype=dtype, device=torch.device("cuda"))
torch.distributed.broadcast(output_logits, src, group)
if compute_logprob:
precision = model._trainer.precision
if precision in [16, "16"]:
dtype = torch.float16
elif precision == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
output_logits = torch.empty(
tokens.size(0), context_length - 1, dtype=dtype, device=torch.device("cuda")
)
torch.distributed.broadcast(output_logits, src, group)

if all_probs:
src = parallel_state.get_pipeline_model_parallel_last_rank()
Expand All @@ -457,6 +469,7 @@ def generate(
top_k=0,
top_p=0.0,
greedy=False,
compute_logprob=False,
repetition_penalty=1.0,
min_tokens_to_generate=0,
end_strings=['<|endoftext|>'],
Expand Down Expand Up @@ -504,6 +517,7 @@ def generate(
context_length_tensor,
tokens_to_generate,
all_probs,
compute_logprob,
temperature,
top_k,
top_p,
Expand All @@ -518,6 +532,7 @@ def generate(
context_tokens_tensor,
tokens_to_generate,
all_probs,
compute_logprob,
temperature,
top_k,
top_p,
Expand All @@ -535,6 +550,7 @@ def generate(
tokens_to_generate,
all_probs,
temperature,
compute_logprob=compute_logprob,
top_k=top_k,
top_p=top_p,
greedy=greedy,
Expand Down Expand Up @@ -619,6 +635,7 @@ def sample_sequence_batch(
context_lengths,
tokens_to_generate,
all_probs=False,
compute_logprob=False,
type_ids=None,
temperature=None,
end_strings=['<|endoftext|>'],
Expand Down Expand Up @@ -673,11 +690,18 @@ def sample_sequence_batch(
output = inference_strategy.forward_step(batch, tensor_shape)

if parallel_state.is_pipeline_last_stage():
output = output[0]['logits']

output = tensor_parallel.gather_from_tensor_model_parallel_region(output)
assert output is not None
logits = output[:, -1].view(batch_size, -1).contiguous()
if compute_logprob:
output = output[0]['logits']
output = tensor_parallel.gather_from_tensor_model_parallel_region(output)
assert output is not None
logits = output[:, -1].view(batch_size, -1).contiguous()

else:
logits = output[0]['logits'][:, -1].contiguous()
logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits)
assert logits is not None
logits = logits.view(batch_size, -1)

# make sure it will generate at least min_length
min_length = extra.get('min_tokens_to_generate', 0)
Expand All @@ -689,6 +713,7 @@ def sample_sequence_batch(
logits[:, tokenizer.vocab_size :] = -float('Inf')

# started indicates whether the current token step passes the context_length, so we make sure not to overwrite the context tokens

started = context_lengths <= context_length
if extra.get('greedy', False):
prev = torch.argmax(logits, dim=-1).view(-1)
Expand Down Expand Up @@ -716,23 +741,25 @@ def sample_sequence_batch(
# Insert either new predicted or next prompt token
tokens[:, context_length] = new_tokens

if output_logits is None:
output = F.log_softmax(output[:, :context_length, :], 2)
indices = torch.unsqueeze(tokens[:, 1 : context_length + 1], 2)
output_logits = torch.gather(output, 2, indices).squeeze(2)
all_generated_indices = indices[:, :, 0]
if all_probs:
full_logits = output
else:
output = F.log_softmax(output, 2)
indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
new_output_logits = torch.gather(output, 2, indices).squeeze(2)
if compute_logprob:
if output_logits is None:
output = F.log_softmax(output[:, :context_length, :], 2)

# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits = torch.cat([output_logits, new_output_logits], 1)
all_generated_indices = torch.cat([all_generated_indices, indices[:, :, 0]], 1)
if all_probs:
full_logits = torch.cat([full_logits, output], 1)
indices = torch.unsqueeze(tokens[:, 1 : context_length + 1], 2)
output_logits = torch.gather(output, 2, indices).squeeze(2)
all_generated_indices = indices[:, :, 0]
if all_probs:
full_logits = output
else:
output = F.log_softmax(output, 2)
indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2)
new_output_logits = torch.gather(output, 2, indices).squeeze(2)

# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits = torch.cat([output_logits, new_output_logits], 1)
all_generated_indices = torch.cat([all_generated_indices, indices[:, :, 0]], 1)
if all_probs:
full_logits = torch.cat([full_logits, output], 1)

src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_embedding_group()
Expand All @@ -752,10 +779,13 @@ def sample_sequence_batch(
src = parallel_state.get_pipeline_model_parallel_last_rank()
group = parallel_state.get_pipeline_model_parallel_group()
torch.distributed.broadcast(done, src, group)
if all_probs:
yield tokens, lengths, output_logits, full_logits
if compute_logprob:
if all_probs:
yield tokens, lengths, output_logits, full_logits
else:
yield tokens, lengths, output_logits, None
else:
yield tokens, lengths, output_logits, None
yield tokens, lengths, None, None

else:
if parallel_state.is_pipeline_first_stage():
Expand Down

0 comments on commit 3063e32

Please sign in to comment.