Skip to content

Commit

Permalink
Merge branch 'main' into rust-bench
Browse files Browse the repository at this point in the history
* main: (49 commits)
  Fix iv of constant (#2141)
  Update benchmarks (#2035)
  Implement tgamma derivative (#2140)
  tgamma error improvement (#2139)
  Improve cache index error message (#2138)
  Fixes warnings and adds missing header guards (#2124)
  mlir: cache and reuse reverse funcs (#2133)
  mlir: implement forward mode for func.call (#2134)
  mlir: Func call reverse diff (#2127)
  Update build_tarballs.jl
  Fix combined temp cache for reverse (#2131)
  Improve runtime activity err message (#2132)
  Fix undef value storage (#2129)
  Adapt to const tblgen (#2128)
  Add gcloaded TT (#2125)
  Fix blas decl updater indexing (#2123)
  Add header files to ClangEnzyme target (#2062)
  Improve unknown function error messages (#2120)
  Fix handle sync (#2122)
  Support more Julia 1.11 functions (#2121)
  ...
  • Loading branch information
jedbrown committed Nov 1, 2024
2 parents 59f866b + de7c147 commit 3c0a0f8
Show file tree
Hide file tree
Showing 80 changed files with 2,101 additions and 644 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ jobs:
fail-fast: false
matrix:
llvm: ["16", "17", "18"]
build: ["Release", "Debug"] # "RelWithDebInfo"
os: [openstack18]
build: ["Release"] #, "Debug" "RelWithDebInfo"
os: [openstack22]
timeout-minutes: 120
steps:
- name: add llvm
run: |
wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add -
sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true
sudo apt-get install -y python3-pip autoconf cmake gcc g++ libtool gfortran libblas-dev llvm-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} libeigen3-dev libboost-dev
sudo apt-get install -y python3-pip autoconf cmake gcc g++ libtool gfortran libblas-dev llvm-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} libeigen3-dev libboost-dev libzstd-dev
sudo python3 -m pip install lit pathlib
sudo touch /usr/lib/llvm-${{ matrix.llvm }}/bin/yaml-bench
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/enzyme-mlir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- uses: actions/checkout@v4
with:
repository: 'llvm/llvm-project'
ref: '1bc7057a8eb7400dfbb1fc8335efa41abab9884e'
ref: 'eaa7b385368fa7e3dad9b95411d04be55e71494e'
path: 'llvm-project'

- name: Get MLIR commit hash
Expand Down
2 changes: 1 addition & 1 deletion .packaging/build_tarballs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repo = "https://github.com/EnzymeAD/Enzyme.git"
auto_version = "%ENZYME_VERSION%"
version = VersionNumber(split(auto_version, "/")[end])

llvm_versions = [v"15.0.7", v"16.0.6", v"17.0.6"]
llvm_versions = [v"15.0.7", v"16.0.6", v"17.0.6", v"18.1.7", v"19.1.1"]

# Collection of sources required to build attr
sources = [
Expand Down
1 change: 1 addition & 0 deletions enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ cc_library(
"@llvm-project//llvm:Scalar",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:TargetParser",
"@llvm-project//llvm:TransformUtils",
"@llvm-project//llvm:config",
],
Expand Down
4 changes: 3 additions & 1 deletion enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,9 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
}
if (funcName == "jl_array_copy" || funcName == "ijl_array_copy" ||
funcName == "jl_idtable_rehash" ||
funcName == "ijl_idtable_rehash") {
funcName == "ijl_idtable_rehash" ||
funcName == "jl_genericmemory_copy_slice" ||
funcName == "ijl_genericmemory_copy_slice") {
// This pointer is inactive if it is either not actively stored to
// and not actively loaded from and the copied input is inactive.
if (directions & DOWN && directions & UP) {
Expand Down
6 changes: 6 additions & 0 deletions enzyme/Enzyme/ActivityAnalysisPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
// results of a given function.
//
//===----------------------------------------------------------------------===//

#ifndef ENZYME_ACTIVITY_ANALYSIS_PRINTER_H
#define ENZYME_ACTIVITY_ANALYSIS_PRINTER_H

#include <llvm/Config/llvm-config.h>

#include "llvm/IR/PassManager.h"
Expand All @@ -46,3 +50,5 @@ class ActivityAnalysisPrinterNewPM final

static bool isRequired() { return true; }
};

#endif // ENZYME_ACTIVITY_ANALYSIS_PRINTER_H
116 changes: 80 additions & 36 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
// LLVM instructions.
//
//===----------------------------------------------------------------------===//

#ifndef ENZYME_ADJOINT_GENERATOR_H
#define ENZYME_ADJOINT_GENERATOR_H

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
Expand Down Expand Up @@ -302,9 +306,12 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
void forwardModeInvertedPointerFallback(llvm::Instruction &I) {
using namespace llvm;

if (gutils->isConstantValue(&I))
return;
auto found = gutils->invertedPointers.find(&I);
if (gutils->isConstantValue(&I)) {
assert(found == gutils->invertedPointers.end());
return;
}

assert(found != gutils->invertedPointers.end());
auto placeholder = cast<PHINode>(&*found->second);
gutils->invertedPointers.erase(found);
Expand All @@ -320,6 +327,8 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {

auto toset = gutils->invertPointerM(&I, Builder2, /*nullShadow*/ true);

assert(toset != placeholder);

gutils->replaceAWithB(placeholder, toset);
placeholder->replaceAllUsesWith(toset);
gutils->erase(placeholder);
Expand Down Expand Up @@ -1137,8 +1146,8 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
} else {
maskL = lookup(mask, Builder2);
Type *tys[] = {valType, orig_ptr->getType()};
auto F = Intrinsic::getDeclaration(gutils->oldFunc->getParent(),
Intrinsic::masked_load, tys);
auto F = getIntrinsicDeclaration(gutils->oldFunc->getParent(),
Intrinsic::masked_load, tys);
Value *alignv =
ConstantInt::get(Type::getInt32Ty(mask->getContext()),
align ? align->value() : 0);
Expand Down Expand Up @@ -2141,18 +2150,6 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
void visitBinaryOperator(llvm::BinaryOperator &BO) {
eraseIfUnused(BO);

size_t size = 1;
if (BO.getType()->isSized())
size = (gutils->newFunc->getParent()->getDataLayout().getTypeSizeInBits(
BO.getType()) +
7) /
8;

if (BO.getType()->isIntOrIntVectorTy() &&
TR.intType(size, &BO, /*errifnotfound*/ false) == BaseType::Pointer) {
return;
}

if (BO.getOpcode() == llvm::Instruction::FDiv &&
(Mode == DerivativeMode::ReverseModeGradient ||
Mode == DerivativeMode::ReverseModeCombined) &&
Expand Down Expand Up @@ -2285,6 +2282,9 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
}

void createBinaryOperatorAdjoint(llvm::BinaryOperator &BO) {
if (gutils->isConstantInstruction(&BO)) {
return;
}
using namespace llvm;

IRBuilder<> Builder2(&BO);
Expand Down Expand Up @@ -2766,8 +2766,19 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
auto rval = EmitNoDerivativeError(ss.str(), BO, gutils, Builder2);
if (!rval)
rval = Constant::getNullValue(gutils->getShadowType(BO.getType()));
if (!gutils->isConstantValue(&BO))
setDiffe(&BO, rval, Builder2);
auto ifound = gutils->invertedPointers.find(&BO);
if (!gutils->isConstantValue(&BO)) {
if (ifound != gutils->invertedPointers.end()) {
auto placeholder = cast<PHINode>(&*ifound->second);
gutils->invertedPointers.erase(ifound);
gutils->replaceAWithB(placeholder, rval);
gutils->erase(placeholder);
gutils->invertedPointers.insert(std::make_pair(
(const Value *)&BO, InvertedPointerVH(gutils, rval)));
}
} else {
assert(ifound == gutils->invertedPointers.end());
}
break;
}
}
Expand Down Expand Up @@ -3108,7 +3119,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
op3 = gutils->getNewFromOriginal(MS.getOperand(3));
}

for (auto &&[secretty, seg_start, seg_size] : toIterate) {
for (auto &&[secretty_ref, seg_start_ref, seg_size_ref] : toIterate) {
auto secretty = secretty_ref;
auto seg_start = seg_start_ref;
auto seg_size = seg_size_ref;

Value *length = new_size;
if (seg_start != std::get<1>(toIterate.back())) {
length = ConstantInt::get(new_size->getType(), seg_start + seg_size);
Expand Down Expand Up @@ -3484,7 +3499,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
}
}

for (auto &&[floatTy, seg_start, seg_size] : toIterate) {
for (auto &&[floatTy_ref, seg_start_ref, seg_size_ref] : toIterate) {
auto floatTy = floatTy_ref;
auto seg_start = seg_start_ref;
auto seg_size = seg_size_ref;

Value *length = new_size;
if (seg_start != std::get<1>(toIterate.back())) {
length = ConstantInt::get(new_size->getType(), seg_start + seg_size);
Expand Down Expand Up @@ -3781,10 +3800,9 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
case Intrinsic::nvvm_barrier0_or: {
SmallVector<Value *, 1> args = {};
auto cal = cast<CallInst>(Builder2.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::nvvm_barrier0), args));
cal->setCallingConv(
Intrinsic::getDeclaration(M, Intrinsic::nvvm_barrier0)
->getCallingConv());
getIntrinsicDeclaration(M, Intrinsic::nvvm_barrier0), args));
cal->setCallingConv(getIntrinsicDeclaration(M, Intrinsic::nvvm_barrier0)
->getCallingConv());
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
return false;
}
Expand All @@ -3796,8 +3814,8 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
case Intrinsic::nvvm_membar_sys: {
SmallVector<Value *, 1> args = {};
auto cal = cast<CallInst>(
Builder2.CreateCall(Intrinsic::getDeclaration(M, ID), args));
cal->setCallingConv(Intrinsic::getDeclaration(M, ID)->getCallingConv());
Builder2.CreateCall(getIntrinsicDeclaration(M, ID), args));
cal->setCallingConv(getIntrinsicDeclaration(M, ID)->getCallingConv());
cal->setDebugLoc(gutils->getNewFromOriginal(I.getDebugLoc()));
return false;
}
Expand All @@ -3810,9 +3828,9 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
lookup(gutils->getNewFromOriginal(orig_ops[1]), Builder2)};
Type *tys[] = {args[1]->getType()};
auto cal = Builder2.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::lifetime_end, tys), args);
getIntrinsicDeclaration(M, Intrinsic::lifetime_end, tys), args);
cal->setCallingConv(
Intrinsic::getDeclaration(M, Intrinsic::lifetime_end, tys)
getIntrinsicDeclaration(M, Intrinsic::lifetime_end, tys)
->getCallingConv());
return false;
}
Expand Down Expand Up @@ -5474,19 +5492,42 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
}

if (subretused) {
Intrinsic::ID ID = Intrinsic::not_intrinsic;
if (DifferentialUseAnalysis::is_value_needed_in_reverse<
QueryType::Primal>(gutils, &call, Mode, oldUnreachable) &&
!gutils->unnecessaryIntermediates.count(&call)) {

if (!isMemFreeLibMFunction(getFuncNameFromCall(&call), &ID)) {

#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
auto It = BuilderZ.GetInsertPoint();
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
cachereplace = BuilderZ.CreatePHI(call.getType(), 1,
call.getName() + "_tmpcacheB");
cachereplace = gutils->cacheForReverse(
BuilderZ, cachereplace,
getIndex(&call, CacheType::Self, BuilderZ));
auto idx = getIndex(&call, CacheType::Self, BuilderZ);
if (idx == IndexMappingError) {
std::string str;
raw_string_ostream ss(str);
ss << "Failed to compute consistent cache index for operation: "
<< call << "\n";
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&call),
ErrorType::InternalError, nullptr, nullptr,
nullptr);
} else {
EmitFailure("GetIndexError", call.getDebugLoc(), &call,
ss.str());
}
} else {
if (Mode == DerivativeMode::ReverseModeCombined)
cachereplace = newCall;
else
cachereplace = BuilderZ.CreatePHI(
call.getType(), 1, call.getName() + "_tmpcacheB");
cachereplace =
gutils->cacheForReverse(BuilderZ, cachereplace, idx);
}
}
} else {
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
Expand Down Expand Up @@ -5993,6 +6034,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
eraseIfUnused(call);
}

ifound = gutils->invertedPointers.find(&call);
if (ifound != gutils->invertedPointers.end()) {
auto placeholder = cast<PHINode>(&*ifound->second);
if (invertedReturn && invertedReturn != placeholder) {
Expand Down Expand Up @@ -6406,3 +6448,5 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
subretused);
}
};

#endif // ENZYME_ADJOINT_GENERATOR_H
61 changes: 61 additions & 0 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,67 @@ def symm: CallBlasPattern<(Op $layout, $side, $uplo, $m, $n, $alpha, $A, $lda, $
)
>;



def syr2: CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $x, $incx, $y, $incy, $A, $lda),
["A"],
[cblas_layout, uplo, len, fp, vinc<["n"]>, vinc<["n"]>, mld<["uplo", "n", "n"]>],
[
/*alpha*/ (AssertingInactiveArg),
/*x*/ (AssertingInactiveArg),
/*y*/ (AssertingInactiveArg),
/*A*/ (AssertingInactiveArg)
]
>;


def symv: CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $A, $lda, $x, $incx, $beta, $y, $incy),
["y"],
[cblas_layout, uplo, len, fp, mld<["uplo", "n", "n"]>, vinc<["n"]>, fp, vinc<["n"]>],
[
/*alpha*/ (Seq<["Ax", "vector", "n"], [], 1>
(BlasCall<"symv"> $layout, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>),
(BlasCall<"dot"> $n, (Shadow $y), use<"Ax">, ConstantInt<1>)
),
/*A*/ (Seq<["tmp", "vector", "n"], [], 1>
// Save the diagonal as we shouldn't add syr2 into it
(BlasCall<"copy"> $n, (First (Shadow $A)), (Add $lda, ConstantInt<1>), use<"tmp">, ConstantInt<1>),

(BlasCall<"syr2">
$layout,
$uplo,
$n,
$alpha,
$x,
(Shadow $y),
(Shadow $A)
),
(BlasCall<"copy"> $n, use<"tmp">, ConstantInt<1>, (First (Shadow $A)), (Add $lda, ConstantInt<1>))
),
/*x*/ (BlasCall<"symv"> $layout, $uplo, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $y), Constant<"1.0">, (Shadow $x)),
/*beta*/ (BlasCall<"dot"> $n, (Shadow $y), input<"y">),
/*y*/ (BlasCall<"scal"> $n, $beta, (Shadow $y))
],
// FWD: dy = dalpha A x + alpha dA x + alpha A dx + dbeta y + beta dy

(Seq<[], ["beta1"], 1>
// dbeta y
(BlasCall<"axpy"> $n, (Shadow $beta), $y, (Shadow $y)),

// alpha A dx (optional + beta dy)
(BlasCall<"symv"> $layout, $uplo, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $x), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $y)),

// alpha dA x (optional + beta dy)
(BlasCall<"symv"> $layout, $uplo, $n, $alpha, (Shadow $A), $x, (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $y)),

// dalpha A x (optional + beta dy)
(BlasCall<"symv"> $layout, $uplo, $n, (Shadow $alpha), $A, (ld $A, Char<"N">, $lda, $n, $n), $x, (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $y)),

// (beta dy)
(FirstUse<"beta1"> (BlasCall<"scal"> $n, $beta, (Shadow $y)))
)
>;

def syr2k : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc),
["C"],
[cblas_layout, uplo, trans, len, len, fp, mld<["trans", "n", "k"]>, mld<["trans", "n", "k"]>, fp, mld<["n", "n"]>],
Expand Down
Loading

0 comments on commit 3c0a0f8

Please sign in to comment.