Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark #1617

Merged
merged 3 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 81 additions & 23 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7063,6 +7063,7 @@ Constraints::allSolutions(SCEVExpander &Exp, llvm::Type *T, Instruction *IP,
return {};
}

constexpr bool SparseDebug = false;
std::shared_ptr<const Constraints>
getSparseConditions(bool &legal, Value *val,
std::shared_ptr<const Constraints> defaultFloat,
Expand All @@ -7077,11 +7078,13 @@ getSparseConditions(bool &legal, Value *val,
auto res = lhs->andB(rhs, ctx);
assert(res);
assert(ctx.seen.size() == 0);
llvm::errs() << " getSparse(and, " << *I << "), lhs(" << *I->getOperand(0)
<< ") = " << *lhs << "\n";
llvm::errs() << " getSparse(and, " << *I << "), rhs(" << *I->getOperand(1)
<< ") = " << *rhs << "\n";
llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n";
if (SparseDebug) {
llvm::errs() << " getSparse(and, " << *I << "), lhs("
<< *I->getOperand(0) << ") = " << *lhs << "\n";
llvm::errs() << " getSparse(and, " << *I << "), rhs("
<< *I->getOperand(1) << ") = " << *rhs << "\n";
llvm::errs() << " getSparse(and, " << *I << ") = " << *res << "\n";
}
return res;
}

Expand All @@ -7092,11 +7095,13 @@ getSparseConditions(bool &legal, Value *val,
auto rhs = getSparseConditions(legal, I->getOperand(1),
Constraints::none(), I, ctx);
auto res = lhs->orB(rhs, ctx);
llvm::errs() << " getSparse(or, " << *I << "), lhs(" << *I->getOperand(0)
<< ") = " << *lhs << "\n";
llvm::errs() << " getSparse(or, " << *I << "), rhs(" << *I->getOperand(1)
<< ") = " << *rhs << "\n";
llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n";
if (SparseDebug) {
llvm::errs() << " getSparse(or, " << *I << "), lhs("
<< *I->getOperand(0) << ") = " << *lhs << "\n";
llvm::errs() << " getSparse(or, " << *I << "), rhs("
<< *I->getOperand(1) << ") = " << *rhs << "\n";
llvm::errs() << " getSparse(or, " << *I << ") = " << *res << "\n";
}
return res;
}

Expand All @@ -7108,9 +7113,12 @@ getSparseConditions(bool &legal, Value *val,
getSparseConditions(legal, I->getOperand(1 - i),
defaultFloat->notB(ctx), scope, ctx);
auto res = pres->notB(ctx);
llvm::errs() << " getSparse(not, " << *I << "), prev ("
<< *I->getOperand(0) << ") = " << *pres << "\n";
llvm::errs() << " getSparse(not, " << *I << ") = " << *res << "\n";
if (SparseDebug) {
llvm::errs() << " getSparse(not, " << *I << "), prev ("
<< *I->getOperand(0) << ") = " << *pres << "\n";
llvm::errs() << " getSparse(not, " << *I << ") = " << *res
<< "\n";
}
return res;
}
}
Expand All @@ -7120,8 +7128,10 @@ getSparseConditions(bool &legal, Value *val,
auto L = ctx.loopToSolve;
auto lhs = ctx.SE.getSCEVAtScope(icmp->getOperand(0), L);
auto rhs = ctx.SE.getSCEVAtScope(icmp->getOperand(1), L);
llvm::errs() << " lhs: " << *lhs << "\n";
llvm::errs() << " rhs: " << *rhs << "\n";
if (SparseDebug) {
llvm::errs() << " lhs: " << *lhs << "\n";
llvm::errs() << " rhs: " << *rhs << "\n";
}

auto sub1 = ctx.SE.getMinusSCEV(lhs, rhs);

Expand All @@ -7145,8 +7155,10 @@ getSparseConditions(bool &legal, Value *val,
auto res = Constraints::make_compare(
div, icmp->getPredicate() == ICmpInst::ICMP_EQ,
add->getLoop(), ctx);
llvm::errs()
<< " getSparse(icmp, " << *I << ") = " << *res << "\n";
if (SparseDebug) {
llvm::errs()
<< " getSparse(icmp, " << *I << ") = " << *res << "\n";
}
return res;
}
}
Expand All @@ -7172,7 +7184,9 @@ getSparseConditions(bool &legal, Value *val,
// cmp x, 1.0 -> false/true
if (auto fcmp = dyn_cast<FCmpInst>(I)) {
auto res = defaultFloat;
llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n";
if (SparseDebug) {
llvm::errs() << " getSparse(fcmp, " << *I << ") = " << *res << "\n";
}
return res;

if (fcmp->getPredicate() == CmpInst::FCMP_OEQ ||
Expand Down Expand Up @@ -7263,13 +7277,16 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
// Full simplification
while (!Q.empty()) {
auto cur = Q.pop_back_val();
/*
std::set<Instruction *> prev;
for (auto v : Q)
prev.insert(v);
// llvm::errs() << "\n\n\n\n" << F << "\n";
llvm::errs() << "cur: " << *cur << "\n";
*/
auto changed = fixSparse_inner(cur, F, Q, DT, SE, LI, DL);
(void)changed;
/*
if (changed) {
llvm::errs() << "changed: " << *changed << "\n";

Expand All @@ -7278,6 +7295,7 @@ void fixSparseIndices(llvm::Function &F, llvm::FunctionAnalysisManager &FAM,
llvm::errs() << " + " << *I << "\n";
// llvm::errs() << F << "\n\n";
}
*/
}

// llvm::errs() << " post fix inner " << F << "\n";
Expand Down Expand Up @@ -7872,6 +7890,7 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F,
args.push_back(diff);
for (size_t i = argstart; i < num_args; i++)
args.push_back(CI->getArgOperand(i));

if (load_fn->getFunctionType()->getNumParams() != args.size()) {
auto fnName = load_fn->getName();
auto found_numargs = load_fn->getFunctionType()->getNumParams();
Expand All @@ -7893,7 +7912,7 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F,
*args[i]->getType(), " found ",
load_fn->getFunctionType()->params()[i]);
tocontinue = true;
break;
args[i] = UndefValue::get(args[i]->getType());
}
}
if (tocontinue)
Expand All @@ -7902,8 +7921,18 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F,
CallInst *call = B.CreateCall(load_fn, args);
call->setDebugLoc(LI->getDebugLoc());
Value *tmp = call;
if (tmp->getType() != LI->getType())
tmp = B.CreateBitCast(tmp, LI->getType());
if (tmp->getType() != LI->getType()) {
if (CastInst::castIsValid(Instruction::BitCast, tmp, LI->getType()))
tmp = B.CreateBitCast(tmp, LI->getType());
else {
auto fnName = load_fn->getName();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" incorrect return type of loader function ", fnName,
" expected ", *LI->getType(), " found ",
*call->getType());
tmp = UndefValue::get(LI->getType());
}
}
LI->replaceAllUsesWith(tmp);

