diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 2245113c9c..7f1d217ac0 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -368,7 +368,7 @@ def backward(ctx, grad_output): # grad_weight = grad_output.t().matmul(total_input) from megatron.core.tensor_parallel.weight_grad_store import WeightGradStore WeightGradStore.put(total_input, grad_output, weight, gradientUpdateFunction) - grad_weight = None + grad_weight = weight.grad grad_bias = grad_output.sum(dim=0) if use_bias else None if ctx.sequence_parallel: