From 3063e3251bb4dbbf81278084cce132c3e56b4c52 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Tue, 6 Jun 2023 01:52:45 -0400 Subject: [PATCH] text_generation_utils memory reduction if no logprob needed (#6773) * repro for gpt eval mp mem issue Signed-off-by: Yang Zhang * add print statements for memory allocation Signed-off-by: Yang Zhang * adjusted hot fix that prevents softmax on the entire output embedding,now memory bottlenecked by attention softmax which needs to be solved with FA or long attention Signed-off-by: Yang Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * using compute_logprob to configure inference Signed-off-by: Yang Zhang * enable compute logprob for peft Signed-off-by: Yang Zhang * remove print statements Signed-off-by: Yang Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci Signed-off-by: Yang Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added docstrings Signed-off-by: Yang Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add missing config Signed-off-by: Yang Zhang * remove truncate prompt length feature Signed-off-by: Yang Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor before all gather needs to be contiguous Signed-off-by: Yang Zhang --------- Signed-off-by: Yang Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Evelina <10428420+ekmb@users.noreply.github.com> Co-authored-by: Sandeep Subramanian --- .../tuning/megatron_gpt_peft_eval.py | 22 ++-- .../language_modeling/megatron_gpt_model.py | 2 - .../megatron_gpt_sft_model.py | 4 +- .../megatron_retrieval_model.py | 2 - .../common/text_generation_strategy.py | 1 - .../modules/common/text_generation_utils.py | 110 +++++++++++------- 6 files changed, 83 insertions(+), 58 deletions(-) diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py index a5bf1ee552cb..fc427a60d172 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py @@ -155,7 +155,7 @@ def main(cfg) -> None: if os.path.isdir(cfg.model.restore_from_path): save_restore_connector.model_extracted_dir = cfg.model.restore_from_path - model = NLPModel.restore_from( + model = MegatronGPTSFTModel.restore_from( restore_path=cfg.model.restore_from_path, trainer=trainer, override_config_path=peft_model_cfg, @@ -180,15 +180,17 @@ def main(cfg) -> None: for batch in response: batch_sentences = [s for s in batch['sentences']] batch_tokens = [s for s in batch['tokens']] - batch_logprob = [s.tolist() for s in batch['logprob']] - for s, t, l in zip(batch_sentences, batch_tokens, batch_logprob): - if cfg.inference.get("verbose", False): - d = { - 'sentence': s, - 'tokens_with_logprobs': ', '.join([f"{_t} {_l:.4f}" for _t, _l in zip(t, l)]), - } - f.write(json.dumps(d, sort_keys=True, indent=2) + '\n') - else: + if cfg.inference.compute_logprob: + batch_logprob = [s.tolist() for s in batch['logprob']] + for s, t, l in zip(batch_sentences, batch_tokens, batch_logprob): + if cfg.inference.get("verbose", False): + d = { + 'sentence': s, + 'tokens_with_logprobs': ', '.join([f"{_t} {_l:.4f}" for _t, _l in zip(t, l)]), + } + f.write(json.dumps(d, sort_keys=True, indent=2) + '\n') + else: + for s in batch_sentences: d = {'sentence': s} f.write(json.dumps(d) + '\n') print("predictions saved to {}".format(cfg.inference.outfile_path)) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 9aadb6853190..3530ffcfc371 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1111,7 +1111,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] inference_config = inference_config.copy() compute_logprob = inference_config['compute_logprob'] if compute_logprob: - del inference_config['compute_logprob'] inference_config['inputs'] = batch inference_config['tokens_to_generate'] = 1 inference_config['all_probs'] = True @@ -1121,7 +1120,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] compute_prob_response = get_computeprob_response(self.tokenizer, response, batch) return compute_prob_response else: - del inference_config['compute_logprob'] inference_config['inputs'] = batch return generate(self, **inference_config) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 1dc335b86609..9507a01d01f0 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -35,6 +35,7 @@ LengthParam, SamplingParam, generate, + get_computeprob_response, megatron_gpt_generate, ) from nemo.utils import AppState, logging @@ -539,7 +540,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] inference_config = inference_config.copy() compute_logprob = inference_config['compute_logprob'] if compute_logprob: - del inference_config['compute_logprob'] inference_config['inputs'] = batch inference_config['tokens_to_generate'] = 1 inference_config['all_probs'] = True @@ -549,8 +549,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] compute_prob_response = get_computeprob_response(self.tokenizer, response, batch) return compute_prob_response else: - del inference_config['compute_logprob'] - # for megatron_gpt_eval.py if isinstance(batch, list): inference_config['inputs'] = batch diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py index afd8ad54d150..5900513f3547 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py @@ -464,7 +464,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] inference_config = inference_config.copy() compute_logprob = inference_config['compute_logprob'] if compute_logprob: - del inference_config['compute_logprob'] inference_config['inputs'] = batch inference_config['tokens_to_generate'] = 1 inference_config['all_probs'] = True @@ -474,7 +473,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] compute_prob_response = get_computeprob_response(self.tokenizer, response, batch) return compute_prob_response else: - del inference_config['compute_logprob'] inference_config['inputs'] = batch return generate(self, **inference_config, strategy=self.inference_strategy) diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index 27ae3b2606d3..310065fc3523 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -53,7 +53,6 @@ def __init__(self, model): def forward_step(self, batch, tensor_shape): fwd_bwd_function = get_forward_backward_func() - output_tensor = fwd_bwd_function( forward_step_func=self.model.get_forward_output_only_func(), data_iterator=iter([batch,]), diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index 3a07a807b11a..a56304970bdc 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -97,6 +97,7 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para inputs=inputs, tokens_to_generate=length_params['max_length'], all_probs=sampling_params['all_probs'], + compute_logprob=sampling_params['compute_logprob'], temperature=sampling_params['temperature'], add_BOS=sampling_params['add_BOS'], top_k=sampling_params['top_k'], @@ -116,6 +117,7 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para inputs=inputs, tokens_to_generate=length_params['max_length'], all_probs=sampling_params['all_probs'], + compute_logprob=sampling_params['compute_logprob'], temperature=sampling_params['temperature'], add_BOS=sampling_params['add_BOS'], top_k=sampling_params['top_k'], @@ -269,6 +271,7 @@ def send_generate_info( context_length_tensor, tokens_to_generate, all_probs, + compute_logprob, temperature, top_k, top_p, @@ -288,6 +291,7 @@ def send_generate_info( context_tokens_tensor.size(1), # seq_len tokens_to_generate, all_probs, + compute_logprob, # whether to compute log probabilities matrix temperature, top_k, top_p, @@ -317,18 +321,19 @@ def receive_generate_info(): """ model_parallel_group = parallel_state.get_model_parallel_group() src = get_model_parallel_src_rank() - input_info_tensor = torch.empty(10, dtype=torch.float32, device=torch.cuda.current_device()) + input_info_tensor = torch.empty(11, dtype=torch.float32, device=torch.cuda.current_device()) torch.distributed.broadcast(input_info_tensor, src, model_parallel_group) batch_size = int(input_info_tensor[0].item()) seq_len = int(input_info_tensor[1].item()) tokens_to_generate = int(input_info_tensor[2].item()) all_probs = bool(input_info_tensor[3].item()) - temperature = float(input_info_tensor[4].item()) - top_k = int(input_info_tensor[5].item()) - top_p = float(input_info_tensor[6].item()) - greedy = bool(input_info_tensor[7].item()) - repetition_penalty = float(input_info_tensor[8].item()) - min_tokens_to_generate = int(input_info_tensor[9].item()) + compute_logprob = bool(input_info_tensor[4].item()) # whether to compute log probabilities matrix + temperature = float(input_info_tensor[5].item()) + top_k = int(input_info_tensor[6].item()) + top_p = float(input_info_tensor[7].item()) + greedy = bool(input_info_tensor[8].item()) + repetition_penalty = float(input_info_tensor[9].item()) + min_tokens_to_generate = int(input_info_tensor[10].item()) context_length_tensor = torch.empty(batch_size, dtype=torch.int64, device=torch.cuda.current_device()) context_tokens_tensor = torch.empty(batch_size, seq_len, dtype=torch.int64, device=torch.cuda.current_device()) @@ -349,6 +354,7 @@ def receive_generate_info(): context_tokens_tensor, tokens_to_generate, all_probs, + compute_logprob, temperature, top_k, top_p, @@ -370,6 +376,7 @@ def synced_generate( top_k=0, top_p=0.0, greedy=False, + compute_logprob=False, repetition_penalty=1.2, min_tokens_to_generate=0, end_strings=[], @@ -394,6 +401,7 @@ def synced_generate( context_length_tensor, tokens_to_generate, all_probs, + compute_logprob=compute_logprob, temperature=temperature, end_strings=end_strings, extra={ @@ -411,7 +419,8 @@ def synced_generate( if parallel_state.is_pipeline_last_stage(): src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() - torch.distributed.broadcast(output_logits, src, group) + if compute_logprob: + torch.distributed.broadcast(output_logits, src, group) if all_probs: src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() @@ -422,15 +431,18 @@ def synced_generate( src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() - precision = model._trainer.precision - if precision in [16, "16"]: - dtype = torch.float16 - elif precision == "bf16": - dtype = torch.bfloat16 - else: - dtype = torch.float32 - output_logits = torch.empty(tokens.size(0), context_length - 1, dtype=dtype, device=torch.device("cuda")) - torch.distributed.broadcast(output_logits, src, group) + if compute_logprob: + precision = model._trainer.precision + if precision in [16, "16"]: + dtype = torch.float16 + elif precision == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + output_logits = torch.empty( + tokens.size(0), context_length - 1, dtype=dtype, device=torch.device("cuda") + ) + torch.distributed.broadcast(output_logits, src, group) if all_probs: src = parallel_state.get_pipeline_model_parallel_last_rank() @@ -457,6 +469,7 @@ def generate( top_k=0, top_p=0.0, greedy=False, + compute_logprob=False, repetition_penalty=1.0, min_tokens_to_generate=0, end_strings=['<|endoftext|>'], @@ -504,6 +517,7 @@ def generate( context_length_tensor, tokens_to_generate, all_probs, + compute_logprob, temperature, top_k, top_p, @@ -518,6 +532,7 @@ def generate( context_tokens_tensor, tokens_to_generate, all_probs, + compute_logprob, temperature, top_k, top_p, @@ -535,6 +550,7 @@ def generate( tokens_to_generate, all_probs, temperature, + compute_logprob=compute_logprob, top_k=top_k, top_p=top_p, greedy=greedy, @@ -619,6 +635,7 @@ def sample_sequence_batch( context_lengths, tokens_to_generate, all_probs=False, + compute_logprob=False, type_ids=None, temperature=None, end_strings=['<|endoftext|>'], @@ -673,11 +690,18 @@ def sample_sequence_batch( output = inference_strategy.forward_step(batch, tensor_shape) if parallel_state.is_pipeline_last_stage(): - output = output[0]['logits'] - output = tensor_parallel.gather_from_tensor_model_parallel_region(output) - assert output is not None - logits = output[:, -1].view(batch_size, -1).contiguous() + if compute_logprob: + output = output[0]['logits'] + output = tensor_parallel.gather_from_tensor_model_parallel_region(output) + assert output is not None + logits = output[:, -1].view(batch_size, -1).contiguous() + + else: + logits = output[0]['logits'][:, -1].contiguous() + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + assert logits is not None + logits = logits.view(batch_size, -1) # make sure it will generate at least min_length min_length = extra.get('min_tokens_to_generate', 0) @@ -689,6 +713,7 @@ def sample_sequence_batch( logits[:, tokenizer.vocab_size :] = -float('Inf') # started indicates whether the current token step passes the context_length, so we make sure not to overwrite the context tokens + started = context_lengths <= context_length if extra.get('greedy', False): prev = torch.argmax(logits, dim=-1).view(-1) @@ -716,23 +741,25 @@ def sample_sequence_batch( # Insert either new predicted or next prompt token tokens[:, context_length] = new_tokens - if output_logits is None: - output = F.log_softmax(output[:, :context_length, :], 2) - indices = torch.unsqueeze(tokens[:, 1 : context_length + 1], 2) - output_logits = torch.gather(output, 2, indices).squeeze(2) - all_generated_indices = indices[:, :, 0] - if all_probs: - full_logits = output - else: - output = F.log_softmax(output, 2) - indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2) - new_output_logits = torch.gather(output, 2, indices).squeeze(2) + if compute_logprob: + if output_logits is None: + output = F.log_softmax(output[:, :context_length, :], 2) - # TODO(rprenger) we're copying output_logits every time. Should pre-allocate - output_logits = torch.cat([output_logits, new_output_logits], 1) - all_generated_indices = torch.cat([all_generated_indices, indices[:, :, 0]], 1) - if all_probs: - full_logits = torch.cat([full_logits, output], 1) + indices = torch.unsqueeze(tokens[:, 1 : context_length + 1], 2) + output_logits = torch.gather(output, 2, indices).squeeze(2) + all_generated_indices = indices[:, :, 0] + if all_probs: + full_logits = output + else: + output = F.log_softmax(output, 2) + indices = torch.unsqueeze(new_tokens, 1).unsqueeze(2) + new_output_logits = torch.gather(output, 2, indices).squeeze(2) + + # TODO(rprenger) we're copying output_logits every time. Should pre-allocate + output_logits = torch.cat([output_logits, new_output_logits], 1) + all_generated_indices = torch.cat([all_generated_indices, indices[:, :, 0]], 1) + if all_probs: + full_logits = torch.cat([full_logits, output], 1) src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_embedding_group() @@ -752,10 +779,13 @@ def sample_sequence_batch( src = parallel_state.get_pipeline_model_parallel_last_rank() group = parallel_state.get_pipeline_model_parallel_group() torch.distributed.broadcast(done, src, group) - if all_probs: - yield tokens, lengths, output_logits, full_logits + if compute_logprob: + if all_probs: + yield tokens, lengths, output_logits, full_logits + else: + yield tokens, lengths, output_logits, None else: - yield tokens, lengths, output_logits, None + yield tokens, lengths, None, None else: if parallel_state.is_pipeline_first_stage():