Skip to content

Commit

Permalink
adding binary transform scan and test
Browse files Browse the repository at this point in the history
  • Loading branch information
andyD123 committed Sep 1, 2023
1 parent e099546 commit 058078b
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 2 deletions.
136 changes: 134 additions & 2 deletions VectorTest/TestScan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void testScan1(int SZ ,double start)



TEST(TestScanTransform, scanShortVector)
TEST(TestScan, scanShortVector)
{

for (int SZ = 3; SZ < 33; SZ++)
Expand All @@ -137,9 +137,141 @@ TEST(TestScanTransform, scanShortVector)
testScan1(SZ,3.14);
}


}





void testTransformScan1(int SZ, double start)
{


std::vector<Numeric> input(SZ, asNumber(0.0));
std::iota(begin(input), end(input), asNumber(start));


Numeric err = getErr(Numeric(0.));

VecXX testVec(input);
auto SQR = [](auto x) { return x * x; };

auto sqrVec = transform( [](auto x) {return x * x; }, testVec);
std::vector< Numeric> sq = sqrVec;
auto add = [](auto x, auto y) {return x + y; };


for (int j = 0; j < SZ; ++j)
{

auto res = ApplyTransformScan(testVec, add, SQR);

std::vector<Numeric> dbg = res;

std::vector<Numeric> expected;
std::inclusive_scan(cbegin(sq), cend(sq), std::back_inserter(expected));

EXPECT_NEAR(expected[0], res[0], err);

for (int k = 1; k < SZ; k++)
{
auto relErr = err * std::max(Numeric(1.), std::abs(Numeric(expected[k])));
EXPECT_NEAR(expected[k], res[k], relErr);

}
}

}






TEST(TestTransformScanTransform, transformScanShortVector)
{

for (int SZ = 3; SZ < 33; SZ++)
{
testTransformScan1(SZ,0);
}

for (int SZ = 3; SZ < 133; SZ++)
{
testTransformScan1(SZ, 3.14);
}

}







void testTransformScan2(int SZ, double start)
{


std::vector<Numeric> input(SZ, asNumber(0.0));
std::iota(begin(input), end(input), asNumber(start));


Numeric err = getErr(Numeric(0.));

VecXX testVec(input);

VecXX testVec1 = testVec + 1.0;

auto MULT = [](auto x,auto y) { return x * y; };

auto multVec = testVec * testVec1;


std::vector< Numeric> prod = multVec;
auto add = [](auto x, auto y) {return x + y; };


for (int j = 0; j < SZ; ++j)
{

auto res = ApplyTransformScan(testVec, testVec1, add, MULT);

std::vector<Numeric> dbg = res;

std::vector<Numeric> expected;
std::inclusive_scan(cbegin(prod), cend(prod), std::back_inserter(expected));

EXPECT_NEAR(expected[0], res[0], err);

for (int k = 1; k < SZ; k++)
{
auto relErr = err * std::max(Numeric(1.), std::abs(Numeric(expected[k])));
EXPECT_NEAR(expected[k], res[k], relErr);

}
}

}






TEST(TestTransformScanTransform, transformScanShortVectorBinary)
{

for (int SZ = 3; SZ < 33; SZ++)
{
testTransformScan2(SZ, 0);
}

for (int SZ = 3; SZ < 133; SZ++)
{
testTransformScan2(SZ, 3.14);
}

}


107 changes: 107 additions & 0 deletions Vectorisation/VecX/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,110 @@ Vec<INS_VEC> ApplyTransformScan(const Vec<INS_VEC>& rhs1, OP& oper, TRANSFORM& t
return result;

};


template< typename INS_VEC, typename OP, typename TRANSFORM>
Vec<INS_VEC> ApplyTransformScan(const Vec<INS_VEC>& lhs1, const Vec<INS_VEC>& rhs1, OP& oper, TRANSFORM& transform)
{


check_vector(rhs1);
check_vector(lhs1);
if (isScalar(rhs1))
{
throw std::range_error("ApplyScan called on scalar");
}

Vec<INS_VEC> result(rhs1.size());

auto pRes = result.start();
auto pLhs1 = lhs1.start();
auto pRhs1 = rhs1.start();

constexpr int width = InstructionTraits<INS_VEC>::width;
int step = 4 * width;


INS_VEC runValue = 0.0;

INS_VEC contValue = 0.0;

int sz = rhs1.size();


//https://gfxcourses.stanford.edu/cs149/fall20content/media/dataparallel/08_dataparallel.pdf

constexpr int LAST_ELEM = width - 1;

INS_VEC LHS1;
INS_VEC LHS2;
INS_VEC LHS3;
INS_VEC LHS4;

INS_VEC RHS1;
INS_VEC RHS2;
INS_VEC RHS3;
INS_VEC RHS4;


int i = 0;
for (; i < (sz - step); i += step)
{



RHS1.load_a(pRhs1 + i);
RHS2.load_a(pRhs1 + i + width);
RHS3.load_a(pRhs1 + i + 2 * width);
RHS4.load_a(pRhs1 + i + 3 * width);

LHS1.load_a(pLhs1 + i);
LHS2.load_a(pLhs1 + i + width);
LHS3.load_a(pLhs1 + i + 2 * width);
LHS4.load_a(pLhs1 + i + 3 * width);

auto Res1 = scanN(transform(LHS1,RHS1), runValue, oper);
auto Res2 = scanN(transform(LHS2,RHS2), runValue, oper);
auto Res3 = scanN(transform(LHS3, RHS3), runValue, oper);
auto Res4 = scanN(transform(LHS4, RHS4), runValue, oper);

INS_VEC runValue1 = Res1[LAST_ELEM];
INS_VEC runValue2 = Res2[LAST_ELEM];
INS_VEC runValue3 = Res3[LAST_ELEM];
INS_VEC runValue4 = Res4[LAST_ELEM];

auto sv12 = oper(runValue1, runValue2);
auto sv13 = oper(sv12, runValue3);
auto s14 = oper(sv13, runValue4);

Res1 = oper(Res1, contValue);
Res2 = oper(Res2, oper(runValue1, contValue));
Res3 = oper(Res3, oper(sv12, contValue));
Res4 = oper(Res4, oper(sv13, contValue));


contValue = oper(contValue, s14);

Res1.store_a(pRes + i);
Res2.store_a(pRes + i + width);
Res3.store_a(pRes + i + 2 * width);
Res4.store_a(pRes + i + 3 * width);

}

if (i == 0) // small and need to init first element
{
i++;
result[0] = transform(INS_VEC(lhs1[0]),INS_VEC(rhs1[0]))[0];
};
for (int j = i; j < sz; j++)
{
auto trans_of_rhs1_j = transform(INS_VEC(lhs1[j]),INS_VEC(rhs1[j]))[0];

result[j] = ApplyBinaryOperation1<INS_VEC, OP>(trans_of_rhs1_j, result[j - 1], oper);
}

return result;

};

0 comments on commit 058078b

Please sign in to comment.