Skip to content

Commit

Permalink
[MLIR] Add read-only reverse mode arg
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 1, 2024
1 parent b96c443 commit 7651e87
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 282 deletions.
6 changes: 6 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ def : MLIRDerivative<"arith", "DivFOp", (Op $x, $y),
],
(CheckedDivF (SubF (SelectIfActive $x, (MulF (Shadow $x), $y), (ConstantFP<"0","arith", "ConstantOp"> $x)), (SelectIfActive $y, (MulF (Shadow $y), $x), (ConstantFP<"0","arith","ConstantOp"> $y))), (MulF $y, $y))
>;

def ExtF : ArithInst<"ExtFOp">;
def TruncF : ArithInst<"TruncFOp">;

def : ReadOnlyIdentityOp<"arith", "TruncFOp", [0], (Op $x), [(ExtF (TypeOf $x), (DiffeRet))]>;
def : ReadOnlyIdentityOp<"arith", "ExtFOp", [0], (Op $x), [(TruncF (TypeOf $x), (DiffeRet))]>;
14 changes: 12 additions & 2 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,21 @@ class ControlFlowOp<string dialect_, string opName_, string impl_> {
string impl = impl_;
}

class MemoryIdentityOp<string dialect_, string opName_, list<int> ptrargs_, list<int> storedargs_ = []> {

def Unimplemented {

}

class MemoryIdentityOp<string dialect_, string opName_, list<int> ptrargs_, list<int> storedargs_ = [], dag patternToMatch=(Unimplemented), list<dag> reverse_ = []> {
string dialect = dialect_;
string opName = opName_;
dag PatternToMatch = patternToMatch;
list<int> ptrargs = ptrargs_;
list<int> storedargs = storedargs_;
list<dag> reverse = reverse_;
}

class ReadOnlyIdentityOp<string dialect_, string opName_, list<int> ptrargs_> : MemoryIdentityOp<dialect_, opName_, ptrargs_>;
class ReadOnlyIdentityOp<string dialect_, string opName_, list<int> ptrargs_, dag patternToMatch=(Unimplemented), list<dag> reverse_ = []> : MemoryIdentityOp<dialect_, opName_, ptrargs_, [], patternToMatch, reverse_>;

class ReturnOp<string dialect_, string opName_> {
string dialect = dialect_;
Expand Down Expand Up @@ -94,6 +101,9 @@ class ConstantFP<string val, string dialect_, string op_, string type_=""> : Ope

def ResultTypes : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "op->getResultTypes()">;

def TypeOf : Operation</*primal*/0, /*shadow*/0> {
}

class ArithInst<string m> : Inst<m, "arith">;
class MathInst<string m> : Inst<m, "math">;

Expand Down
18 changes: 18 additions & 0 deletions enzyme/test/MLIR/ForwardMode/trunc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @f(%x : f64) -> f32 {
%y = arith.truncf %x : f64 to f32
return %y : f32
}
func.func @dsq(%x : f64, %dx : f64) -> f32 {
%r = enzyme.fwddiff @f(%x, %dx) { activity=[#enzyme<activity enzyme_dup>] } : (f64, f64) -> (f32)
return %r : f32
}
}

// CHECK: func.func private @fwddiffef(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f32 {
// CHECK-NEXT: %[[dy:.+]] = arith.truncf %[[arg1]] : f64 to f32
// CHECK-NEXT: %[[y:.+]] = arith.truncf %[[arg0]] : f64 to f32
// CHECK-NEXT: return %[[dy]] : f32
// CHECK-NEXT: }
18 changes: 18 additions & 0 deletions enzyme/test/MLIR/ReverseMode/trunc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN

module {
func.func @f(%x: f64) -> f32 {
%next = arith.truncf %x : f64 to f32
return %next : f32
}

func.func @dsquare(%x: f64, %dr: f32) -> f64 {
%r = enzyme.autodiff @f(%x, %dr) { activity=[#enzyme<activity enzyme_out>] } : (f64, f32) -> f64
return %r : f64
}
}

// FIN: func.func private @diffef(%[[x:.+]]: f64, %[[dx:.+]]: f32) -> f64 {
// FIN-NEXT: %[[res:.+]] = arith.extf %[[dx]] : f32 to f64
// FIN-NEXT: return %[[res]] : f64
// FIN-NEXT: }
Loading

0 comments on commit 7651e87

Please sign in to comment.