Skip to content

Commit

Permalink
Update MLIR pieces to LLVM 18
Browse files Browse the repository at this point in the history
  • Loading branch information
pengmai committed Oct 25, 2023
1 parent 8930f72 commit e1bf997
Show file tree
Hide file tree
Showing 23 changed files with 42 additions and 41 deletions.
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
Expand Down
14 changes: 8 additions & 6 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ class SparseBackwardActivityAnalysis

void visitBranchOperand(OpOperand &operand) override {}

void visitCallOperand(OpOperand &operand) override {}

void
visitOperation(Operation *op, ArrayRef<BackwardValueActivity *> operands,
ArrayRef<const BackwardValueActivity *> results) override {
Expand Down Expand Up @@ -475,10 +477,10 @@ class DenseForwardActivityAnalysis
});
} else if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
// linalg.yield stores to the corresponding value.
for (OpOperand *dpsInit : linalgOp.getDpsInitOperands()) {
if (dpsInit->get() == value) {
for (OpOperand &dpsInit : linalgOp.getDpsInitsMutable()) {
if (dpsInit.get() == value) {
int64_t resultIndex =
dpsInit->getOperandNumber() - linalgOp.getNumDpsInputs();
dpsInit.getOperandNumber() - linalgOp.getNumDpsInputs();
Value yieldOperand =
linalgOp.getBlock()->getTerminator()->getOperand(resultIndex);
auto *valueState =
Expand Down Expand Up @@ -619,10 +621,10 @@ class DenseBackwardActivityAnalysis
});
} else if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
if (after.activeDataFlowsOut(alloc)) {
for (OpOperand *dpsInit : linalgOp.getDpsInitOperands()) {
if (dpsInit->get() == value) {
for (OpOperand &dpsInit : linalgOp.getDpsInitsMutable()) {
if (dpsInit.get() == value) {
int64_t resultIndex =
dpsInit->getOperandNumber() - linalgOp.getNumDpsInputs();
dpsInit.getOperandNumber() - linalgOp.getNumDpsInputs();
Value yieldOperand =
linalgOp.getBlock()->getTerminator()->getOperand(
resultIndex);
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Dialect/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef ENZYME_DIALECT_H
#define ENZYME_DIALECT_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"

#include "Dialect/EnzymeOpsDialect.h.inc"
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/EnumAttr.td"

include "mlir/IR/OpBase.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"

include "mlir/IR/AttrTypeBase.td"

include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/FunctionInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
// #include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Dialect/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "mlir/IR/DialectImplementation.h"
#include "mlir/Bytecode/BytecodeOpInterface.h"

#define GET_OP_CLASSES
#include "Dialect/EnzymeOps.h.inc"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class PointerTypeInterface
public:
mlir::Value createNullValue(mlir::Type self, OpBuilder &builder,
Location loc) const {
return builder.create<LLVM::NullOp>(loc, self);
return builder.create<LLVM::ZeroOp>(loc, self);
}

Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ struct GenericOpInterfaceReverse
dims.push_back(dim);
}
}
for (OpOperand *output : newOp.getDpsInitOperands()) {
auto shape = cast<MemRefType>(output->get().getType()).getShape();
for (Value output : newOp.getDpsInits()) {
auto shape = cast<MemRefType>(output.getType()).getShape();
for (unsigned i = 0; i < shape.size(); i++) {
auto dimI =
cacheBuilder.create<arith::ConstantIndexOp>(op->getLoc(), i);
auto dim = cacheBuilder.create<memref::DimOp>(op->getLoc(),
output->get(), dimI);
auto dim =
cacheBuilder.create<memref::DimOp>(op->getLoc(), output, dimI);
dims.push_back(dim);
}
}
Expand All @@ -135,12 +135,12 @@ struct GenericOpInterfaceReverse
SmallVector<utils::IteratorType> iteratorTypes{
linalgOp.getNumLoops(), utils::IteratorType::parallel};

for (OpOperand *output : linalgOp.getDpsInitOperands()) {
if (!gutils->hasInvertPointer(output->get())) {
for (OpOperand &output : linalgOp.getDpsInitsMutable()) {
if (!gutils->hasInvertPointer(output.get())) {
continue;
}
indexingMaps.push_back(linalgOp.getMatchingIndexingMap(output));
Value out = gutils->invertPointerM(output->get(), builder);
indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&output));
Value out = gutils->invertPointerM(output.get(), builder);
Value view = invertMemref(out, builder, op->getLoc());
outputs.push_back(view);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ struct ForOpInterface
}
}
SmallVector<mlir::Value> nArgs;
for (auto r :
llvm::zip(forOp.getIterOperands(), forOp.getRegionIterArgs())) {
for (auto r : llvm::zip(forOp.getInitArgs(), forOp.getRegionIterArgs())) {
// TODO only if used
nArgs.push_back(gutils->getNewFromOriginal(std::get<0>(r)));
if (!gutils->isConstantValue(std::get<1>(r)))
Expand Down Expand Up @@ -105,7 +104,6 @@ struct ForOpInterfaceReverse
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto forOp = cast<scf::ForOp>(op);
auto newForOp = cast<scf::ForOp>(gutils->getNewFromOriginal(op));

SmallVector<Value> nArgs;
for (Value v : forOp.getResults()) {
Expand Down Expand Up @@ -140,7 +138,7 @@ struct ForOpInterfaceReverse
repFor.getRegion().insertArgument((unsigned)0, indexType, forOp.getLoc());

for (const auto &[iterOperand, adjResult] :
llvm::zip(forOp.getIterOperands(), repFor.getResults())) {
llvm::zip(forOp.getInitArgs(), repFor.getResults())) {
if (gutils->hasInvertPointer(iterOperand)) {
auto autoDiffType = cast<AutoDiffTypeInterface>(iterOperand.getType());
Value before = gutils->invertPointerM(iterOperand, builder);
Expand Down Expand Up @@ -176,7 +174,7 @@ struct ForOpInterfaceReverse

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {
auto forOp = cast<scf::ForOp>(op);
// auto forOp = cast<scf::ForOp>(op);
}
};

Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/CloneFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#include "Dialect/Ops.h"
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/AutoDiffTypeInterface.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"

Expand All @@ -18,7 +18,7 @@
#include "CloneFunction.h"
#include "EnzymeLogic.h"

#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/IR/IRMapping.h"

using namespace mlir;
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

// TODO: this shouldn't depend on specific dialects except Enzyme.
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
Expand Down Expand Up @@ -57,7 +57,7 @@ void createTerminator(MDiffeGradientUtils *gutils, mlir::Block *oBB,

SmallVector<NamedAttribute> attrs(newInst->getAttrs());
for (auto &attr : attrs) {
if (attr.getName() == "operand_segment_sizes")
if (attr.getName() == "operandSegmentSizes")
attr.setValue(nBuilder.getDenseI32ArrayAttr(segSizes));
}

Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/IR/IRMapping.h"

#include "../../TypeAnalysis/TypeAnalysis.h"
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "Dialect/Ops.h"
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/AutoDiffTypeInterface.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"

Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/CloneFunction.h"

#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"

Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "Interfaces/EnzymeLogic.h"

#include "Analysis/ActivityAnalysis.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/IR/IRMapping.h"

namespace mlir {
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "Dialect/Ops.h"
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/AutoDiffTypeInterface.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"

Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/IR/IRMapping.h"

#include "CloneFunction.h"
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "PassDetails.h"
#include "Passes/Passes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"

Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Passes/EnzymeToMemRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ struct LoweredCache {
.getResult(0);
}
static std::optional<LoweredCache>
getFromEnzymeCache(Location loc, TypeConverter *typeConverter,
getFromEnzymeCache(Location loc, const TypeConverter *typeConverter,
Value enzymeCache, OpBuilder &b) {
assert(enzymeCache.getType().isa<enzyme::CacheType>());
auto cacheType = enzymeCache.getType().cast<enzyme::CacheType>();
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/CallInterfaces.h"

#include "llvm/ADT/TypeSwitch.h"
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

using namespace mlir;

Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/MLIR/ActivityAnalysis/allocator.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f16, dense<16> :
%2 = llvm.mlir.constant("d_reduce_max(%i)=%f\0A\00") : !llvm.array<21 x i8>
%3 = llvm.mlir.addressof @".str.1" : !llvm.ptr
%4 = llvm.mlir.constant(0 : i32) : i32
%5 = llvm.call @_Z17__enzyme_autodiffPvPdS0_i(%0, %1) : (!llvm.ptr, f64) -> f64
%6 = llvm.call @printf(%3, %4, %5) : (!llvm.ptr, i32, f64) -> i32
%5 = llvm.call @_Z17__enzyme_autodiffPvPdS0_i(%0, %1) vararg(!llvm.func<f64 (...)>) : (!llvm.ptr, f64) -> f64
%6 = llvm.call @printf(%3, %4, %5) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr, i32, f64) -> i32
llvm.return %4 : i32
}
llvm.func @printf(!llvm.ptr {llvm.nocapture, llvm.readonly}, ...) -> i32
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ActivityAnalysis/string.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<!llvm.ptr, dense<
llvm.store %5, %16 {alignment = 8 : i64, tbaa = [#tbaa_tag2]} : i64, !llvm.ptr
%17 = llvm.getelementptr inbounds %10[%1, 2, 1, %8] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<"class.std::__cxx11::basic_string", (struct<"struct.std::__cxx11::basic_string<char>::_Alloc_hider", (ptr)>, i64, struct<"union.anon", (i64, array<8 x i8>)>)>
llvm.store %9, %17 {alignment = 1 : i64, tbaa = [#tbaa_tag]} : i8, !llvm.ptr
%18 = llvm.call @printf(%14) : (!llvm.ptr) -> i32
%18 = llvm.call @printf(%14) vararg(!llvm.func<i32 (ptr, ...)>) : (!llvm.ptr) -> i32
%19 = llvm.load %15 {alignment = 8 : i64, tbaa = [#tbaa_tag3], tag = "loaded"} : !llvm.ptr -> !llvm.ptr
%20 = llvm.icmp "eq" %19, %14 : !llvm.ptr
llvm.cond_br %20, ^bb2, ^bb1
Expand Down

0 comments on commit e1bf997

Please sign in to comment.