Skip to content

Commit

Permalink
Allow custom importing of files and syntactic sugar
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 22, 2024
1 parent 638ac37 commit 990d404
Show file tree
Hide file tree
Showing 7 changed files with 373 additions and 6 deletions.
15 changes: 15 additions & 0 deletions enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ gentbl(
],
)

gentbl(
name = "include-utils",
tbl_outs = [(
"-gen-header-strings",
"IncludeUtils.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/Clang/include_utils.td",
td_srcs = ["Enzyme/Clang/include_utils.td"],
deps = [
":enzyme-tblgen",
],
)

cc_library(
name = "EnzymeStatic",
srcs = glob(
Expand All @@ -167,6 +181,7 @@ cc_library(
data = ["@llvm-project//clang:builtin_headers_gen"],
visibility = ["//visibility:public"],
deps = [
"include-utils",
":binop-derivatives",
":blas-attributor",
":blas-derivatives",
Expand Down
6 changes: 6 additions & 0 deletions enzyme/Enzyme/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ add_public_tablegen_target(BlasDeclarationsIncGen)
add_public_tablegen_target(BlasTAIncGen)
add_public_tablegen_target(BlasDiffUseIncGen)

set(LLVM_TARGET_DEFINITIONS Clang/include_utils.td)
enzyme_tablegen(IncludeUtils.inc -gen-header-strings)
add_public_tablegen_target(IncludeUtilsIncGen)

include_directories(${CMAKE_CURRENT_BINARY_DIR})

set(LLVM_LINK_COMPONENTS Demangle)
Expand Down Expand Up @@ -74,6 +78,7 @@ if (${Clang_FOUND})
LLVM
)
target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS)
add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen)
endif()
add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR}
${ENZYME_SRC} Clang/EnzymePassLoader.cpp
Expand Down Expand Up @@ -107,6 +112,7 @@ if (${Clang_FOUND})
clang
)
target_compile_definitions(ClangEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS)
add_dependencies(ClangEnzyme-${LLVM_VERSION_MAJOR} IncludeUtilsIncGen)
endif()
add_llvm_library( LLDEnzyme-${LLVM_VERSION_MAJOR}
${ENZYME_SRC} Clang/EnzymePassLoader.cpp
Expand Down
33 changes: 33 additions & 0 deletions enzyme/Enzyme/Clang/EnzymeClang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@

#include "../Utils.h"

#include "IncludeUtils.inc"

using namespace clang;

#if LLVM_VERSION_MAJOR >= 18
Expand Down Expand Up @@ -134,6 +136,37 @@ class EnzymePlugin final : public clang::ASTConsumer {
Builder.defineMacro("ENZYME_VERSION_PATCH",
std::to_string(ENZYME_VERSION_PATCH));
CI.getPreprocessor().setPredefines(Predefines.str());

auto baseFS = CI.getFileManager().getVirtualFileSystemPtr();
llvm::vfs::OverlayFileSystem* fuseFS(
new llvm::vfs::OverlayFileSystem(baseFS));
IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> fs(
new llvm::vfs::InMemoryFileSystem());

struct tm y2k = {};

y2k.tm_hour = 0;
y2k.tm_min = 0;
y2k.tm_sec = 0;
y2k.tm_year = 100;
y2k.tm_mon = 0;
y2k.tm_mday = 1;
time_t timer = mktime(&y2k);
for (const auto &pair : include_headers) {
fs->addFile(StringRef(pair[0]), timer,
llvm::MemoryBuffer::getMemBuffer(StringRef(pair[1]), StringRef(pair[0]), /*RequiresNullTerminator*/true));

}

fuseFS->pushOverlay(fs);
fuseFS->pushOverlay(baseFS);
CI.getFileManager().setVirtualFileSystem(fuseFS);

auto DE = CI.getFileManager().getDirectoryRef("/enzymeroot");
assert(DE);
auto DL = DirectoryLookup(*DE, SrcMgr::C_User,
/*isFramework=*/false);
CI.getPreprocessor().getHeaderSearchInfo().AddSearchPath(DL, /*isAngled=*/true);
}
~EnzymePlugin() {}
void HandleTranslationUnit(ASTContext &context) override {}
Expand Down
255 changes: 250 additions & 5 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,258 @@ static Optional<StringRef> recursePhiReads(PHINode *val)
return finalMetadata;
}

Value *simplifyLoad(Value *LI, size_t valSz = 0);

// Find the base pointer of ptr and the offset in bytes from the start of
// the returned base pointer to this value.
AllocaInst *getBaseAndOffset(Value *ptr, size_t &offset) {
offset = 0;
while (true) {
if (auto CI = dyn_cast<CastInst>(ptr)) {
ptr = CI->getOperand(0);
continue;
}
if (auto CI = dyn_cast<GetElementPtrInst>(ptr)) {
auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout();
MapVector<Value *, APInt> VariableOffsets;
auto width = sizeof(size_t) * 8;
APInt Offset(width, 0);
bool success = collectOffset(cast<GEPOperator>(CI), DL, width,
VariableOffsets, Offset);
if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) {
return nullptr;
}
offset += Offset.getZExtValue();
ptr = CI->getOperand(0);
continue;
}
if (isa<AllocaInst>(ptr)) {
break;
}
if (auto LI = dyn_cast<LoadInst>(ptr)) {
if (auto S = simplifyLoad(LI)) {
ptr = S;
continue;
}
}
return nullptr;
}
return cast<AllocaInst>(ptr);
}