if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) {
Expand All @@ -7927,15 +7956,44 @@ void replaceToDense(llvm::CallBase *CI, bool replaceAll, llvm::Function *F,
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" first argument of store function must be the type of "
"the store found fn arg type ",
sty, " expected ", args0ty);
*sty, " expected ", *args0ty);
args[0] = UndefValue::get(sty);
}
}
args.push_back(diff);
for (size_t i = argstart; i < num_args; i++)
args.push_back(CI->getArgOperand(i));

if (store_fn->getFunctionType()->getNumParams() != args.size()) {
auto fnName = store_fn->getName();
auto found_numargs = store_fn->getFunctionType()->getNumParams();
auto expected_numargs = args.size();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" incorrect number of arguments to store function ", fnName,
" expected ", expected_numargs, " found ", found_numargs,
" - ", *store_fn->getFunctionType());
continue;
} else {
bool tocontinue = false;
for (size_t i = 0; i < args.size(); i++) {
if (store_fn->getFunctionType()->getParamType(i) !=
args[i]->getType()) {
auto fnName = store_fn->getName();
EmitFailure("IllegalSparse", CI->getDebugLoc(), CI,
" incorrect type of argument ", i,
" to storeer function ", fnName, " expected ",
*args[i]->getType(), " found ",
store_fn->getFunctionType()->params()[i]);
tocontinue = true;
args[i] = UndefValue::get(args[i]->getType());
}
}
if (tocontinue)
continue;
}
auto call = B.CreateCall(store_fn, args);
call->setDebugLoc(SI->getDebugLoc());
if (load_fn->hasFnAttribute(Attribute::AlwaysInline)) {
if (store_fn->hasFnAttribute(Attribute::AlwaysInline)) {
InlineFunctionInfo IFI;
InlineFunction(*call, IFI);
}
Expand Down
70 changes: 27 additions & 43 deletions enzyme/test/Integration/Sparse/eigen_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,41 +150,6 @@ static void gradient_ip(const T *__restrict__ pos0, const size_t num_faces, cons
enzyme_dup, x, out);
}


