Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 685751531
Change-Id: I7548d73eeffdaa68b8c684fb6e4009e610837c08
  • Loading branch information
Akshaya Purohit authored and copybara-github committed Oct 14, 2024
1 parent 6b4d9e9 commit cbcc62e
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions qkeras/qseparable_conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import tensorflow as tf
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import InputSpec
from tf_keras.utils import conv_utils

from .qconvolutional import deconv_output_length
from .quantizers import get_quantizer
from tensorflow.python.eager import context
from tensorflow.python.keras import constraints
Expand Down Expand Up @@ -138,21 +138,23 @@ def _get_output_size(self, inputs, output_padding, padding, strides,
out_pad_h, out_pad_w = output_padding

# Infer the dynamic output shape:
out_height = conv_utils.deconv_output_length(
out_height = deconv_output_length(
height,
kernel_h,
padding=padding,
output_padding=out_pad_h,
stride=stride_h,
dilation=dilation_h)
dilation=dilation_h,
)

out_width = conv_utils.deconv_output_length(
out_width = deconv_output_length(
width,
kernel_w,
padding=padding,
output_padding=out_pad_w,
stride=stride_w,
dilation=dilation_w)
dilation=dilation_w,
)

return (batch_size, out_height, out_width, kernel_h, kernel_w)

Expand Down Expand Up @@ -233,15 +235,15 @@ def compute_final_output_shape(
# Pointwise convolution maps input channels to output filters.
output_shape[c_axis] = self.filters

output_shape[h_axis] = conv_utils.deconv_output_length(
output_shape[h_axis] = deconv_output_length(
output_shape[h_axis],
kernel_h,
padding=self.padding,
output_padding=out_pad_h,
stride=stride_h,
dilation=self.dilation_rate[0],
)
output_shape[w_axis] = conv_utils.deconv_output_length(
output_shape[w_axis] = deconv_output_length(
output_shape[w_axis],
kernel_w,
padding=self.padding,
Expand Down

0 comments on commit cbcc62e

Please sign in to comment.