Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync oneflow pro interpolate like op #10421

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions oneflow/core/autograd/gradient_funcs/upsample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("upsample", Upsample);

struct UpsampleNearest2DCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool has_like_input = false;
double height_scale = 0.0;
double width_scale = 0.0;
std::vector<int64_t> output_size;
Expand All @@ -105,6 +106,8 @@ class UpsampleNearest2D : public OpExprGradFunction<UpsampleNearest2DCaptureStat
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->SaveTensorForBackward(inputs.at(0));
ctx->has_like_input = inputs.size() == 2;
if (ctx->has_like_input) { ctx->SaveTensorForBackward(inputs.at(1)); }
return Maybe<void>::Ok();
}

Expand All @@ -114,9 +117,15 @@ class UpsampleNearest2D : public OpExprGradFunction<UpsampleNearest2DCaptureStat
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest2DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale,
ctx->output_size, ctx->data_format));
if (ctx->has_like_input) {
const std::shared_ptr<oneflow::one::Tensor>& like = ctx->SavedTensors().at(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest2DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, like, ctx->data_format));
} else {
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest2DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale,
ctx->output_size, ctx->data_format));
}

return Maybe<void>::Ok();
}
Expand Down Expand Up @@ -227,6 +236,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_linear_1d", UpsampleLinear1D);

struct UpsampleNearest1DCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool has_like_input = false;
double scale_factor = 0.0;
std::vector<int64_t> output_size;
std::string data_format;
Expand All @@ -238,7 +248,7 @@ class UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DCaptureStat

Maybe<void> Capture(UpsampleNearest1DCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_GE_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand All @@ -249,6 +259,8 @@ class UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DCaptureStat
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->SaveTensorForBackward(inputs.at(0));
ctx->has_like_input = inputs.size() == 2;
if (ctx->has_like_input) { ctx->SaveTensorForBackward(inputs.at(1)); }
return Maybe<void>::Ok();
}

Expand All @@ -258,9 +270,15 @@ class UpsampleNearest1D : public OpExprGradFunction<UpsampleNearest1DCaptureStat
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(
functional::UpsampleNearest1DGrad(JUST(oneflow::VectorAt(out_grads, 0)), x,
ctx->scale_factor, ctx->output_size, ctx->data_format));
if (ctx->has_like_input) {
const std::shared_ptr<oneflow::one::Tensor>& like = ctx->SavedTensors().at(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest1DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, like, ctx->data_format));
} else {
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(
functional::UpsampleNearest1DGrad(JUST(oneflow::VectorAt(out_grads, 0)), x,
ctx->scale_factor, ctx->output_size, ctx->data_format));
}

return Maybe<void>::Ok();
}
Expand Down Expand Up @@ -322,6 +340,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_bicubic_2d", UpsampleBicubic2D);

struct UpsampleNearest3DCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
bool has_like_input = false;
double depth_scale = 0.0;
double height_scale = 0.0;
double width_scale = 0.0;
Expand All @@ -348,6 +367,10 @@ class UpsampleNearest3D : public OpExprGradFunction<UpsampleNearest3DCaptureStat
}
ctx->data_format = JUST(composed_attrs.GetAttr<std::string>("data_format"));
ctx->SaveTensorForBackward(inputs.at(0));
ctx->has_like_input = inputs.size() == 2;
if (ctx->has_like_input) { ctx->SaveTensorForBackward(inputs.at(1)); }
ctx->has_like_input = inputs.size() == 2;
if (ctx->has_like_input) { ctx->SaveTensorForBackward(inputs.at(1)); }
return Maybe<void>::Ok();
}

Expand All @@ -357,9 +380,15 @@ class UpsampleNearest3D : public OpExprGradFunction<UpsampleNearest3DCaptureStat
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
const std::shared_ptr<oneflow::one::Tensor>& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest3DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale,
ctx->width_scale, ctx->output_size, ctx->data_format));
if (ctx->has_like_input) {
const std::shared_ptr<oneflow::one::Tensor>& like = ctx->SavedTensors().at(1);
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest3DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, like, ctx->data_format));
} else {
JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest3DGrad(
JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale,
ctx->width_scale, ctx->output_size, ctx->data_format));
}

return Maybe<void>::Ok();
}
Expand Down
53 changes: 35 additions & 18 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1777,27 +1777,35 @@
bind_python: False

- name: "upsample_nearest_1d"
signature:
'Tensor (Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None,
String data_format="channels_first") => UpsampleNearest1D'
signature: [
'Tensor (Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None,
String data_format="channels_first") => UpsampleNearest1D',
'Tensor (Tensor x, Tensor like, String data_format="channels_first") => UpsampleNearest1D'
]
bind_python: True

- name: "upsample_nearest_1d_grad"
signature:
'Tensor (Tensor dy, Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None,
String data_format="channels_first") => UpsampleNearest1DGrad'
signature: [
'Tensor (Tensor dy, Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None,
String data_format="channels_first") => UpsampleNearest1DGrad',
'Tensor (Tensor dy, Tensor x, Tensor like, String data_format="channels_first") => UpsampleNearest1DGrad'
]
bind_python: False

- name: "upsample_nearest_2d"
signature:
'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None,
String data_format="channels_first") => UpsampleNearest2D'
signature: [
'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None,
String data_format="channels_first") => UpsampleNearest2D',
'Tensor (Tensor x, Tensor like, String data_format="channels_first") => UpsampleNearest2D'
]
bind_python: True

- name: "upsample_nearest_2d_grad"
signature:
'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None,
String data_format="channels_first") => UpsampleNearest2DGrad'
signature: [
'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None,
String data_format="channels_first") => UpsampleNearest2DGrad',
'Tensor (Tensor dy, Tensor x, Tensor like, String data_format="channels_first") =>UpsampleNearest2DGrad'
]
bind_python: False

- name: "upsample_bilinear_2d"
Expand Down Expand Up @@ -1825,15 +1833,19 @@
bind_python: False

- name: "upsample_nearest_3d"
signature:
'Tensor (Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None,
String data_format="channels_first") => UpsampleNearest3D'
signature: [
'Tensor (Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None,
String data_format="channels_first") => UpsampleNearest3D',
'Tensor (Tensor x, Tensor like, String data_format="channels_first") => UpsampleNearest3D'
]
bind_python: True

- name: "upsample_nearest_3d_grad"
signature:
'Tensor (Tensor dy, Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None,
String data_format="channels_first") => UpsampleNearest3DGrad'
signature: [
'Tensor (Tensor dy, Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None,
String data_format="channels_first") => UpsampleNearest3DGrad',
'Tensor (Tensor dy, Tensor x, Tensor like, String data_format="channels_first") => UpsampleNearest3DGrad'
]
bind_python: False

- name: "upsample_trilinear_3d"
Expand Down Expand Up @@ -3494,3 +3506,8 @@
- name: "fused_clip_grad"
signature: "Tensor (TensorTuple model_diff, Float max_norm, Float norm_type) => FusedClipGrad"
bind_python: True

- name: "fused_group_norm_quantization"
signature: 'TensorTuple[y, y_scale, y_zero_point] (Tensor x, Tensor gamma=None, Tensor beta=None, Bool affine, Int32 num_groups, Double epsilon=1e-5, String data_format="channels_first", String activation="none", String quantization_scheme="affine", Int32 quantization_bit=8, String quantization_formula="oneflow", String quantization_mode="tensorwise") => FusedGroupNormQuantization'
bind_python: True

Loading
Loading