// Find all user instructions of AI, returning tuples of <instruction, value,
// byte offet from AI> Unlike a simple get users, this will recurse through any
// constant gep offsets and casts
SmallVector<std::tuple<Instruction *, Value *, size_t>, 1>
findAllUsersOf(Value *AI) {
SmallVector<std::pair<Value *, size_t>, 1> todo;
todo.emplace_back(AI, 0);

SmallVector<std::tuple<Instruction *, Value *, size_t>, 1> users;
while (todo.size()) {
auto pair = todo.pop_back_val();
Value *ptr = pair.first;
size_t suboff = pair.second;

for (auto U : ptr->users()) {
if (auto CI = dyn_cast<CastInst>(U)) {
todo.emplace_back(CI, suboff);
continue;
}
if (auto CI = dyn_cast<GetElementPtrInst>(U)) {
auto &DL = CI->getParent()->getParent()->getParent()->getDataLayout();
MapVector<Value *, APInt> VariableOffsets;
auto width = sizeof(size_t) * 8;
APInt Offset(width, 0);
bool success = collectOffset(cast<GEPOperator>(CI), DL, width,
VariableOffsets, Offset);

if (!success || VariableOffsets.size() != 0 || Offset.isNegative()) {
users.emplace_back(cast<Instruction>(U), ptr, suboff);
continue;
}
todo.emplace_back(CI, suboff + Offset.getZExtValue());
continue;
}
users.emplace_back(cast<Instruction>(U), ptr, suboff);
continue;
}
}
return users;
}

// Given a pointer, find all values of size `valSz` which could be loaded from
// that pointer when indexed at offset. If it is impossible to guarantee that
// the set contains all such values, set legal to false
SmallVector<Value *, 1> getAllLoadedValuesFrom(AllocaInst *ptr0, size_t offset,
size_t valSz, bool &legal) {
SmallVector<Value *, 1> options;

auto todo = findAllUsersOf(ptr0);
std::set<std::tuple<Instruction *, Value *, size_t>> seen;

while (todo.size()) {
auto pair = todo.pop_back_val();
if (seen.count(pair))
continue;
seen.insert(pair);
Instruction *U = std::get<0>(pair);
Value *ptr = std::get<1>(pair);
size_t suboff = std::get<2>(pair);

// Read only users do not set the memory inside of ptr
if (isa<LoadInst>(U)) {
continue;
}
if (auto MTI = dyn_cast<MemTransferInst>(U))
if (MTI->getOperand(0) != ptr) {
continue;
}
if (auto I = dyn_cast<Instruction>(U)) {
if (!I->mayWriteToMemory() && I->getType()->isVoidTy())
continue;
}

if (auto SI = dyn_cast<StoreInst>(U)) {
auto &DL = SI->getParent()->getParent()->getParent()->getDataLayout();

// We are storing into the ptr
if (SI->getPointerOperand() == ptr) {
auto storeSz =
(DL.getTypeStoreSizeInBits(SI->getValueOperand()->getType()) + 7) /
8;
// If store is before the load would start
if (storeSz + suboff <= offset)
continue;
// if store starts after load would start
if (offset + valSz <= suboff)
continue;

if (valSz == storeSz) {
options.push_back(SI->getValueOperand());
continue;
}
}

// We capture our pointer of interest, if it is stored into an alloca,
// all loads of said alloca would potentially store into.
if (SI->getValueOperand() == ptr) {
if (suboff == 0) {
size_t mid_offset = 0;
if (auto AI2 =
getBaseAndOffset(SI->getPointerOperand(), mid_offset)) {
bool sublegal = true;
auto ptrSz = (DL.getTypeStoreSizeInBits(ptr->getType()) + 7) / 8;
auto subPtrs =
getAllLoadedValuesFrom(AI2, mid_offset, ptrSz, sublegal);
if (!sublegal) {
legal = false;
return options;
}
for (auto subPtr : subPtrs) {
for (const auto &pair3 : findAllUsersOf(subPtr)) {
todo.emplace_back(pair3);
}
}
continue;
}
}
}
}

// If we copy into the ptr at a location that includes the offset, consider
// all sub uses
if (auto MTI = dyn_cast<MemTransferInst>(U)) {
if (auto CI = dyn_cast<ConstantInt>(MTI->getLength())) {
if (MTI->getOperand(0) == ptr && suboff == 0 &&
CI->getValue().uge(offset + valSz)) {
size_t midoffset = 0;
auto AI2 = getBaseAndOffset(MTI->getOperand(1), midoffset);
if (!AI2) {
legal = false;
return options;
}
if (midoffset != 0) {
legal = false;
return options;
}
for (const auto &pair3 : findAllUsersOf(AI2)) {
todo.emplace_back(pair3);
}
continue;
}
}
}

legal = false;
return options;
}

return options;
}

