Skip to content

Commit

Permalink
Add a small number of utils to make MLIR easier and nicer to use.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668170852
  • Loading branch information
Google-ML-Automation authored and jax authors committed Sep 6, 2024
1 parent 671acef commit afcf5c2
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 0 deletions.
38 changes: 38 additions & 0 deletions jaxlib/mlir/utils/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

licenses(["notice"])

package(
default_applicable_licenses = [],
default_visibility = ["//jax:internal"],
)

# TODO(mvoz):Break up into smaller targets, utils is a catchall name.
cc_library(
name = "mlir_utils",
hdrs = ["mlir_utils.h"],
# compatible with libtpu
features = ["-use_header_modules"],
visibility = [
"//platforms/xla/service:__subpackages__",
],
deps = [
"//jaxlib/mosaic:tpu_dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
],
)
150 changes: 150 additions & 0 deletions jaxlib/mlir/utils/mlir_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#ifndef JAX_JAXLIB_MLIR_UTILS_MLIR_UTILS_H_
#define JAX_JAXLIB_MLIR_UTILS_MLIR_UTILS_H_

// Helper functions and utilities to make MLIR less verbose, easier to hold.

#include <cstdint>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include <string>

#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/IR/SCF.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineExpr.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineMap.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Builders.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinAttributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/MLIRContext.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/ArrayRef.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h"

namespace jax {
namespace mlir {

struct MLIRIterationBound {
std::string name;
int64_t bound_size;
::mlir::tpu::DimensionSemantics dimension_semantics;
::mlir::AffineExpr ae;
};

class MLIRIterationContext {
public:
MLIRIterationContext(::mlir::ImplicitLocOpBuilder builder,
::mlir::MLIRContext* context,
std::vector<MLIRIterationBound> iteration_bounds)
: context_(context) {
std::vector<int64_t> bound_sizes;
bound_sizes.reserve(iteration_bounds.size());

std::vector<::mlir::Attribute> dimension_semantics;
dimension_semantics.reserve(iteration_bounds.size());

affine_exprs_.reserve(iteration_bounds.size());

int i = 0;
for (const auto& bound : iteration_bounds) {
name_to_iteration_bound_[bound.name] = i;
dimension_semantics.emplace_back(::mlir::tpu::DimensionSemanticsAttr::get(
context, bound.dimension_semantics));
bound_sizes.emplace_back(bound.bound_size);
affine_exprs_.emplace_back(bound.ae);
i++;
}

iteration_bounds_ =
builder.getDenseI64ArrayAttr(::llvm::ArrayRef(bound_sizes));
dimension_semantics_ = builder.getArrayAttr(
::llvm::ArrayRef<::mlir::Attribute>(dimension_semantics));
}

template <typename... Names>
::mlir::AffineMapAttr getAffineMapAttr(Names... names) {
std::vector<::mlir::AffineExpr> affine_exprs;
for (const auto& name : {names...}) {
// TODO(mvoz): ensure the name is there
auto idx = name_to_iteration_bound_.at(name);
auto val = affine_exprs_[idx];
affine_exprs.push_back(affine_exprs_[idx]);
}
return ::mlir::AffineMapAttr::get(::mlir::AffineMap::get(
/*dimCount=*/iteration_bounds_.size(),
/*symbolCount=*/0, affine_exprs, context_));
}

::mlir::DenseI64ArrayAttr getIterationBounds() const {
return iteration_bounds_;
}

::mlir::ArrayAttr getDimensionSemantics() const {
return dimension_semantics_;
}

private:
::mlir::DenseI64ArrayAttr iteration_bounds_;
::mlir::ArrayAttr dimension_semantics_;
std::unordered_map<std::string, int> name_to_iteration_bound_;
::mlir::MLIRContext* context_;

::mlir::SmallVector<::mlir::AffineExpr> affine_exprs_;
};

class MLIRHelper {
public:
MLIRHelper(::mlir::ImplicitLocOpBuilder builder) : builder_(builder) {}

template <typename Op, typename L>
::mlir::scf::IfOp if_op(Op predicate, L mlir_if_then_block) {
return builder_.create<::mlir::scf::IfOp>(predicate.getLoc(), predicate,
mlir_if_then_block);
}

template <typename LT, typename RT>
::mlir::arith::AndIOp and_op(LT lhs, RT rhs) {
return builder_.create<::mlir::arith::AndIOp>(lhs, rhs);
}

::mlir::arith::SelectOp select(const ::mlir::Value& mask, const ::mlir::Value& lhs, const ::mlir::Value& rhs) {
return builder_.create<::mlir::arith::SelectOp>(mask, lhs, rhs);
}

template <typename LT, typename RT, std::enable_if_t<!std::is_scalar_v<RT>>>
::mlir::arith::CmpIOp eq(const LT& lhs, const RT& rhs) {
return builder_.create<::mlir::arith::CmpIOp>(
::mlir::arith::CmpIPredicate::eq, lhs, rhs);
}

template <typename LT>
::mlir::arith::CmpIOp eq(const LT& lhs, int rhs_scalar) {
// TODO: Type check the scalar, route to correct mlir attr, not always index
auto rhs = builder_.create<::mlir::arith::ConstantOp>(
builder_.getIndexAttr(rhs_scalar));
return builder_.create<::mlir::arith::CmpIOp>(
::mlir::arith::CmpIPredicate::eq, lhs, rhs);
}

::mlir::arith::CmpIOp sge(const ::mlir::Value lhs, const ::mlir::Value rhs) {
return builder_.create<::mlir::arith::CmpIOp>(
::mlir::arith::CmpIPredicate::sge, lhs, rhs);
}

::mlir::arith::CmpIOp sle(const ::mlir::Value lhs, const ::mlir::Value rhs) {
return builder_.create<::mlir::arith::CmpIOp>(
::mlir::arith::CmpIPredicate::sle, lhs, rhs);
}

::mlir::arith::CmpIOp slt(const ::mlir::Value lhs, const ::mlir::Value rhs) {
return builder_.create<::mlir::arith::CmpIOp>(
::mlir::arith::CmpIPredicate::slt, lhs, rhs);
}

private:
::mlir::ImplicitLocOpBuilder builder_;
};
} // namespace mlir
} // namespace jax

#endif // JAX_JAXLIB_MLIR_UTILS_MLIR_UTILS_H_

0 comments on commit afcf5c2

Please sign in to comment.