Skip to content

Commit

Permalink
Make the jaxpr for jnp.pad in "constant" mode more succinct.
Browse files Browse the repository at this point in the history
Example before:

```
$ print(jax.jit(lambda x: jnp.pad(x, ((0, 0), (1, 0), (0, 1)), constant_values=7)).lower(jnp.ones((3,4,5))).as_text())
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<7> : tensor<i32>
    %0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor<i32>) -> tensor<3x5x6xf32>
    return %0 : tensor<3x5x6xf32>
  }
  func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor<i32>) -> tensor<3x5x6xf32> {
    %0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<i32>) -> tensor<3x2xi32>
    %1 = stablehlo.convert %0 : (tensor<3x2xi32>) -> tensor<3x2xf32>
    %2 = stablehlo.slice %1 [0:1, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %3 = stablehlo.reshape %2 : (tensor<1x1xf32>) -> tensor<f32>
    %4 = stablehlo.pad %arg0, %3, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
    %5 = stablehlo.slice %1 [0:1, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %6 = stablehlo.reshape %5 : (tensor<1x1xf32>) -> tensor<f32>
    %7 = stablehlo.pad %4, %6, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
    %8 = stablehlo.slice %1 [1:2, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %9 = stablehlo.reshape %8 : (tensor<1x1xf32>) -> tensor<f32>
    %10 = stablehlo.pad %7, %9, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
    %11 = stablehlo.slice %1 [1:2, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %12 = stablehlo.reshape %11 : (tensor<1x1xf32>) -> tensor<f32>
    %13 = stablehlo.pad %10, %12, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
    %14 = stablehlo.slice %1 [2:3, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %15 = stablehlo.reshape %14 : (tensor<1x1xf32>) -> tensor<f32>
    %16 = stablehlo.pad %13, %15, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
    %17 = stablehlo.slice %1 [2:3, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
    %18 = stablehlo.reshape %17 : (tensor<1x1xf32>) -> tensor<f32>
    %19 = stablehlo.pad %16, %18, low = [0, 0, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x6xf32>
    return %19 : tensor<3x5x6xf32>
  }
}
```

After:
```
module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) {
    %c = stablehlo.constant dense<7> : tensor<i32>
    %0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor<i32>) -> tensor<3x5x6xf32>
    return %0 : tensor<3x5x6xf32>
  }
  func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor<i32>) -> tensor<3x5x6xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<i32>) -> tensor<f32>
    %1 = stablehlo.pad %arg0, %0, low = [0, 1, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x5x6xf32>
    return %1 : tensor<3x5x6xf32>
  }
}
```
  • Loading branch information
hawkinsp committed Nov 14, 2024
1 parent aefe621 commit ad5a062
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4048,15 +4048,37 @@ def _check_no_padding(axis_padding: tuple[Any, Any], mode: str):

def _pad_constant(array: Array, pad_width: PadValue[int], constant_values: Array) -> Array:
nd = ndim(array)
constant_values = broadcast_to(constant_values, (nd, 2))
constant_values = lax_internal._convert_element_type(
constant_values, array.dtype, dtypes.is_weakly_typed(array))
constant_values_nd = ndim(constant_values)

if constant_values_nd == 0:
widths = [(low, high, 0) for (low, high) in pad_width]
return lax.pad(array, constant_values, widths)

if constant_values_nd == 1:
if constant_values.shape[-1] == 1:
widths = [(low, high, 0) for (low, high) in pad_width]
return lax.pad(array, squeeze(constant_values), widths)
elif constant_values.shape[-1] == 2:
widths = [(low, 0, 0) for (low, _) in pad_width]
array = lax.pad(array, constant_values[0], widths)
widths = [(0, high, 0) for (_, high) in pad_width]
return lax.pad(array, constant_values[1], widths)
else:
raise ValueError("jnp.pad: constant_values has unsupported shape "
f"{constant_values.shape}. If the shape is 1D or 2D, the "
"last dimension must be of size 1 or 2.")

constant_values = broadcast_to(constant_values, (nd, 2))
for i in range(nd):
widths = [(0, 0, 0)] * nd
widths[i] = (pad_width[i][0], 0, 0)
array = lax.pad(array, constant_values[i, 0], widths)
widths[i] = (0, pad_width[i][1], 0)
array = lax.pad(array, constant_values[i, 1], widths)
if pad_width[i][0] != 0:
widths[i] = (pad_width[i][0], 0, 0)
array = lax.pad(array, constant_values[i, 0], widths)
if pad_width[i][1] != 0:
widths[i] = (0, pad_width[i][1], 0)
array = lax.pad(array, constant_values[i, 1], widths)
return array


Expand Down

0 comments on commit ad5a062

Please sign in to comment.