Shape Polymorphism, with an image downscale #15995
-
Hi, I'd opened a prior discussion that got a bit long, so I'm restarting with a more specific question. I want to export a model with Jax2tf that has polymorphic shapes and an image resizing. Specifically, I want to use a float instead of an integer, to result in a downscaled image. Here's a sample (largely copied from here): from jax.experimental import jax2tf
polymorphic_image_shape = jax2tf.shape_poly.PolyShape(
1, 'height', 'width', 3
)
def apply_fn(image):
b,h,w,c = image.shape
h = h / 2
return jax.image.resize(image, (b,h,w,c), method="linear")
apply_fn_jitted = jax.jit(apply_fn, backend='cpu')
apply_fn_tf = tf.function(
jax2tf.convert(
apply_fn_jitted,
with_gradient=False,
enable_xla=True,
polymorphic_shapes=[polymorphic_image_shape]),
autograph=False,
jit_compile=True)
image_f = np.random.uniform(size=(1, 4, 4, 3))
apply_fn_tf(image_f) This fails with:
If I switch the Is there any way to solve this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
At some level this is independent of |
Beta Was this translation helpful? Give feedback.
At some level this is independent of
jax2tf
: dimensions must be integers. I.e., can youjit
the function that you are trying to convert withjax2tf
? You probably still want to use an integer and useh // 2
for downscaling. Once you start using division you are approaching the limits of the reasoning that shape polymorphism can do in terms of comparing dimension values. You may run into some errors that are specific to shape polymorphism. If you are prepared to ensure that you always will use this function on images with even height, then you can write2 * height
in lieu ofheight
in the polymorphic shape specification.