template<typename T>
__attribute__((always_inline))
static T ident_load(unsigned long long offset, size_t i) {
return (offset / sizeof(T) == i) ? T(1) : T(0);
}


template<typename T>
__attribute__((always_inline))
static void err_store(T val, unsigned long long offset, size_t i) {
assert(0 && "store is not legal");
}


template<typename T>
__attribute__((always_inline))
static T zero_load(unsigned long long offset, size_t i, std::vector<Triple<T>> &hess) {
return T(0);
}


__attribute__((enzyme_sparse_accumulate))
void inner_store(size_t offset, size_t i, float val, std::vector<Triple<float>> &hess) {
hess.push_back(Triple<float>(offset, i, val));
}

template<typename T>
__attribute__((always_inline))
static void csr_store(T val, unsigned long long offset, size_t i, std::vector<Triple<T>> &hess) {
if (val == 0.0) return;
offset /= sizeof(T);
inner_store(offset, i, val, hess);
}

template<typename T>
__attribute__((noinline))
std::vector<Triple<T>> hessian(const T*__restrict__ pos0, size_t num_faces, const int* faces, const T*__restrict__ x, size_t x_pts)
Expand Down Expand Up @@ -217,13 +182,20 @@ std::vector<Triple<T>> hessian(const T*__restrict__ pos0, size_t num_faces, cons
enzyme_const, pos02,
enzyme_const, num_faces,
enzyme_const, faces,
enzyme_dup, x2, __enzyme_todense<T*>(ident_load<T>, err_store<T>, i),
enzyme_dupnoneed, nullptr, __enzyme_todense<T*>(zero_load<T>, csr_store<T>, i, &hess));
enzyme_dup, x2, __enzyme_todense<T*>(ident_load<T>, ident_store<T>, i),
enzyme_dupnoneed, nullptr, __enzyme_todense<T*>(sparse_load<T>, sparse_store<T>, i, &hess));
return hess;
}

int main() {
const size_t x_pts = 1;
int main(int argc, char** argv) {
size_t x_pts = 8;

if (argc >= 2) {
x_pts = atoi(argv[1]);
}

// TODO generate data for more inputs
assert(x_pts == 8);
const float x[] = {0.0, 1.0, 0.0};


Expand All @@ -233,25 +205,37 @@ int main() {
const float pos0[] = {1.0, 2.0, 3.0, 4.0, 3.0, 2.0, 3.0, 1.0, 3.0};

// Call eigenstuffM_simple
struct timeval start, end;
gettimeofday(&start, NULL);
const float resultM = eigenstuffM(pos0, num_faces, faces, x);
printf("Result for eigenstuffM_simple: %f\n", resultM);
gettimeofday(&end, NULL);
printf("Result for eigenstuffM_simple: %f, runtime:%f\n", resultM, tdiff(&start, &end));

// Call eigenstuffL_simple
gettimeofday(&start, NULL);
const float resultL = eigenstuffL(pos0, num_faces, faces, x);
printf("Result for eigenstuffL_simple: %f\n", resultL);
gettimeofday(&end, NULL);
printf("Result for eigenstuffL_simple: %f, runtime:%f\n", resultL, tdiff(&start, &end));

float dx[sizeof(x)/sizeof(x[0])];
for (size_t i=0; i<sizeof(dx)/sizeof(x[0]); i++)
dx[i] = 0;
gradient_ip(pos0, num_faces, faces, x, dx);

if (x_pts < 30) {
for (size_t i=0; i<sizeof(dx)/sizeof(dx[0]); i++)
printf("eigenstuffM grad_vert[%zu]=%f\n", i, dx[i]);

size_t num_elts = sizeof(x)/sizeof(x[0]) * sizeof(x)/sizeof(x[0]);
}

gettimeofday(&start, NULL);
auto hess_x = hessian(pos0, num_faces, faces, x, x_pts);
gettimeofday(&end, NULL);

printf("Number of elements %ld\n", hess_x.size());

printf("Runtime %0.6f\n", tdiff(&start, &end));

if (x_pts <= 8)
for (auto &hess : hess_x) {
printf("i=%lu, j=%lu, val=%f\n", hess.row, hess.col, hess.val);
}
Expand Down
Loading
Loading