diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 5fee182fb1ee7..21bb96b631d22 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -395,10 +395,12 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr } CudnnReduceDescriptor reduce_desc; - if (std::is_same::value) + if (std::is_same::value) { ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, CudnnTensor::GetDataType(), ReduceTensorIndices)); - else + } else { ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, ReduceTensorIndices)); + } + const auto one = Consts::One; const auto zero = Consts::Zero; CudnnTensor input_tensor; @@ -437,7 +439,11 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr } else { // Reduce max -- Max/Min will output indices data CudnnReduceDescriptor reduce_max_desc; - ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES)); + cudnnDataType_t cudnn_reduce_max_type = cudnn_type_X; + if((std::is_same::value)) { + cudnn_reduce_max_type = CUDNN_DATA_FLOAT; + } + ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_reduce_max_type, CUDNN_REDUCE_TENSOR_NO_INDICES)); size_t indices_bytes_max = 0; CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(cuda_ep.PerThreadCudnnHandle(), reduce_max_desc, input_tensor, output_tensor, &indices_bytes_max)); diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index eb6ea80457d5f..48b9635f43e09 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -905,6 +905,40 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) { return result; } +// Reference computation is pytorch's logsumexp_backward +// dx_i = exp(xi) / reduceSum(exp(xi)) +// O(0) = log(reduceSum(exp(xi))) +// Self_Sub_Result = I(0) - O(0) = xi - log(sum(exp(xi))) = log( xi / reduceSum(exp(xi))) +// Gradient computation is re-using output and input from forward op, can be a recomputation candidate. +IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) { + std::vector result; + auto attributes = SrcNodeAttributes(); + bool keepdims = true; + if (attributes.find("keepdims") != attributes.end() && + attributes.at("keepdims").has_i()) { + keepdims = static_cast(attributes.at("keepdims").i()); + } + + ArgDef grad = GO(0); + if (!keepdims && attributes.find("axes") != attributes.end()) { + std::vector axes_values = RetrieveValues(attributes.at("axes")); + grad = IA("Unsqueezed_Grad"); + result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)})); + + result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)})); + result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")})); + } + else { + result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")})); + } + + result.push_back(NodeDef("Exp", {IA("Self_Sub_Result")}, {IA("Self_Sub_Result_Exp")})); + + result.push_back(NodeDef("Mul", {IA("Self_Sub_Result_Exp"), grad}, {GI(0)})); + + return result; +} + IMPLEMENT_GRADIENT_BUILDER(GetReduceSumGradient) { std::vector result; auto attributes = SrcNodeAttributes(); diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index 9a32e421bf5b3..819c800820ed0 100644 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -25,6 +25,7 @@ DECLARE_GRADIENT_BUILDER(GetMulGradient) DECLARE_GRADIENT_BUILDER(GetDivGradient) DECLARE_GRADIENT_BUILDER(GetReduceMeanGradient) DECLARE_GRADIENT_BUILDER(GetReduceSumGradient) +DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient) DECLARE_GRADIENT_BUILDER(GetPowGradient) DECLARE_GRADIENT_BUILDER(GetConcatGradient) DECLARE_GRADIENT_BUILDER(GetReshapeGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 94b4e1e096992..7631ce25eb311 100644 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -51,6 +51,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Pow", GetPowGradient); REGISTER_GRADIENT_BUILDER("ReduceMean", GetReduceMeanGradient); REGISTER_GRADIENT_BUILDER("ReduceSum", GetReduceSumGradient); + REGISTER_GRADIENT_BUILDER("ReduceLogSumExp", GetReduceLogSumExpGradient); REGISTER_GRADIENT_BUILDER("Add", GetAddSubGradient); REGISTER_GRADIENT_BUILDER("Sub", GetAddSubGradient); REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient); diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index a69c41e312a9d..b350e3c578c77 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -629,6 +629,8 @@ def __init__(self, model, loss_fn, model_desc, training_optimizer_name, map_opti self.world_size = world_size self.use_mixed_precision = use_mixed_precision + self.original_model_state_keys = list(model.state_dict().keys()) if hasattr(model, 'state_dict') else [] + self.session = None self.device_ = device self.gradient_accumulation_steps = gradient_accumulation_steps @@ -773,7 +775,11 @@ def state_dict(self): if n.name not in torch_state: torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n)) - return torch_state + # Need to remove redundant initializers and name suffices to map back to original torch state names + torch_state_to_return = {key: torch_state[key] for key in self.original_model_state_keys if key in torch_state} \ + if self.original_model_state_keys \ + else torch_state + return torch_state_to_return def load_state_dict(self, state_dict, strict=False): # Note: It may happen ONNX model has not yet been initialized diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.h b/orttraining/orttraining/test/gradient/gradient_op_test_utils.h index ad75d061627da..26f56ddbe0fbf 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.h +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.h @@ -7,6 +7,9 @@ namespace onnxruntime { namespace test { +using TestDataVector = std::tuple>, // Input data + std::vector>, // output data + std::vector>>; //attribute class GradientOpTester : public OpTester { public: @@ -39,3 +42,4 @@ class GradientOpTester : public OpTester { }; } // namespace test } // namespace onnxruntime + diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index b5f43f263fa89..03095320c3cae 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -38,6 +38,70 @@ static bool IsErrorWithinTolerance(float error, float tolerance) { #define EXPECT_IS_TINY(max_error) \ EXPECT_IS_TINIER_THAN(max_error, 1.5e-2f) +static void RunReductionTests(const OpDef& op_def) { + + TestDataVector test_data( + // Input X + { + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + {{4, 3, 2}}, + }, + // Input Y + { + {{1, 1, 1}}, + {{}}, + {{1, 3, 1}}, + {{2}}, + {{4, 1, 2}}, + {{4, 3}}, + {{4, 1, 2}}, + {{4}} + }, + // Attributes + { + // default + {}, + // axes = [0, 1, 2], keepdims = 0 + {MakeAttribute("axes", std::vector{0, 1, 2}), + MakeAttribute("keepdims", int64_t(0))}, + // axes = [0, 2], keepdims = 1 + {MakeAttribute("axes", std::vector{0, 2})}, + // axes = [0, 1], keepdims = 0 + {MakeAttribute("axes", std::vector{0, 1}), + MakeAttribute("keepdims", int64_t(0))}, + // axes = [1], keepdims = 1 + {MakeAttribute("axes", std::vector{1}), + MakeAttribute("keepdims", int64_t(1))}, + // axes = [2], keepdims = 0 + {MakeAttribute("axes", std::vector{2}), + MakeAttribute("keepdims", int64_t(0))}, + // axes = [-2], keepdims = 1 + {MakeAttribute("axes", std::vector{-2}), + MakeAttribute("keepdims", int64_t(1))}, + // axes = [-2, -1], keepdims = 0 + {MakeAttribute("axes", std::vector{-2, -1}), + MakeAttribute("keepdims", int64_t(0))} + }); + + GradientChecker gradient_checker; + + float max_error; + + for (size_t i = 0; i < std::get<0>(test_data).size(); i++) { + max_error = 0; + gradient_checker.ComputeGradientError(op_def, std::get<0>(test_data)[i], + std::get<1>(test_data)[i], &max_error, + std::get<2>(test_data)[i]); + EXPECT_IS_TINY(max_error); + } +} + template void GenerateRandomDataWithOneHot( std::vector>& x_datas, @@ -426,149 +490,24 @@ TEST(GradientCheckerTest, GemmGrad) { } TEST(GradientCheckerTest, ReduceMeanGrad) { - float max_error; - GradientChecker gradient_checker; // Attribute axes supports negative values from opset 11. OpDef op_def{"ReduceMean", kOnnxDomain, 11}; - // default - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 1, 1}}, &max_error); - EXPECT_IS_TINY(max_error); - } - - // TODO: Fix forward kernel behavior for default axes - // default axes, keepdims = 0 - /* - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error, - {MakeAttribute("keepdims", int64_t(0))}); - EXPECT_IS_TINY(max_error); - } - */ - - // axes = [0, 1, 2], keepdims = 0 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error, - {MakeAttribute("axes", std::vector{0, 1, 2}), - MakeAttribute("keepdims", int64_t(0))}); - EXPECT_IS_TINY(max_error); - } - - // axes = [0, 2], keepdims = 1 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 3, 1}}, &max_error, - {MakeAttribute("axes", std::vector{0, 2})}); - EXPECT_IS_TINY(max_error); - } - - // axes = [0, 1], keepdims = 0 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{2}}, &max_error, - {MakeAttribute("axes", std::vector{0, 1}), - MakeAttribute("keepdims", int64_t(0))}); - EXPECT_IS_TINY(max_error); - } - - // axes = [1], keepdims = 1 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error, - {MakeAttribute("axes", std::vector{1}), - MakeAttribute("keepdims", int64_t(1))}); - EXPECT_IS_TINY(max_error); - } - - // axes = [2], keepdims = 0 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 3}}, &max_error, - {MakeAttribute("axes", std::vector{2}), - MakeAttribute("keepdims", int64_t(0))}); - EXPECT_IS_TINY(max_error); - } - - // axes = [-2], keepdims = 1 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error, - {MakeAttribute("axes", std::vector{-2}), - MakeAttribute("keepdims", int64_t(1))}); - EXPECT_IS_TINY(max_error); - } - - // axes = [-2, -1], keepdims = 0 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4}}, &max_error, - {MakeAttribute("axes", std::vector{-2, -1}), - MakeAttribute("keepdims", int64_t(0))}); - EXPECT_IS_TINY(max_error); - } + RunReductionTests(op_def); } TEST(GradientCheckerTest, ReduceSumGrad) { - float max_error; - GradientChecker gradient_checker; // Attribute axes supports negative values from opset 11. OpDef op_def{"ReduceSum", kOnnxDomain, 11}; - // default - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 1, 1}}, &max_error); - EXPECT_IS_TINY(max_error); - } - - // axes = [0, 1, 2], keepdims = 0 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error, - {MakeAttribute("axes", std::vector{0, 1, 2}), - MakeAttribute("keepdims", int64_t(0))}); - EXPECT_IS_TINY(max_error); - } - - // axes = [0, 2], keepdims = 1 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 3, 1}}, &max_error, - {MakeAttribute("axes", std::vector{0, 2})}); - EXPECT_IS_TINY(max_error); - } - - // axes = [0, 1], keepdims = 0 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{2}}, &max_error, - {MakeAttribute("axes", std::vector{0, 1}), - MakeAttribute("keepdims", int64_t(0))}); - EXPECT_IS_TINY(max_error); - } - - // axes = [1], keepdims = 1 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error, - {MakeAttribute("axes", std::vector{1}), - MakeAttribute("keepdims", int64_t(1))}); - EXPECT_IS_TINY(max_error); - } - - // axes = [2], keepdims = 0 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 3}}, &max_error, - {MakeAttribute("axes", std::vector{2}), - MakeAttribute("keepdims", int64_t(0))}); - EXPECT_IS_TINY(max_error); - } + RunReductionTests(op_def); +} - // axes = [-2], keepdims = 1 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error, - {MakeAttribute("axes", std::vector{-2}), - MakeAttribute("keepdims", int64_t(1))}); - EXPECT_IS_TINY(max_error); - } +TEST(GradientCheckerTest, ReduceLogSumExpGrad) { + // Attribute axes supports negative values from opset 11. + OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11}; - // axes = [-1, -3], keepdims = 0 - { - gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{3}}, &max_error, - {MakeAttribute("axes", std::vector{-1, -3}), - MakeAttribute("keepdims", int64_t(0))}); - EXPECT_IS_TINY(max_error); - } + RunReductionTests(op_def); } #ifndef USE_CUDA @@ -1960,3 +1899,4 @@ TEST(GradientCheckerTest, ExpandGrad) { } // namespace onnxruntime #endif // NDEBUG +