diff --git a/.packaging/build_tarballs.jl b/.packaging/build_tarballs.jl index f52e40d3998..9a76a181acd 100644 --- a/.packaging/build_tarballs.jl +++ b/.packaging/build_tarballs.jl @@ -22,7 +22,9 @@ sources = [ # These are the platforms we will build for by default, unless further # platforms are passed in on the command line -platforms = expand_cxxstring_abis(supported_platforms(; experimental=true)) +platforms = expand_cxxstring_abis(supported_platforms()) +# Exclude aarch64 FreeBSD for the time being +filter!(p -> !(Sys.isfreebsd(p) && arch(p) == "aarch64"), platforms) # Bash recipe for building across all platforms script = raw""" diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 32d7abe9f30..0405de12c32 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -2902,168 +2902,205 @@ class AdjointGenerator : public llvm::InstVisitor { } assert(size != 0); - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - auto vd = TR.query(MS.getOperand(0)).Data0().ShiftIndices(DL, 0, size, 0); + // Offsets of the form Optional, segment start, segment size + std::vector> toIterate; - if (!vd.isKnownPastPointer()) { - // If unknown type results, and zeroing known undef allocation, consider - // integers - if (auto CI = dyn_cast(MS.getOperand(1))) - if (CI->isZero()) { - auto root = getBaseObject(MS.getOperand(0)); - bool writtenTo = false; - bool undefMemory = - isa(root) || isAllocationCall(root, gutils->TLI); - if (auto arg = dyn_cast(root)) - if (arg->hasStructRetAttr()) - undefMemory = true; - if (undefMemory) { - Instruction *cur = MS.getPrevNode(); - while (cur) { - if (cur == root) - break; - if (auto MCI = dyn_cast(MS.getOperand(2))) { - if (auto II = dyn_cast(cur)) { - // If the start of the lifetime for more memory than being - // memset, its valid. - if (II->getIntrinsicID() == Intrinsic::lifetime_start) { - if (getBaseObject(II->getOperand(1)) == root) { - if (auto CI2 = dyn_cast(II->getOperand(0))) { - if (MCI->getValue().ule(CI2->getValue())) - break; + // Special handling mechanism to bypass TA limitations by supporting + // arbitrary sized types. + if (auto MD = hasMetadata(&MS, "enzyme_truetype")) { + toIterate = parseTrueType(MD, Mode, false); + } else { + auto &DL = gutils->newFunc->getParent()->getDataLayout(); + auto vd = TR.query(MS.getOperand(0)).Data0().ShiftIndices(DL, 0, size, 0); + + if (!vd.isKnownPastPointer()) { + // If unknown type results, and zeroing known undef allocation, consider + // integers + if (auto CI = dyn_cast(MS.getOperand(1))) + if (CI->isZero()) { + auto root = getBaseObject(MS.getOperand(0)); + bool writtenTo = false; + bool undefMemory = + isa(root) || isAllocationCall(root, gutils->TLI); + if (auto arg = dyn_cast(root)) + if (arg->hasStructRetAttr()) + undefMemory = true; + if (undefMemory) { + Instruction *cur = MS.getPrevNode(); + while (cur) { + if (cur == root) + break; + if (auto MCI = dyn_cast(MS.getOperand(2))) { + if (auto II = dyn_cast(cur)) { + // If the start of the lifetime for more memory than being + // memset, its valid. + if (II->getIntrinsicID() == Intrinsic::lifetime_start) { + if (getBaseObject(II->getOperand(1)) == root) { + if (auto CI2 = + dyn_cast(II->getOperand(0))) { + if (MCI->getValue().ule(CI2->getValue())) + break; + } } + cur = cur->getPrevNode(); + continue; } - cur = cur->getPrevNode(); - continue; } } + if (cur->mayWriteToMemory()) { + writtenTo = true; + break; + } + cur = cur->getPrevNode(); } - if (cur->mayWriteToMemory()) { - writtenTo = true; - break; - } - cur = cur->getPrevNode(); - } - if (!writtenTo) { - vd = TypeTree(BaseType::Pointer); - vd.insert({-1}, BaseType::Integer); + if (!writtenTo) { + vd = TypeTree(BaseType::Pointer); + vd.insert({-1}, BaseType::Integer); + } } } - } - } + } - if (!vd.isKnownPastPointer()) { - // If unknown type results, consider the intersection of all incoming. - if (isa(MS.getOperand(0)) || isa(MS.getOperand(0))) { - SmallVector todo = {MS.getOperand(0)}; - bool set = false; - SmallSet seen; - TypeTree vd2; - while (todo.size()) { - Value *cur = todo.back(); - todo.pop_back(); - if (seen.count(cur)) - continue; - seen.insert(cur); - if (auto PN = dyn_cast(cur)) { - for (size_t i = 0, end = PN->getNumIncomingValues(); i < end; i++) { - todo.push_back(PN->getIncomingValue(i)); + if (!vd.isKnownPastPointer()) { + // If unknown type results, consider the intersection of all incoming. + if (isa(MS.getOperand(0)) || + isa(MS.getOperand(0))) { + SmallVector todo = {MS.getOperand(0)}; + bool set = false; + SmallSet seen; + TypeTree vd2; + while (todo.size()) { + Value *cur = todo.back(); + todo.pop_back(); + if (seen.count(cur)) + continue; + seen.insert(cur); + if (auto PN = dyn_cast(cur)) { + for (size_t i = 0, end = PN->getNumIncomingValues(); i < end; + i++) { + todo.push_back(PN->getIncomingValue(i)); + } + continue; } - continue; - } - if (auto S = dyn_cast(cur)) { - todo.push_back(S->getTrueValue()); - todo.push_back(S->getFalseValue()); - continue; - } - if (auto CE = dyn_cast(cur)) { - if (CE->isCast()) { - todo.push_back(CE->getOperand(0)); + if (auto S = dyn_cast(cur)) { + todo.push_back(S->getTrueValue()); + todo.push_back(S->getFalseValue()); continue; } - } - if (auto CI = dyn_cast(cur)) { - todo.push_back(CI->getOperand(0)); - continue; - } - if (isa(cur)) - continue; - if (auto CI = dyn_cast(cur)) - if (CI->isZero()) + if (auto CE = dyn_cast(cur)) { + if (CE->isCast()) { + todo.push_back(CE->getOperand(0)); + continue; + } + } + if (auto CI = dyn_cast(cur)) { + todo.push_back(CI->getOperand(0)); continue; - auto curTT = TR.query(cur).Data0().ShiftIndices(DL, 0, size, 0); - if (!set) - vd2 = curTT; - else - vd2 &= curTT; - set = true; + } + if (isa(cur)) + continue; + if (auto CI = dyn_cast(cur)) + if (CI->isZero()) + continue; + auto curTT = TR.query(cur).Data0().ShiftIndices(DL, 0, size, 0); + if (!set) + vd2 = curTT; + else + vd2 &= curTT; + set = true; + } + vd = vd2; } - vd = vd2; } - } - if (!vd.isKnownPastPointer()) { - if (looseTypeAnalysis) { + if (!vd.isKnownPastPointer()) { + if (looseTypeAnalysis) { #if LLVM_VERSION_MAJOR < 17 - if (auto CI = dyn_cast(MS.getOperand(0))) { - if (auto PT = dyn_cast(CI->getSrcTy())) { - auto ET = PT->getPointerElementType(); - while (1) { - if (auto ST = dyn_cast(ET)) { - if (ST->getNumElements()) { - ET = ST->getElementType(0); + if (auto CI = dyn_cast(MS.getOperand(0))) { + if (auto PT = dyn_cast(CI->getSrcTy())) { + auto ET = PT->getPointerElementType(); + while (1) { + if (auto ST = dyn_cast(ET)) { + if (ST->getNumElements()) { + ET = ST->getElementType(0); + continue; + } + } + if (auto AT = dyn_cast(ET)) { + ET = AT->getElementType(); continue; } + break; } - if (auto AT = dyn_cast(ET)) { - ET = AT->getElementType(); - continue; + if (ET->isFPOrFPVectorTy()) { + vd = TypeTree(ConcreteType(ET->getScalarType())).Only(0, &MS); + goto known; + } + if (ET->isPointerTy()) { + vd = TypeTree(BaseType::Pointer).Only(0, &MS); + goto known; + } + if (ET->isIntOrIntVectorTy()) { + vd = TypeTree(BaseType::Integer).Only(0, &MS); + goto known; } - break; - } - if (ET->isFPOrFPVectorTy()) { - vd = TypeTree(ConcreteType(ET->getScalarType())).Only(0, &MS); - goto known; - } - if (ET->isPointerTy()) { - vd = TypeTree(BaseType::Pointer).Only(0, &MS); - goto known; } - if (ET->isIntOrIntVectorTy()) { - vd = TypeTree(BaseType::Integer).Only(0, &MS); - goto known; + } +#endif + if (auto gep = dyn_cast(MS.getOperand(0))) { + if (auto AT = dyn_cast(gep->getSourceElementType())) { + if (AT->getElementType()->isIntegerTy()) { + vd = TypeTree(BaseType::Integer).Only(0, &MS); + goto known; + } } } + EmitWarning("CannotDeduceType", MS, + "failed to deduce type of memset ", MS); + vd = TypeTree(BaseType::Pointer).Only(0, &MS); + goto known; } -#endif - if (auto gep = dyn_cast(MS.getOperand(0))) { - if (auto AT = dyn_cast(gep->getSourceElementType())) { - if (AT->getElementType()->isIntegerTy()) { - vd = TypeTree(BaseType::Integer).Only(0, &MS); - goto known; + std::string str; + raw_string_ostream ss(str); + ss << "Cannot deduce type of memset " << MS; + EmitNoTypeError(str, MS, gutils, BuilderZ); + return; + } + known:; + { + unsigned start = 0; + while (1) { + unsigned nextStart = size; + + auto dt = vd[{-1}]; + for (size_t i = start; i < size; ++i) { + bool Legal = true; + dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal); + if (!Legal) { + nextStart = i; + break; } } + if (!dt.isKnown()) { + TR.dump(); + llvm::errs() << " vd:" << vd.str() << " start:" << start + << " size: " << size << " dt:" << dt.str() << "\n"; + } + assert(dt.isKnown()); + toIterate.emplace_back(dt.isFloat(), start, nextStart - start); + + if (nextStart == size) + break; + start = nextStart; } - EmitWarning("CannotDeduceType", MS, "failed to deduce type of memset ", - MS); - vd = TypeTree(BaseType::Pointer).Only(0, &MS); - goto known; } - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of memset " << MS; - EmitNoTypeError(str, MS, gutils, BuilderZ); - return; } - known:; #if 0 unsigned dstalign = dstAlign.valueOrOne().value(); unsigned srcalign = srcAlign.valueOrOne().value(); #endif - unsigned start = 0; - Value *op1 = gutils->getNewFromOriginal(MS.getArgOperand(1)); Value *new_size = gutils->getNewFromOriginal(MS.getArgOperand(2)); Value *op3 = nullptr; @@ -3071,32 +3108,14 @@ class AdjointGenerator : public llvm::InstVisitor { op3 = gutils->getNewFromOriginal(MS.getOperand(3)); } - while (1) { - unsigned nextStart = size; - - auto dt = vd[{-1}]; - for (size_t i = start; i < size; ++i) { - bool Legal = true; - dt.checkedOrIn(vd[{(int)i}], /*PointerIntSame*/ true, Legal); - if (!Legal) { - nextStart = i; - break; - } - } - if (!dt.isKnown()) { - TR.dump(); - llvm::errs() << " vd:" << vd.str() << " start:" << start - << " size: " << size << " dt:" << dt.str() << "\n"; - } - assert(dt.isKnown()); - + for (auto &&[secretty, seg_start, seg_size] : toIterate) { Value *length = new_size; - if (nextStart != size) { - length = ConstantInt::get(new_size->getType(), nextStart); + if (seg_start != std::get<1>(toIterate.back())) { + length = ConstantInt::get(new_size->getType(), seg_start + seg_size); } - if (start != 0) + if (seg_start != 0) length = BuilderZ.CreateSub( - length, ConstantInt::get(new_size->getType(), start)); + length, ConstantInt::get(new_size->getType(), seg_start)); #if 0 unsigned subdstalign = dstalign; @@ -3118,7 +3137,6 @@ class AdjointGenerator : public llvm::InstVisitor { Value *shadow_dst = gutils->invertPointerM(MS.getOperand(0), BuilderZ); // TODO ponder forward split mode - Type *secretty = dt.isFloat(); if (!secretty && ((Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) || (Mode == DerivativeMode::ReverseModeCombined && forwardsShadow) || @@ -3130,9 +3148,9 @@ class AdjointGenerator : public llvm::InstVisitor { ValueType::Primal, ValueType::Primal}, BuilderZ, /*lookup*/ false); auto rule = [&](Value *op0) { - if (start != 0) { - Value *idxs[] = { - ConstantInt::get(Type::getInt32Ty(op0->getContext()), start)}; + if (seg_start != 0) { + Value *idxs[] = {ConstantInt::get( + Type::getInt32Ty(op0->getContext()), seg_start)}; op0 = BuilderZ.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()), op0, idxs); } @@ -3167,9 +3185,9 @@ class AdjointGenerator : public llvm::InstVisitor { op3l = gutils->lookupM(op3l, BuilderZ); length = gutils->lookupM(length, Builder2); auto rule = [&](Value *op0) { - if (start != 0) { - Value *idxs[] = { - ConstantInt::get(Type::getInt32Ty(op0->getContext()), start)}; + if (seg_start != 0) { + Value *idxs[] = {ConstantInt::get( + Type::getInt32Ty(op0->getContext()), seg_start)}; op0 = Builder2.CreateInBoundsGEP(Type::getInt8Ty(op0->getContext()), op0, idxs); } @@ -3206,10 +3224,6 @@ class AdjointGenerator : public llvm::InstVisitor { applyChainRule(Builder2, rule, gutils->lookupM(shadow_dst, Builder2)); } - - if (nextStart == size) - break; - start = nextStart; } } @@ -3275,111 +3289,181 @@ class AdjointGenerator : public llvm::InstVisitor { return; } - auto &DL = gutils->newFunc->getParent()->getDataLayout(); - auto vd = TR.query(orig_dst).Data0().ShiftIndices(DL, 0, size, 0); - vd |= TR.query(orig_src).Data0().ShiftIndices(DL, 0, size, 0); - for (size_t i = 0; i < MTI.getNumOperands(); i++) - if (MTI.getOperand(i) == orig_dst) - if (MTI.getAttributes().hasParamAttr(i, "enzyme_type")) { - auto attr = MTI.getAttributes().getParamAttr(i, "enzyme_type"); - auto TT = TypeTree::parse(attr.getValueAsString(), MTI.getContext()); - vd |= TT.Data0().ShiftIndices(DL, 0, size, 0); - break; - } + // Offsets of the form Optional, segment start, segment size + std::vector> toIterate; + IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&MTI)); - bool errorIfNoType = true; - if ((Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) && - (!gutils->isConstantValue(orig_src) && !gutils->runtimeActivity)) { - errorIfNoType = false; - } + // Special handling mechanism to bypass TA limitations by supporting + // arbitrary sized types. + if (auto MD = hasMetadata(&MTI, "enzyme_truetype")) { + toIterate = parseTrueType(MD, Mode, + !gutils->isConstantValue(orig_src) && + !gutils->runtimeActivity); + } else { + auto &DL = gutils->newFunc->getParent()->getDataLayout(); + auto vd = TR.query(orig_dst).Data0().ShiftIndices(DL, 0, size, 0); + vd |= TR.query(orig_src).Data0().ShiftIndices(DL, 0, size, 0); + for (size_t i = 0; i < MTI.getNumOperands(); i++) + if (MTI.getOperand(i) == orig_dst) + if (MTI.getAttributes().hasParamAttr(i, "enzyme_type")) { + auto attr = MTI.getAttributes().getParamAttr(i, "enzyme_type"); + auto TT = + TypeTree::parse(attr.getValueAsString(), MTI.getContext()); + vd |= TT.Data0().ShiftIndices(DL, 0, size, 0); + break; + } - IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&MTI)); + bool errorIfNoType = true; + if ((Mode == DerivativeMode::ForwardMode || + Mode == DerivativeMode::ForwardModeError) && + (!gutils->isConstantValue(orig_src) && !gutils->runtimeActivity)) { + errorIfNoType = false; + } - if (!vd.isKnownPastPointer()) { - if (looseTypeAnalysis) { - for (auto val : {orig_dst, orig_src}) { + if (!vd.isKnownPastPointer()) { + if (looseTypeAnalysis) { + for (auto val : {orig_dst, orig_src}) { #if LLVM_VERSION_MAJOR < 17 - if (auto CI = dyn_cast(val)) { - if (auto PT = dyn_cast(CI->getSrcTy())) { - auto ET = PT->getPointerElementType(); - while (1) { - if (auto ST = dyn_cast(ET)) { - if (ST->getNumElements()) { - ET = ST->getElementType(0); + if (auto CI = dyn_cast(val)) { + if (auto PT = dyn_cast(CI->getSrcTy())) { + auto ET = PT->getPointerElementType(); + while (1) { + if (auto ST = dyn_cast(ET)) { + if (ST->getNumElements()) { + ET = ST->getElementType(0); + continue; + } + } + if (auto AT = dyn_cast(ET)) { + ET = AT->getElementType(); continue; } + break; } - if (auto AT = dyn_cast(ET)) { - ET = AT->getElementType(); - continue; + if (ET->isFPOrFPVectorTy()) { + vd = + TypeTree(ConcreteType(ET->getScalarType())).Only(0, &MTI); + goto known; + } + if (ET->isPointerTy()) { + vd = TypeTree(BaseType::Pointer).Only(0, &MTI); + goto known; + } + if (ET->isIntOrIntVectorTy()) { + vd = TypeTree(BaseType::Integer).Only(0, &MTI); + goto known; } - break; - } - if (ET->isFPOrFPVectorTy()) { - vd = TypeTree(ConcreteType(ET->getScalarType())).Only(0, &MTI); - goto known; - } - if (ET->isPointerTy()) { - vd = TypeTree(BaseType::Pointer).Only(0, &MTI); - goto known; } - if (ET->isIntOrIntVectorTy()) { - vd = TypeTree(BaseType::Integer).Only(0, &MTI); - goto known; + } +#endif + if (auto gep = dyn_cast(val)) { + if (auto AT = dyn_cast(gep->getSourceElementType())) { + if (AT->getElementType()->isIntegerTy()) { + vd = TypeTree(BaseType::Integer).Only(0, &MTI); + goto known; + } } } } -#endif - if (auto gep = dyn_cast(val)) { - if (auto AT = dyn_cast(gep->getSourceElementType())) { - if (AT->getElementType()->isIntegerTy()) { - vd = TypeTree(BaseType::Integer).Only(0, &MTI); + // If the type is known, but outside of the known range + // (but the memcpy size is a variable), attempt to use + // the first type out of range as the memcpy type. + if (size == 1 && !isa(new_size)) { + for (auto ptr : {orig_dst, orig_src}) { + vd = TR.query(ptr).Data0().ShiftIndices(DL, 0, -1, 0); + if (vd.isKnownPastPointer()) { + ConcreteType mv(BaseType::Unknown); + size_t minInt = 0xFFFFFFFF; + for (const auto &pair : vd.getMapping()) { + if (pair.first.size() != 1) + continue; + if (minInt < (size_t)pair.first[0]) + continue; + minInt = pair.first[0]; + mv = pair.second; + } + assert(mv != BaseType::Unknown); + vd.insert({0}, mv); goto known; } } } + if (errorIfNoType) + EmitWarning("CannotDeduceType", MTI, + "failed to deduce type of copy ", MTI); + vd = TypeTree(BaseType::Pointer).Only(0, &MTI); + goto known; } - // If the type is known, but outside of the known range - // (but the memcpy size is a variable), attempt to use - // the first type out of range as the memcpy type. - if (size == 1 && !isa(new_size)) { - for (auto ptr : {orig_dst, orig_src}) { - vd = TR.query(ptr).Data0().ShiftIndices(DL, 0, -1, 0); - if (vd.isKnownPastPointer()) { - ConcreteType mv(BaseType::Unknown); - size_t minInt = 0xFFFFFFFF; - for (const auto &pair : vd.getMapping()) { - if (pair.first.size() != 1) - continue; - if (minInt < (size_t)pair.first[0]) - continue; - minInt = pair.first[0]; - mv = pair.second; + if (errorIfNoType) { + std::string str; + raw_string_ostream ss(str); + ss << "Cannot deduce type of copy " << MTI; + EmitNoTypeError(str, MTI, gutils, BuilderZ); + vd = TypeTree(BaseType::Integer).Only(0, &MTI); + } else { + vd = TypeTree(BaseType::Pointer).Only(0, &MTI); + } + } + + known:; + { + + unsigned start = 0; + while (1) { + unsigned nextStart = size; + + auto dt = vd[{-1}]; + for (size_t i = start; i < size; ++i) { + bool Legal = true; + auto tmp = dt; + auto next = vd[{(int)i}]; + tmp.checkedOrIn(next, /*PointerIntSame*/ true, Legal); + // Prevent fusion of {Anything, Float} since anything is an int rule + // but float requires zeroing. + if ((dt == BaseType::Anything && + (next != BaseType::Anything && next.isKnown())) || + (next == BaseType::Anything && + (dt != BaseType::Anything && dt.isKnown()))) + Legal = false; + if (!Legal) { + if (Mode == DerivativeMode::ForwardMode || + Mode == DerivativeMode::ForwardModeError) { + // if both are floats (of any type), forward mode is the same. + // + [potentially zero if const, otherwise copy] + // if both are int/pointer (of any type), also the same + // + copy + // if known non-constant, also the same + // + copy + if ((dt.isFloat() == nullptr) == + (vd[{(int)i}].isFloat() == nullptr)) { + Legal = true; + } + if (!gutils->isConstantValue(orig_src) && + !gutils->runtimeActivity) { + Legal = true; + } } - assert(mv != BaseType::Unknown); - vd.insert({0}, mv); - goto known; - } + if (!Legal) { + nextStart = i; + break; + } + } else + dt = tmp; + } + if (!dt.isKnown()) { + TR.dump(); + llvm::errs() << " vd:" << vd.str() << " start:" << start + << " size: " << size << " dt:" << dt.str() << "\n"; } + assert(dt.isKnown()); + toIterate.emplace_back(dt.isFloat(), start, nextStart - start); + + if (nextStart == size) + break; + start = nextStart; } - if (errorIfNoType) - EmitWarning("CannotDeduceType", MTI, "failed to deduce type of copy ", - MTI); - vd = TypeTree(BaseType::Pointer).Only(0, &MTI); - goto known; - } - if (errorIfNoType) { - std::string str; - raw_string_ostream ss(str); - ss << "Cannot deduce type of copy " << MTI; - EmitNoTypeError(str, MTI, gutils, BuilderZ); - vd = TypeTree(BaseType::Integer).Only(0, &MTI); - } else { - vd = TypeTree(BaseType::Pointer).Only(0, &MTI); } } - known:; // llvm::errs() << "MIT: " << MTI << "|size: " << size << " vd: " << // vd.str() << "\n"; @@ -3387,8 +3471,6 @@ class AdjointGenerator : public llvm::InstVisitor { unsigned dstalign = dstAlign.valueOrOne().value(); unsigned srcalign = srcAlign.valueOrOne().value(); - unsigned start = 0; - bool backwardsShadow = false; bool forwardsShadow = true; for (auto pair : gutils->backwardsOnlyShadows) { @@ -3402,73 +3484,26 @@ class AdjointGenerator : public llvm::InstVisitor { } } - while (1) { - unsigned nextStart = size; - - auto dt = vd[{-1}]; - for (size_t i = start; i < size; ++i) { - bool Legal = true; - auto tmp = dt; - auto next = vd[{(int)i}]; - tmp.checkedOrIn(next, /*PointerIntSame*/ true, Legal); - // Prevent fusion of {Anything, Float} since anything is an int rule - // but float requires zeroing. - if ((dt == BaseType::Anything && - (next != BaseType::Anything && next.isKnown())) || - (next == BaseType::Anything && - (dt != BaseType::Anything && dt.isKnown()))) - Legal = false; - if (!Legal) { - if (Mode == DerivativeMode::ForwardMode || - Mode == DerivativeMode::ForwardModeError) { - // if both are floats (of any type), forward mode is the same. - // + [potentially zero if const, otherwise copy] - // if both are int/pointer (of any type), also the same - // + copy - // if known non-constant, also the same - // + copy - if ((dt.isFloat() == nullptr) == - (vd[{(int)i}].isFloat() == nullptr)) { - Legal = true; - } - if (!gutils->isConstantValue(orig_src) && - !gutils->runtimeActivity) { - Legal = true; - } - } - if (!Legal) { - nextStart = i; - break; - } - } else - dt = tmp; - } - if (!dt.isKnown()) { - TR.dump(); - llvm::errs() << " vd:" << vd.str() << " start:" << start - << " size: " << size << " dt:" << dt.str() << "\n"; - } - assert(dt.isKnown()); - + for (auto &&[floatTy, seg_start, seg_size] : toIterate) { Value *length = new_size; - if (nextStart != size) { - length = ConstantInt::get(new_size->getType(), nextStart); + if (seg_start != std::get<1>(toIterate.back())) { + length = ConstantInt::get(new_size->getType(), seg_start + seg_size); } - if (start != 0) + if (seg_start != 0) length = BuilderZ.CreateSub( - length, ConstantInt::get(new_size->getType(), start)); + length, ConstantInt::get(new_size->getType(), seg_start)); unsigned subdstalign = dstalign; // todo make better alignment calculation if (dstalign != 0) { - if (start % dstalign != 0) { + if (seg_start % dstalign != 0) { dstalign = 1; } } unsigned subsrcalign = srcalign; // todo make better alignment calculation if (srcalign != 0) { - if (start % srcalign != 0) { + if (seg_start % srcalign != 0) { srcalign = 1; } } @@ -3486,8 +3521,8 @@ class AdjointGenerator : public llvm::InstVisitor { if (shadow_src == nullptr) shadow_src = gutils->getNewFromOriginal(orig_src); SubTransferHelper( - gutils, Mode, dt.isFloat(), ID, subdstalign, subsrcalign, - /*offset*/ start, gutils->isConstantValue(orig_dst), shadow_dst, + gutils, Mode, floatTy, ID, subdstalign, subsrcalign, + /*offset*/ seg_start, gutils->isConstantValue(orig_dst), shadow_dst, gutils->isConstantValue(orig_src), shadow_src, /*length*/ length, /*volatile*/ isVolatile, &MTI, /*allowForward*/ forwardsShadow, /*shadowsLookedup*/ false, @@ -3508,13 +3543,13 @@ class AdjointGenerator : public llvm::InstVisitor { if (ddst->getType()->isIntegerTy()) ddst = BuilderZ.CreateIntToPtr(ddst, getInt8PtrTy(ddst->getContext())); - if (start != 0) { + if (seg_start != 0) { ddst = BuilderZ.CreateConstInBoundsGEP1_64( - Type::getInt8Ty(ddst->getContext()), ddst, start); + Type::getInt8Ty(ddst->getContext()), ddst, seg_start); } CallInst *call; // TODO add gutils->runtimeActivity (correctness) - if (dt.isFloat() && gutils->isConstantValue(orig_src)) { + if (floatTy && gutils->isConstantValue(orig_src)) { call = BuilderZ.CreateMemSet( ddst, ConstantInt::get(Type::getInt8Ty(ddst->getContext()), 0), length, salign, isVolatile); @@ -3522,9 +3557,9 @@ class AdjointGenerator : public llvm::InstVisitor { if (dsrc->getType()->isIntegerTy()) dsrc = BuilderZ.CreateIntToPtr(dsrc, getInt8PtrTy(dsrc->getContext())); - if (start != 0) { + if (seg_start != 0) { dsrc = BuilderZ.CreateConstInBoundsGEP1_64( - Type::getInt8Ty(ddst->getContext()), dsrc, start); + Type::getInt8Ty(ddst->getContext()), dsrc, seg_start); } if (ID == Intrinsic::memmove) { call = BuilderZ.CreateMemMove(ddst, dalign, dsrc, salign, length); @@ -3552,10 +3587,6 @@ class AdjointGenerator : public llvm::InstVisitor { applyChainRule(BuilderZ, fwd_rule, shadow_dst, shadow_src); else applyChainRule(BuilderZ, rev_rule, shadow_dst, shadow_src); - - if (nextStart == size) - break; - start = nextStart; } eraseIfUnused(MTI); diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 9ca27452729..6bb49cdef3b 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -3685,6 +3685,69 @@ void EmitNoTypeError(const std::string &message, llvm::Instruction &inst, } } +std::vector> +parseTrueType(const llvm::MDNode *md, DerivativeMode Mode, bool const_src) { + std::vector> parsed; + for (size_t i = 0; i < md->getNumOperands(); i += 2) { + ConcreteType base( + llvm::cast(md->getOperand(i))->getString(), + md->getContext()); + auto size = llvm::cast( + llvm::cast(md->getOperand(i + 1)) + ->getValue()) + ->getSExtValue(); + parsed.emplace_back(base, size); + } + + std::vector> toIterate; + size_t idx = 0; + while (idx < parsed.size()) { + + auto dt = parsed[idx].first; + size_t start = parsed[idx].second; + size_t end = 0x0fffffff; + for (idx = idx + 1; idx < parsed.size(); ++idx) { + bool Legal = true; + auto tmp = dt; + auto next = parsed[idx].first; + tmp.checkedOrIn(next, /*PointerIntSame*/ true, Legal); + // Prevent fusion of {Anything, Float} since anything is an int rule + // but float requires zeroing. + if ((dt == BaseType::Anything && + (next != BaseType::Anything && next.isKnown())) || + (next == BaseType::Anything && + (dt != BaseType::Anything && dt.isKnown()))) + Legal = false; + if (!Legal) { + if (Mode == DerivativeMode::ForwardMode || + Mode == DerivativeMode::ForwardModeError) { + // if both are floats (of any type), forward mode is the same. + // + [potentially zero if const, otherwise copy] + // if both are int/pointer (of any type), also the same + // + copy + // if known non-constant, also the same + // + copy + if ((parsed[idx].first.isFloat() == nullptr) == + (parsed[idx - 1].first.isFloat() == nullptr)) { + Legal = true; + } + if (const_src) { + Legal = true; + } + } + if (!Legal) { + end = parsed[idx].second; + break; + } + } else + dt = tmp; + } + assert(dt.isKnown()); + toIterate.emplace_back(dt.isFloat(), start, end - start); + } + return toIterate; +} + void dumpModule(llvm::Module *mod) { llvm::errs() << *mod << "\n"; } void dumpValue(llvm::Value *val) { llvm::errs() << *val << "\n"; } diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 12717d9ff45..9b66730d14d 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -687,6 +687,9 @@ std::optional extractBLAS(llvm::StringRef in); llvm::Optional extractBLAS(llvm::StringRef in); #endif +std::vector> +parseTrueType(const llvm::MDNode *, DerivativeMode, bool const_src); + /// Create function for type that performs the derivative memcpy on floating /// point memory llvm::Function *getOrInsertDifferentialFloatMemcpy( diff --git a/enzyme/test/Enzyme/ReverseMode/memcpy-truetype.ll b/enzyme/test/Enzyme/ReverseMode/memcpy-truetype.ll new file mode 100644 index 00000000000..fee4d19a8e7 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/memcpy-truetype.ll @@ -0,0 +1,82 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -enzyme-preopt=false -S | FileCheck %s + +; Function Attrs: nounwind uwtable +define dso_local void @memcpy_float(i8* nocapture %dst, i8* nocapture readonly %src) #0 { +entry: + tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 %dst, i8* align 1 %src, i64 100000, i1 false), !enzyme_truetype !0 + ret void +} + +; Function Attrs: argmemonly nounwind +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) #1 + +; Function Attrs: nounwind uwtable +define dso_local void @dmemcpy_float(i8* %dst, i8* %dstp, i8* %src, i64 %n) local_unnamed_addr #0 { +entry: + tail call void (...) @__enzyme_autodiff.f64(void (i8*, i8*)* nonnull @memcpy_float, metadata !"enzyme_dup", i8* %dst, i8* %dstp, metadata !"enzyme_dup", i8* %src, i8* %src) #3 + ret void +} + +declare void @__enzyme_autodiff.f64(...) local_unnamed_addr + + +attributes #0 = { nounwind uwtable } +attributes #1 = { argmemonly nounwind } +attributes #2 = { noinline nounwind uwtable } + +!0 = !{!"Float@float", i64 0, !"Integer", i64 8, !"Float@float", i64 50000, !"Integer", i64 50008} + + +; CHECK: define internal void @diffememcpy_float(i8* nocapture %dst, i8* nocapture %"dst'", i8* nocapture readonly %src, i8* nocapture %"src'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = getelementptr inbounds i8, i8* %"dst'", i64 8 +; CHECK-NEXT: %1 = getelementptr inbounds i8, i8* %"src'", i64 8 +; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 %0, i8* align 1 %1, i64 49992, i1 false) +; CHECK-NEXT: %2 = getelementptr inbounds i8, i8* %"dst'", i64 50008 +; CHECK-NEXT: %3 = getelementptr inbounds i8, i8* %"src'", i64 50008 +; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 %2, i8* align 1 %3, i64 49992, i1 false) +; CHECK-NEXT: tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 1 %dst, i8* align 1 %src, i64 100000, i1 false) #{{[0-9]+}}, !enzyme_truetype +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: ; preds = %entry +; CHECK-NEXT: %4 = bitcast i8* %"dst'" to float* +; CHECK-NEXT: %5 = bitcast i8* %"src'" to float* +; CHECK-NEXT: br label %for.body.i + +; CHECK: for.body.i: ; preds = %for.body.i, %invertentry +; CHECK-NEXT: %idx.i = phi i64 [ 0, %invertentry ], [ %idx.next.i, %for.body.i ] +; CHECK-NEXT: %dst.i.i = getelementptr inbounds float, float* %4, i64 %idx.i +; CHECK-NEXT: %dst.i.l.i = load float, float* %dst.i.i, align 1 +; CHECK-NEXT: store float 0.000000e+00, float* %dst.i.i, align 1 +; CHECK-NEXT: %src.i.i = getelementptr inbounds float, float* %5, i64 %idx.i +; CHECK-NEXT: %src.i.l.i = load float, float* %src.i.i, align 1 +; CHECK-NEXT: %6 = fadd fast float %src.i.l.i, %dst.i.l.i +; CHECK-NEXT: store float %6, float* %src.i.i, align 1 +; CHECK-NEXT: %idx.next.i = add nuw i64 %idx.i, 1 +; CHECK-NEXT: %7 = icmp eq i64 2, %idx.next.i +; CHECK-NEXT: br i1 %7, label %__enzyme_memcpyadd_floatda1sa1.exit, label %for.body.i + +; CHECK: __enzyme_memcpyadd_floatda1sa1.exit: ; preds = %for.body.i +; CHECK-NEXT: %8 = getelementptr inbounds i8, i8* %"dst'", i64 50000 +; CHECK-NEXT: %9 = bitcast i8* %8 to float* +; CHECK-NEXT: %10 = getelementptr inbounds i8, i8* %"src'", i64 50000 +; CHECK-NEXT: %11 = bitcast i8* %10 to float* +; CHECK-NEXT: br label %for.body.i7 + +; CHECK: for.body.i7: ; preds = %for.body.i7, %__enzyme_memcpyadd_floatda1sa1.exit +; CHECK-NEXT: %idx.i1 = phi i64 [ 0, %__enzyme_memcpyadd_floatda1sa1.exit ], [ %idx.next.i6, %for.body.i7 ] +; CHECK-NEXT: %dst.i.i2 = getelementptr inbounds float, float* %9, i64 %idx.i1 +; CHECK-NEXT: %dst.i.l.i3 = load float, float* %dst.i.i2, align 1 +; CHECK-NEXT: store float 0.000000e+00, float* %dst.i.i2, align 1 +; CHECK-NEXT: %src.i.i4 = getelementptr inbounds float, float* %11, i64 %idx.i1 +; CHECK-NEXT: %src.i.l.i5 = load float, float* %src.i.i4, align 1 +; CHECK-NEXT: %12 = fadd fast float %src.i.l.i5, %dst.i.l.i3 +; CHECK-NEXT: store float %12, float* %src.i.i4, align 1 +; CHECK-NEXT: %idx.next.i6 = add nuw i64 %idx.i1, 1 +; CHECK-NEXT: %13 = icmp eq i64 2, %idx.next.i6 +; CHECK-NEXT: br i1 %13, label %__enzyme_memcpyadd_floatda1sa1.exit8, label %for.body.i7 + +; CHECK: __enzyme_memcpyadd_floatda1sa1.exit8: ; preds = %for.body.i7 +; CHECK-NEXT: ret void +; CHECK-NEXT: }