Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the jaxpr for jnp.pad in "constant" mode more succinct. #24896

Merged
merged 1 commit into from
Nov 14, 2024

Commits on Nov 14, 2024

  1. Make the jaxpr for jnp.pad in "constant" mode more succinct.

    Example before:
    
    ```
    $ print(jax.jit(lambda x: jnp.pad(x, ((0, 0), (1, 0), (0, 1)), constant_values=7)).lower(jnp.ones((3,4,5))).as_text())
    module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
      func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) {
        %c = stablehlo.constant dense<7> : tensor<i32>
        %0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor<i32>) -> tensor<3x5x6xf32>
        return %0 : tensor<3x5x6xf32>
      }
      func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor<i32>) -> tensor<3x5x6xf32> {
        %0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<i32>) -> tensor<3x2xi32>
        %1 = stablehlo.convert %0 : (tensor<3x2xi32>) -> tensor<3x2xf32>
        %2 = stablehlo.slice %1 [0:1, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
        %3 = stablehlo.reshape %2 : (tensor<1x1xf32>) -> tensor<f32>
        %4 = stablehlo.pad %arg0, %3, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
        %5 = stablehlo.slice %1 [0:1, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
        %6 = stablehlo.reshape %5 : (tensor<1x1xf32>) -> tensor<f32>
        %7 = stablehlo.pad %4, %6, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x4x5xf32>
        %8 = stablehlo.slice %1 [1:2, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
        %9 = stablehlo.reshape %8 : (tensor<1x1xf32>) -> tensor<f32>
        %10 = stablehlo.pad %7, %9, low = [0, 1, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
        %11 = stablehlo.slice %1 [1:2, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
        %12 = stablehlo.reshape %11 : (tensor<1x1xf32>) -> tensor<f32>
        %13 = stablehlo.pad %10, %12, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
        %14 = stablehlo.slice %1 [2:3, 0:1] : (tensor<3x2xf32>) -> tensor<1x1xf32>
        %15 = stablehlo.reshape %14 : (tensor<1x1xf32>) -> tensor<f32>
        %16 = stablehlo.pad %13, %15, low = [0, 0, 0], high = [0, 0, 0], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x5xf32>
        %17 = stablehlo.slice %1 [2:3, 1:2] : (tensor<3x2xf32>) -> tensor<1x1xf32>
        %18 = stablehlo.reshape %17 : (tensor<1x1xf32>) -> tensor<f32>
        %19 = stablehlo.pad %16, %18, low = [0, 0, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x5x5xf32>, tensor<f32>) -> tensor<3x5x6xf32>
        return %19 : tensor<3x5x6xf32>
      }
    }
    ```
    
    After:
    ```
    module @jit__lambda_ attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
      func.func public @main(%arg0: tensor<3x4x5xf32>) -> (tensor<3x5x6xf32> {jax.result_info = ""}) {
        %c = stablehlo.constant dense<7> : tensor<i32>
        %0 = call @_pad(%arg0, %c) : (tensor<3x4x5xf32>, tensor<i32>) -> tensor<3x5x6xf32>
        return %0 : tensor<3x5x6xf32>
      }
      func.func private @_pad(%arg0: tensor<3x4x5xf32>, %arg1: tensor<i32>) -> tensor<3x5x6xf32> {
        %0 = stablehlo.convert %arg1 : (tensor<i32>) -> tensor<f32>
        %1 = stablehlo.pad %arg0, %0, low = [0, 1, 0], high = [0, 0, 1], interior = [0, 0, 0] : (tensor<3x4x5xf32>, tensor<f32>) -> tensor<3x5x6xf32>
        return %1 : tensor<3x5x6xf32>
      }
    }
    ```
    hawkinsp committed Nov 14, 2024
    Configuration menu
    Copy the full SHA
    ad5a062 View commit details
    Browse the repository at this point in the history