Skip to content

Commit

Permalink
Comment branch
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexsandruss committed Aug 26, 2024
1 parent 524cf98 commit b92c277
Showing 1 changed file with 64 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1047,70 +1047,70 @@ Status PredictClassificationTask<algorithmFPType, cpu>::predictAllPointsByAllTre
const algorithmFPType * const aX = xBD.get();
// TODO: investigate why higher level parallelism for trees causes performance degradation
// (excessive memory and CPU resources usage), especially on systems with high number of cores
if (false)
{
daal::static_tls<algorithmFPType *> tlsData([=]() { return service_scalable_calloc<algorithmFPType, cpu>(_nClasses * nRowsOfRes); });

daal::static_threader_for(numberOfTrees, [&, nCols](const size_t iTree, size_t tid) {
const size_t treeSize = _aTree[iTree]->getNumberOfRows();
const DecisionTreeNode * const aNode = (const DecisionTreeNode *)(*_aTree[iTree]).getArray();
parallelPredict(aX, aNode, treeSize, nBlocks, nCols, _blockSize, residualSize, tlsData.local(tid), iTree);
});

const size_t nThreads = tlsData.nthreads();
const size_t localBlockSize = 256; // TODO: Why can't this be the class value _blockSize?
const size_t nBlocks = nRowsOfRes / localBlockSize + !!(nRowsOfRes % localBlockSize);

daal::threader_for(nBlocks, nBlocks, [&](const size_t iBlock) {
const size_t begin = iBlock * localBlockSize;
const size_t end = services::internal::min<cpu, size_t>(nRowsOfRes, begin + localBlockSize);

services::internal::service_memset_seq<algorithmFPType, cpu>(commonBufVal + begin * _nClasses, algorithmFPType(0),
(end - begin) * _nClasses);

for (size_t tid = 0; tid < nThreads; ++tid)
{
algorithmFPType * buf = tlsData.local(tid);
for (size_t i = begin; i < end; ++i)
{
for (size_t j = 0; j < _nClasses; ++j)
{
commonBufVal[i * _nClasses + j] += buf[i * _nClasses + j];
}
}
}

if (prob != nullptr)
{
for (size_t i = begin; i < end; ++i)
{
algorithmFPType sum(0);

for (size_t j = 0; j < _nClasses; ++j)
{
sum += commonBufVal[i * _nClasses + j];
}
sum = daal::algorithms::dtrees::training::internal::isZero<algorithmFPType, cpu>(sum) ? algorithmFPType(1) : sum;

for (size_t j = 0; j < _nClasses; ++j)
{
commonBufVal[i * _nClasses + j] = commonBufVal[i * _nClasses + j] / sum;
}
}
}

if (res != nullptr)
{
for (size_t i = begin; i < end; ++i)
{
res[i] = algorithmFPType(getMaxClass(commonBufVal + i * _nClasses));
}
}
});

tlsData.reduce([&](algorithmFPType * buf) { service_scalable_free<algorithmFPType, cpu>(buf); });
}
else
// if (numberOfTrees > _minTreesForThreading)
// {
// daal::static_tls<algorithmFPType *> tlsData([=]() { return service_scalable_calloc<algorithmFPType, cpu>(_nClasses * nRowsOfRes); });

// daal::static_threader_for(numberOfTrees, [&, nCols](const size_t iTree, size_t tid) {
// const size_t treeSize = _aTree[iTree]->getNumberOfRows();
// const DecisionTreeNode * const aNode = (const DecisionTreeNode *)(*_aTree[iTree]).getArray();
// parallelPredict(aX, aNode, treeSize, nBlocks, nCols, _blockSize, residualSize, tlsData.local(tid), iTree);
// });

// const size_t nThreads = tlsData.nthreads();
// const size_t localBlockSize = 256; // TODO: Why can't this be the class value _blockSize?
// const size_t nBlocks = nRowsOfRes / localBlockSize + !!(nRowsOfRes % localBlockSize);

// daal::threader_for(nBlocks, nBlocks, [&](const size_t iBlock) {
// const size_t begin = iBlock * localBlockSize;
// const size_t end = services::internal::min<cpu, size_t>(nRowsOfRes, begin + localBlockSize);

// services::internal::service_memset_seq<algorithmFPType, cpu>(commonBufVal + begin * _nClasses, algorithmFPType(0),
// (end - begin) * _nClasses);

// for (size_t tid = 0; tid < nThreads; ++tid)
// {
// algorithmFPType * buf = tlsData.local(tid);
// for (size_t i = begin; i < end; ++i)
// {
// for (size_t j = 0; j < _nClasses; ++j)
// {
// commonBufVal[i * _nClasses + j] += buf[i * _nClasses + j];
// }
// }
// }

// if (prob != nullptr)
// {
// for (size_t i = begin; i < end; ++i)
// {
// algorithmFPType sum(0);

// for (size_t j = 0; j < _nClasses; ++j)
// {
// sum += commonBufVal[i * _nClasses + j];
// }
// sum = daal::algorithms::dtrees::training::internal::isZero<algorithmFPType, cpu>(sum) ? algorithmFPType(1) : sum;

// for (size_t j = 0; j < _nClasses; ++j)
// {
// commonBufVal[i * _nClasses + j] = commonBufVal[i * _nClasses + j] / sum;
// }
// }
// }

// if (res != nullptr)
// {
// for (size_t i = begin; i < end; ++i)
// {
// res[i] = algorithmFPType(getMaxClass(commonBufVal + i * _nClasses));
// }
// }
// });

// tlsData.reduce([&](algorithmFPType * buf) { service_scalable_free<algorithmFPType, cpu>(buf); });
// }
// else
{
services::internal::service_memset<algorithmFPType, cpu>(commonBufVal, algorithmFPType(0), nRowsOfRes * _nClasses);

Expand Down

0 comments on commit b92c277

Please sign in to comment.