Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complex sqrt #1778

Merged
merged 4 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
}
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
TR.dump(ss);
Expand Down Expand Up @@ -1032,7 +1032,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
}
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
TR.dump(ss);
Expand All @@ -1058,7 +1058,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
<< " size: " << storeSize;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I, ss.str());
}
Expand Down Expand Up @@ -1136,7 +1136,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
<< " size: " << storeSize;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I, ss.str());
}
Expand Down Expand Up @@ -1445,7 +1445,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "Cannot deduce adding type (cast) of " << I;
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&Builder2));
TR.analyzer, nullptr, wrap(&Builder2));
return;
} else {
ss << "\n";
Expand Down Expand Up @@ -2012,7 +2012,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
<< " size: " << size0 << " TT: " << TT.str();
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&Builder2));
TR.analyzer, nullptr, wrap(&Builder2));
} else {
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
ss.str());
Expand Down Expand Up @@ -2097,7 +2097,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
<< " TT: " << TT.str();
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&Builder2));
TR.analyzer, nullptr, wrap(&Builder2));
} else {
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
ss.str());
Expand Down Expand Up @@ -3133,7 +3133,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "Cannot deduce type of memset " << MS;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&MS), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
TR.dump(ss);
Expand Down Expand Up @@ -3439,7 +3439,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "Cannot deduce type of copy " << MTI;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&MTI), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderZ));
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
ss << *gutils->oldFunc << "\n";
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/DiffeGradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM,
ss << "}\n";
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(val), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderM));
TR.analyzer, nullptr, wrap(&BuilderM));
return addedSelects;
} else {
TR.dump(ss);
Expand Down Expand Up @@ -534,7 +534,7 @@ DiffeGradientUtils::addToDiffe(Value *val, Value *dif, IRBuilder<> &BuilderM,
<< "\n";
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(val), ErrorType::NoType,
&TR.analyzer, nullptr, wrap(&BuilderM));
TR.analyzer, nullptr, wrap(&BuilderM));
return addedSelects;
} else {
DebugLoc loc;
Expand Down
10 changes: 10 additions & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,16 @@ class EnzymeBase {
csts.push_back(ConstantFP::get(e, 1.0));
}
args.push_back(ConstantStruct::get(ST, csts));
} else if (auto AT = dyn_cast<ArrayType>(fn->getReturnType())) {
SmallVector<Constant *, 2> csts(
AT->getNumElements(), ConstantFP::get(AT->getElementType(), 1.0));
args.push_back(ConstantArray::get(AT, csts));
} else {
auto RT = fn->getReturnType();
EmitFailure("EnzymeCallingError", CI->getDebugLoc(), CI,
"Differential return required for call ", *CI,
" but one of type ", *RT, " could not be auto deduced");
return false;
}
}

Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3364,7 +3364,7 @@ void createInvertedTerminator(DiffeGradientUtils *gutils,
<< " sz: " << size << "\n";
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(orig), ErrorType::NoType,
&gutils->TR.analyzer, nullptr, wrap(&Builder));
gutils->TR.analyzer, nullptr, wrap(&Builder));
continue;
} else {
ss << "\n";
Expand Down
15 changes: 15 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4907,6 +4907,21 @@ Type *GradientUtils::getShadowType(Type *ty) {
return getShadowType(ty, width);
}

Type *GradientUtils::extractMeta(Type *T, ArrayRef<unsigned> off) {
for (auto idx : off) {
if (auto AT = dyn_cast<ArrayType>(T)) {
T = AT->getElementType();
continue;
}
if (auto ST = dyn_cast<StructType>(T)) {
T = ST->getElementType(idx);
continue;
}
assert(false && "could not sub index into type");
}
return T;
}

Value *GradientUtils::extractMeta(IRBuilder<> &Builder, Value *Agg,
unsigned off, const Twine &name) {
return extractMeta(Builder, Agg, ArrayRef<unsigned>({off}), name);
Expand Down
3 changes: 3 additions & 0 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,9 @@ class GradientUtils : public CacheUtility {
const llvm::Twine &name = "",
bool fallback = true);

//! Helper routine to get the type of an extraction
static llvm::Type *extractMeta(llvm::Type *T, llvm::ArrayRef<unsigned> off);

static llvm::Value *recursiveFAdd(llvm::IRBuilder<> &B, llvm::Value *lhs,
llvm::Value *rhs,
llvm::ArrayRef<unsigned> lhs_off = {},
Expand Down
7 changes: 7 additions & 0 deletions enzyme/Enzyme/InstructionDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,13 @@ def : IntrPattern<(Op $x),
(ForwardFromSummedReverse)
>;


def : CallPattern<(Op (Op $x, $y):$z),
["cmplx_sqrt"],
[(Select (And (FCmpUEQ $x, (ConstantFP<"0"> $x)), (FCmpUEQ $y, (ConstantFP<"0"> $y))), (ConstantCFP<"0", "0"> $z), (Conj (CFDiv (Conj (DiffeRet)), (CFMul (ConstantCFP<"2", "0"> $z), (Call<(SameFunc), [ReadNone,NoUnwind]> $z)))))],
(Select (And (FCmpUEQ $x, (ConstantFP<"0"> $x)), (FCmpUEQ $y, (ConstantFP<"0"> $y))), (ConstantCFP<"0", "0"> $z), (CFDiv (Shadow $z), (CFMul (ConstantCFP<"2", "0"> $z), (Call<(SameFunc), [ReadNone,NoUnwind]> $z))))
>;

def : IntrPattern<(Op $x, $y),
[["pow"]],
[
Expand Down
20 changes: 20 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/cmplx_sqrt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s

declare [2 x double] @cmplx_sqrt([2 x double] %x)

define [2 x double] @tester([2 x double] %x) {
entry:
%y = call [2 x double] @cmplx_sqrt([2 x double] %x)
ret [2 x double] %y
}

define [2 x double] @test_derivative([2 x double] %x) {
entry:
%0 = tail call [2 x double] (...) @__enzyme_autodiff([2 x double] ([2 x double])* nonnull @tester, metadata !"enzyme_active_return", [2 x double] %x)
ret [2 x double] %0
}

declare [2 x double] @__enzyme_autodiff(...)

; CHECK: define internal { [2 x double] } @diffetester([2 x double] %x, [2 x double] %differeturn)
Loading
Loading