Skip to content

Commit

Permalink
[Op] Conv1dTranspose (#269)
Browse files Browse the repository at this point in the history
This PR introduces Conv1dTranspose to relax.

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-15-248.us-west-2.compute.internal>
  • Loading branch information
2 people authored and tqchen committed Aug 1, 2023
1 parent cca3bf0 commit 863ac8b
Show file tree
Hide file tree
Showing 7 changed files with 682 additions and 0 deletions.
45 changes: 45 additions & 0 deletions include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,51 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
}
}; // struct Conv2dAttrs

/*! \brief Attributes used in Conv1DTranspose operator */
struct Conv1DTransposeAttrs : public tvm::AttrsNode<Conv1DTransposeAttrs> {
Array<IntImm> strides;
Array<IntImm> padding;
Array<IntImm> output_padding;
Array<IntImm> dilation;
int groups;
String data_layout;
String kernel_layout;
String out_layout;
DataType out_dtype;

TVM_DECLARE_ATTRS(Conv1DTransposeAttrs, "relax.attrs.Conv1DTransposeAttrs") {
TVM_ATTR_FIELD(strides).describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).describe(
"If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on both sides"
"two int : padding width in the order of (left, right)");
TVM_ATTR_FIELD(output_padding).describe("Used to disambiguate the output shape.");
TVM_ATTR_FIELD(dilation).describe(
"Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).describe(
"Number of groups to split the input into for grouped convolution. The number of input and "
"output channels should be divisible by the number of groups.");
TVM_ATTR_FIELD(data_layout)
.describe(
"Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, width"
"dimensions respectively. Convolution is applied on the 'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout)
.describe(
"Dimension ordering of weight. Can be 'OIW', 'IOW', etc."
"'O', 'I', 'W' stands for num_filter, input_channel, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_layout)
.describe(
"Dimension ordering of output. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Default to be same as input layout.");
TVM_ATTR_FIELD(out_dtype).describe(
"Output data type, set to explicit type under mixed precision setting");
}
}; // struct Conv1DTransposeAttrs

/*! \brief Attributes used in Conv2d operator */
struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
Array<IntImm> strides;
Expand Down
91 changes: 91 additions & 0 deletions python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,97 @@ def conv2d(
)


def conv1d_transpose(
data: Expr,
weight: Expr,
strides: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int, ...]] = 0,
output_padding: Union[int, Tuple[int]] = 0,
dilation: Union[int, Tuple[int]] = 1,
groups: int = 1,
data_layout: str = "NCW",
kernel_layout: str = "IOW",
out_layout: Optional[str] = None,
out_dtype: Optional[Union[str, DataType]] = None,
) -> Expr:
r"""1D transposed convolution operator.
This operator can be seen as the gradient operator of conv1d.
The output shape can be explained in the simple case when `data_layout == "NCW"` and
`kernel_layout == "IOW"`. Suppose `data` has shape `(N, in_channel, in_w)`, `weight` has
shape `(in_channel, out_channel, weight_w)`, we need to assure that `in_channel % groups == 0`.
The shape of the output will be `(N, out_channel * groups, out_w)`, where
- `out_w = ((in_w - 1) * strides[0] + weight_w - 2 * padding[0] + output_padding[0])`
Parameters
----------
data : relax.Expr
The input data to the operator.
weight : relax.Expr
The weight expressions.
strides : Union[int, Tuple[int]]
The strides of convolution. It is required to have length 1.
padding : Union[int, Tuple[int, ...]]
The padding of convolution on both sides of inputs before convolution.
It is required to have length either 1 or 2.
output_padding : Union[int, Tuple[int, ...]], optional
Used to disambiguate the output shape.
dilation : Union[int, Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
It is required to have length either 1.
groups : int
Number of groups to split the input into for grouped convolution.
The number of input and output channels should be divisible by the number of groups.
data_layout : str
Layout of the input.
kernel_layout : str
Layout of the weight.
out_layout : Optional[str]
Layout of the output. If not specified, it is the same as data_layout
out_dtype : Optional[Union[str, DataType]]
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(strides, int):
strides = (strides,)
if isinstance(dilation, int):
dilation = (dilation,)
if isinstance(padding, int):
padding = (padding, padding)
if isinstance(output_padding, int):
output_padding = (output_padding,)

return _ffi_api.conv1d_transpose( # type: ignore
data,
weight,
strides,
padding,
output_padding,
dilation,
groups,
data_layout,
kernel_layout,
out_layout,
out_dtype,
)


def conv2d_transpose(
data: Expr,
weight: Expr,
Expand Down
40 changes: 40 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,46 @@ def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.nn.conv1d_transpose")
def _nn_conv1d_transpose(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.data_layout:
logging.info(
"TOPI conv1d_transpose does not support different input-output "
"layouts, and thus cannot be legalized by TOPI"
)
return call
if call.attrs.data_layout != "NCW" or call.attrs.kernel_layout != "IOW":
logging.info(
"TOPI conv1d_transpose does not support input layout other than NCW, "
"and kernel layout other than IOW, so cannot be legalized by TOPI"
)
return call
dilation = call.attrs.dilation
if len(dilation) != 1 or dilation[0] != 1:
logging.info(
"TOPI conv1d_transpose does not support dilations other than 1, "
"and thus cannot be legalized by TOPI"
)
return call
if call.attrs.groups != 1:
logging.info(
"TOPI conv1d_transpose does not support groups other than 1, "
"and thus cannot be legalized by TOPI"
)
return call

return bb.call_te(
topi.nn.conv1d_transpose_ncw,
call.args[0],
call.args[1],
stride=call.attrs.strides,
padding=call.attrs.padding,
out_dtype=call.struct_info.dtype,
output_padding=call.attrs.output_padding,
primfunc_name_hint="conv1d_transpose",
)


@register_legalize("relax.nn.conv2d_transpose")
def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.data_layout:
Expand Down
126 changes: 126 additions & 0 deletions src/relax/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,132 @@ TVM_REGISTER_OP("relax.nn.conv2d")
.set_attr<FInferMixedPrecision>("FInferMixedPrecision", InferMixedPrecisionConv2d)
.set_attr<Bool>("FPurity", Bool(true));

TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs);

Expr conv1d_transpose(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding,
Array<IntImm> output_padding, Array<IntImm> dilation, int groups,
String data_layout, String kernel_layout, Optional<String> out_layout,
DataType out_dtype) {
padding = GetCompletePadding1D(std::move(padding));

CHECK_GT(groups, 0) << "The number of groups in convolution is expected to be positive. However, "
"the given number of groups is "
<< groups;
CHECK_EQ(output_padding.size(), 1) << "The input output_padding length is expected to be 1. "
"However, the given output_padding is "
<< output_padding;
CHECK_EQ(strides.size(), 1)
<< "The input strides length is expected to be 1. However, the given strides is " << strides;
CHECK_EQ(dilation.size(), 1)
<< "The input dilation length is expected to be 1. However, the given dilation is "
<< dilation;

auto attrs = make_object<Conv1DTransposeAttrs>();
attrs->strides = ConvertIntImmToInt64(strides);
attrs->padding = ConvertIntImmToInt64(padding);
attrs->output_padding = ConvertIntImmToInt64(output_padding);
attrs->dilation = ConvertIntImmToInt64(dilation);
attrs->groups = groups;
attrs->data_layout = data_layout;
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout.value_or(data_layout));
attrs->out_dtype = std::move(out_dtype);
const Op& op = Op::Get("relax.nn.conv1d_transpose");
return Call(op, {data, weight}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.conv1d_transpose").set_body_typed(conv1d_transpose);

StructInfo InferStructInfoConv1dTranspose(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
TensorStructInfo data_sinfo = input_sinfo[0];
TensorStructInfo weight_sinfo = input_sinfo[1];

const auto* attrs = call->attrs.as<Conv1DTransposeAttrs>();
auto [data_layout, data2NCW] = CheckTensorLayout(call, ctx, attrs->data_layout, //
/*tgt_layout=*/"NCW", //
/*tensor_name=*/"data");
auto [weight_layout, weight2IOW] = CheckTensorLayout(call, ctx, attrs->kernel_layout, //
/*tgt_layout=*/"IOW", //
/*tensor_name=*/"kernel");
auto [out_layout, out2NCW] = CheckTensorLayout(call, ctx, attrs->out_layout, //
/*tgt_layout=*/"NCW", //
/*tensor_name=*/"output");
Optional<ShapeExpr> data_shape =
CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
Optional<ShapeExpr> weight_shape =
CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout);

DataType out_dtype = attrs->out_dtype.is_void()
? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, weight_sinfo)
: attrs->out_dtype;
if (!data_shape.defined() || !weight_shape.defined()) {
return TensorStructInfo(out_dtype, out_layout.ndim());
}

Array<PrimExpr> data_NCW_shape = data2NCW.ForwardShape(data_shape.value()->values);
Array<PrimExpr> weight_IOW_shape = weight2IOW.ForwardShape(weight_shape.value()->values);

arith::Analyzer* analyzer = ctx->GetAnalyzer();
PrimExpr input_channel_data = data_NCW_shape[1];
PrimExpr input_channel_kernel = weight_IOW_shape[0];
if (analyzer->CanProve(input_channel_data != input_channel_kernel)) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "Conv1dTranspose expects the channel size of the data should equal to the input channel "
"size of the weight. However, the data channel size is "
<< input_channel_data << " while the weight input channel size is "
<< input_channel_kernel);
} else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel)) {
// Todo(relax-team): Trust the input shape at this moment, and revisit
// this condition with runtime shape check
}
if (analyzer->CanProve(floormod(input_channel_kernel, attrs->groups) != 0)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Conv1dTranspose expects the number of input channels to be divisible by "
"the number of groups. However, the number of input channels is "
<< input_channel_kernel << " while the number of groups is " << attrs->groups);
} else if (!analyzer->CanProveEqual(floormod(input_channel_kernel, attrs->groups), 0)) {
// Todo(relax-team): Trust the input shape at this moment, and revisit
// this condition with runtime shape check
}
if (analyzer->CanProve(attrs->output_padding[0]->value >= attrs->strides[0]->value)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Conv1dTranspose expects the output padding less than the strides, but the "
"output padding is"
<< attrs->output_padding << " while the strides are" << attrs->strides);
} else if (!analyzer->CanProve(attrs->output_padding[0]->value < attrs->strides[0]->value)) {
// Todo(relax-team): Trust the input padding at this moment, and revisit
// this condition with runtime shape check
}

PrimExpr input_w = data_NCW_shape[2];
PrimExpr kernel_w = weight_IOW_shape[2];
PrimExpr padding_w = attrs->padding[0] + attrs->padding[1];

std::vector<PrimExpr> out_NCW_shape;
out_NCW_shape.resize(3);
out_NCW_shape[0] = data_NCW_shape[0];
out_NCW_shape[1] = weight_IOW_shape[1] * attrs->groups;

PrimExpr out_w = (input_w - 1) * attrs->strides[0] - padding_w +
attrs->dilation[0] * (kernel_w - 1) + attrs->output_padding[0] + 1;
out_NCW_shape[2] = analyzer->Simplify(out_w);

Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
}

// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for conv1d_transpose
// and unit test for mixed_precision
TVM_REGISTER_OP("relax.nn.conv1d_transpose")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_attrs_type<Conv1DTransposeAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv1dTranspose)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.conv2d_transpose */
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);

Expand Down
11 changes: 11 additions & 0 deletions src/relax/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,17 @@ Expr conv2d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding
Array<IntImm> dilation, int groups, String data_layout, String kernel_layout,
Optional<String> out_layout, DataType out_dtype);

/*!
* \brief One dimensional transposed convolution operator.
*
* This operator is intended to be the backward operator of conv1d. It can be used to calculate the
* gradient of the result of conv1d w.r.t. the input of conv1d.
*/
Expr conv1d_transpose(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> padding,
Array<IntImm> output_padding, Array<IntImm> dilation, int groups,
String data_layout, String kernel_layout, Optional<String> out_layout,
DataType out_dtype);

/*!
* \brief Two dimensional transposed convolution operator.
*
Expand Down
Loading

0 comments on commit 863ac8b

Please sign in to comment.