// Perform mem2reg/sroa to identify the innermost value being represented.
Value *simplifyLoad(Value *V, size_t valSz) {
if (auto LI = dyn_cast<LoadInst>(V)) {
if (valSz == 0) {
auto &DL = LI->getParent()->getParent()->getParent()->getDataLayout();
valSz = (DL.getTypeStoreSizeInBits(LI->getType()) + 7) / 8;
}

Value *ptr = LI->getPointerOperand();
size_t offset = 0;

auto AI = getBaseAndOffset(ptr, offset);
if (!AI)
return nullptr;

bool legal = true;
auto opts = getAllLoadedValuesFrom(AI, offset, valSz, legal);

if (!legal) {
return nullptr;
}
std::set<Value *> res;
for (auto opt : opts) {
Value *v2 = simplifyLoad(opt, valSz);
if (v2)
res.insert(v2);
else
res.insert(opt);
}
if (res.size() != 1) {
return nullptr;
}
Value *retval = *res.begin();
return retval;
}
if (auto EVI = dyn_cast<ExtractValueInst>(V)) {
bool allZero = true;
for (auto idx : EVI->getIndices()) {
if (idx != 0)
allZero = false;
}
if (valSz == 0) {
auto &DL = EVI->getParent()->getParent()->getParent()->getDataLayout();
valSz = (DL.getTypeStoreSizeInBits(EVI->getType()) + 7) / 8;
}
if (allZero)
if (auto LI = dyn_cast<LoadInst>(EVI->getAggregateOperand())) {
return simplifyLoad(LI, valSz);
}
}
return nullptr;
}

#if LLVM_VERSION_MAJOR > 16
std::optional<StringRef> getMetadataName(llvm::Value *res)
#else
Optional<StringRef> getMetadataName(llvm::Value *res)
#endif
{
if (auto S = simplifyLoad(res))
return getMetadataName(S);

if (auto av = dyn_cast<MetadataAsValue>(res)) {
return cast<MDString>(av->getMetadata())->getString();
} else if ((isa<LoadInst>(res) || isa<CastInst>(res)) &&
Expand Down Expand Up @@ -463,12 +709,11 @@ Optional<StringRef> getMetadataName(llvm::Value *res)
return gv->getName();
} else if (auto gv = dyn_cast<AllocaInst>(res)) {
return gv->getName();
} else {
if (isa<PHINode>(res)) {
return recursePhiReads(cast<PHINode>(res));
}
return {};
} else if (isa<PHINode>(res)) {
return recursePhiReads(cast<PHINode>(res));
}

return {};
}

static Value *adaptReturnedVector(Value *ret, Value *diffret,
Expand Down
Loading

0 comments on commit 990d404

Please sign in to comment.