Skip to content

Commit

Permalink
Add complex sqrt (#1778)
Browse files Browse the repository at this point in the history
* Add complex sqrt

* fix analyzer capture

* fix analyzer error passing convention

* Handle extractions in tablegen
  • Loading branch information
wsmoses authored Apr 1, 2024
1 parent b1f676c commit 5b2f04c
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 42 deletions.
18 changes: 9 additions & 9 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
}
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
TR.dump(ss);
Expand Down Expand Up @@ -1032,7 +1032,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
}
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
TR.dump(ss);
Expand All @@ -1058,7 +1058,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
<< " size: " << storeSize;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I, ss.str());
}
Expand Down Expand Up @@ -1136,7 +1136,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
<< " size: " << storeSize;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I, ss.str());
}
Expand Down Expand Up @@ -1445,7 +1445,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "Cannot deduce adding type (cast) of " << I;
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&Builder2));
TR.analyzer, nullptr, wrap(&Builder2));
return;
} else {
ss << "\n";
Expand Down Expand Up @@ -2012,7 +2012,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
<< " size: " << size0 << " TT: " << TT.str();
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&Builder2));
TR.analyzer, nullptr, wrap(&Builder2));
} else {
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
ss.str());
Expand Down Expand Up @@ -2097,7 +2097,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
<< " TT: " << TT.str();
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&Builder2));
TR.analyzer, nullptr, wrap(&Builder2));
} else {
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
ss.str());
Expand Down Expand Up @@ -3133,7 +3133,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "Cannot deduce type of memset " << MS;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&MS), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
TR.dump(ss);
Expand Down Expand Up @@ -3439,7 +3439,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "Cannot deduce type of copy " << MTI;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&MTI), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
ss << *gutils->oldFunc << "\n";
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/DiffeGradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM,
ss << "}\n";
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(val), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderM));
TR.analyzer, nullptr, wrap(&BuilderM));
return addedSelects;
} else {
TR.dump(ss);
Expand Down Expand Up @@ -534,7 +534,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM,
<< "\n";
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(val), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderM));
TR.analyzer, nullptr, wrap(&BuilderM));
return addedSelects;
} else {
DebugLoc loc;
Expand Down
10 changes: 10 additions & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,16 @@ class EnzymeBase {
csts.push_back(ConstantFP::get(e, 1.0));
}
args.push_back(ConstantStruct::get(ST, csts));
} else if (auto AT = dyn_cast<ArrayType>(fn->getReturnType())) {
SmallVector<Constant *, 2> csts(
AT->getNumElements(), ConstantFP::get(AT->getElementType(), 1.0));
args.push_back(ConstantArray::get(AT, csts));
} else {
auto RT = fn->getReturnType();
EmitFailure("EnzymeCallingError", CI->getDebugLoc(), CI,
"Differential return required for call ", *CI,
" but one of type ", *RT, " could not be auto deduced");
return false;
}
}

Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3364,7 +3364,7 @@ void createInvertedTerminator(DiffeGradientUtils *gutils,
<< " sz: " << size << "\n";
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(orig), ErrorType::NoType,
&gutils->TR.analyzer, nullptr, wrap(&Builder));
gutils->TR.analyzer, nullptr, wrap(&Builder));
continue;
} else {
ss << "\n";
Expand Down
15 changes: 15 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4907,6 +4907,21 @@ Type *GradientUtils::getShadowType(Type *ty) {
return getShadowType(ty, width);
}

Type *GradientUtils::extractMeta(Type *T, ArrayRef<unsigned> off) {
for (auto idx : off) {
if (auto AT = dyn_cast<ArrayType>(T)) {
T = AT->getElementType();
continue;
}
if (auto ST = dyn_cast<StructType>(T)) {
T = ST->getElementType(idx);
continue;
}
assert(false && "could not sub index into type");
}
return T;
}

Value *GradientUtils::extractMeta(IRBuilder<> &Builder, Value *Agg,
unsigned off, const Twine &name) {
return extractMeta(Builder, Agg, ArrayRef<unsigned>({off}), name);
Expand Down
3 changes: 3 additions & 0 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,9 @@ class GradientUtils : public CacheUtility {
const llvm::Twine &name = "",
bool fallback = true);

//! Helper routine to get the type of an extraction
static llvm::Type *extractMeta(llvm::Type *T, llvm::ArrayRef<unsigned> off);

static llvm::Value *recursiveFAdd(llvm::IRBuilder<> &B, llvm::Value *lhs,
llvm::Value *rhs,
llvm::ArrayRef<unsigned> lhs_off = {},
Expand Down
7 changes: 7 additions & 0 deletions enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,13 @@ def : IntrPattern<(Op $x),
(ForwardFromSummedReverse)
>;


def : CallPattern<(Op (Op $x, $y):$z),
["cmplx_sqrt"],
[(Select (And (FCmpUEQ $x, (ConstantFP<"0"> $x)), (FCmpUEQ $y, (ConstantFP<"0"> $y))), (ConstantCFP<"0", "0"> $z), (Conj (CFDiv (Conj (DiffeRet)), (CFMul (ConstantCFP<"2", "0"> $z), (Call<(SameFunc), [ReadNone,NoUnwind]> $z)))))],
(Select (And (FCmpUEQ $x, (ConstantFP<"0"> $x)), (FCmpUEQ $y, (ConstantFP<"0"> $y))), (ConstantCFP<"0", "0"> $z), (CFDiv (Shadow $z), (CFMul (ConstantCFP<"2", "0"> $z), (Call<(SameFunc), [ReadNone,NoUnwind]> $z))))
>;

def : IntrPattern<(Op $x, $y),
[["pow"]],
[
Expand Down
20 changes: 20 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/cmplx_sqrt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s

declare [2 x double] @cmplx_sqrt([2 x double] %x)

define [2 x double] @tester([2 x double] %x) {
entry:
%y = call [2 x double] @cmplx_sqrt([2 x double] %x)
ret [2 x double] %y
}

define [2 x double] @test_derivative([2 x double] %x) {
entry:
%0 = tail call [2 x double] (...) @__enzyme_autodiff([2 x double] ([2 x double])* nonnull @tester, metadata !"enzyme_active_return", [2 x double] %x)
ret [2 x double] %0
}

declare [2 x double] @__enzyme_autodiff(...)

; CHECK: define internal { [2 x double] } @diffetester([2 x double] %x, [2 x double] %differeturn)
Loading

0 comments on commit 5b2f04c

Please sign in to comment.