Skip to content

Commit

Permalink
fix pred_interactions
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuber21 committed Sep 18, 2023
1 parent b7e0738 commit cc21b23
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -199,36 +199,36 @@ protected:

template <bool hasUnorderedFeatures, bool hasAnyMissing>
inline services::Status predictContributionInteractions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x,
algorithmFPType * res, const DimType & dim);
const algorithmFPType * nominal, algorithmFPType * res, const DimType & dim);

template <bool hasAnyMissing>
inline services::Status predictContributionInteractions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x,
algorithmFPType * res, const DimType & dim)
const algorithmFPType * nominal, algorithmFPType * res, const DimType & dim)
{
if (_featHelper.hasUnorderedFeatures())
{
return predictContributionInteractions<true, hasAnyMissing>(iTree, nTrees, nRowsData, x, res, dim);
return predictContributionInteractions<true, hasAnyMissing>(iTree, nTrees, nRowsData, x, nominal, res, dim);
}
else
{
return predictContributionInteractions<false, hasAnyMissing>(iTree, nTrees, nRowsData, x, res, dim);
return predictContributionInteractions<false, hasAnyMissing>(iTree, nTrees, nRowsData, x, nominal, res, dim);
}
}

// TODO: Add vectorBlockSize templates, similar to predict
// template <size_t vectorBlockSize>
inline services::Status predictContributionInteractions(size_t iTree, size_t nTrees, size_t nRowsData, const algorithmFPType * x,
algorithmFPType * res, const DimType & dim)
const algorithmFPType * nominal, algorithmFPType * res, const DimType & dim)
{
const size_t nColumnsData = dim.nCols;
const bool hasAnyMissing = checkForMissing(x, nTrees, nRowsData, nColumnsData);
if (hasAnyMissing)
{
return predictContributionInteractions<true>(iTree, nTrees, nRowsData, x, res, dim);
return predictContributionInteractions<true>(iTree, nTrees, nRowsData, x, nominal, res, dim);
}
else
{
return predictContributionInteractions<false>(iTree, nTrees, nRowsData, x, res, dim);
return predictContributionInteractions<false>(iTree, nTrees, nRowsData, x, nominal, res, dim);
}
}

Expand Down Expand Up @@ -352,15 +352,17 @@ void PredictRegressionTask<algorithmFPType, cpu>::predictContributions(size_t iT

const gbt::internal::GbtDecisionTree * currentTree = _aTree[currentTreeIndex];
const void * endAddr = static_cast<void *>(&(*uniquePathData.end()));
printf("\n\n\n--> depth = %d | uniquePathData.end() = %p\n\n", depth, endAddr);
gbt::treeshap::treeShap<algorithmFPType, hasUnorderedFeatures, hasAnyMissing>(currentTree, currentX, phi, nColumnsData, &_featHelper,
uniquePathData.data(), condition, conditionFeature);
printf("treeShap is done\n");
}

for (int iFeature = 0; iFeature < nColumnsData; ++iFeature)
if (condition == 0)
{
phi[biasTermIndex] -= phi[iFeature];
// find bias term by leveraging bias = nominal - sum_i phi_i
for (int iFeature = 0; iFeature < nColumnsData; ++iFeature)
{
phi[biasTermIndex] -= phi[iFeature];
}
}
}
}
Expand All @@ -379,12 +381,16 @@ void PredictRegressionTask<algorithmFPType, cpu>::predictContributions(size_t iT
template <typename algorithmFPType, CpuType cpu>
template <bool hasUnorderedFeatures, bool hasAnyMissing>
services::Status PredictRegressionTask<algorithmFPType, cpu>::predictContributionInteractions(size_t iTree, size_t nTrees, size_t nRowsData,
const algorithmFPType * x, algorithmFPType * res,
const algorithmFPType * x,
const algorithmFPType * nominal, algorithmFPType * res,
const DimType & dim)
{
Status st;
const size_t nColumnsData = dim.nCols;
const size_t nColumnsPhi = nColumnsData + 1;
const size_t nColumnsData = dim.nCols;
const size_t nColumnsPhi = nColumnsData + 1;
const size_t biasTermIndex = nColumnsPhi - 1;

const size_t interactionMatrixSize = nColumnsPhi * nColumnsPhi;

// Allocate buffer for 3 matrices for algorithmFPType of size (nRowsData, nColumnsData)
const size_t elementsInMatrix = nRowsData * nColumnsPhi;
Expand All @@ -396,33 +402,47 @@ services::Status PredictRegressionTask<algorithmFPType, cpu>::predictContributio
return st;
}

// Initialize buffer
service_memset_seq<algorithmFPType, cpu>(buffer, algorithmFPType(0), 3 * elementsInMatrix);

// Get pointers into the buffer for our three matrices
algorithmFPType * contribsDiag = buffer + 0 * elementsInMatrix;
algorithmFPType * contribsOff = buffer + 1 * elementsInMatrix;
algorithmFPType * contribsOn = buffer + 2 * elementsInMatrix;

// Initialize nominal buffer
service_memset_seq<algorithmFPType, cpu>(contribsDiag, algorithmFPType(0), elementsInMatrix);

// Copy nominal values (for bias term) to the condition == 0 buffer
PRAGMA_IVDEP
PRAGMA_VECTOR_ALWAYS
for (size_t i = 0; i < nRowsData; ++i)
{
contribsDiag[i * nColumnsPhi + biasTermIndex] = nominal[i];
}

predictContributions(iTree, nTrees, nRowsData, x, contribsDiag, 0, 0, dim);
for (size_t i = 0; i < nColumnsPhi + 1; ++i)
for (size_t i = 0; i < nColumnsPhi; ++i)
{
// initialize/reset the on/off buffers
service_memset_seq<algorithmFPType, cpu>(contribsOff, algorithmFPType(0), 2 * elementsInMatrix);

predictContributions(iTree, nTrees, nRowsData, x, contribsOff, -1, i, dim);
predictContributions(iTree, nTrees, nRowsData, x, contribsOn, 1, i, dim);

for (size_t j = 0; j < nRowsData; ++j)
{
for (size_t k = 0; k < nColumnsPhi + 1; ++k)
const unsigned o_offset = j * interactionMatrixSize + i * nColumnsPhi;
const unsigned c_offset = j * nColumnsPhi;
res[o_offset + i] = 0;
for (size_t k = 0; k < nColumnsPhi; ++k)
{
// fill in the diagonal with additive effects, and off-diagonal with the interactions
if (k == i)
{
res[i] += contribsDiag[k];
res[o_offset + i] += contribsDiag[c_offset + k];
}
else
{
res[k] = (contribsOn[k] - contribsOff[k]) / 2.0;
res[i] -= res[k];
res[o_offset + k] = (contribsOn[c_offset + k] - contribsOff[c_offset + k]) / 2.0f;
res[o_offset + i] -= res[o_offset + k];
}
}
}
Expand Down Expand Up @@ -483,11 +503,20 @@ services::Status PredictRegressionTask<algorithmFPType, cpu>::runInternal(servic
WriteOnlyRows<algorithmFPType, cpu> resRow(result, iStartRow, nRowsToProcess);
DAAL_CHECK_BLOCK_STATUS_THR(resRow);

// TODO: Might need the nominal prediction to account for bias terms
// predict(iTree, nTreesToUse, nRowsToProcess, xBD.get(), resRow.get(), dim, resultNColumns);
// nominal values are required to calculate the correct bias term
algorithmFPType * nominal = static_cast<algorithmFPType *>(daal_malloc(nRowsToProcess * sizeof(algorithmFPType)));
if (!nominal)
{
safeStat.add(ErrorMemoryAllocationFailed);
return;
}
service_memset_seq<algorithmFPType, cpu>(nominal, algorithmFPType(0), nRowsToProcess);
predict(iTree, nTreesToUse, nRowsToProcess, xBD.get(), nominal, dim, 1);

// TODO: support tree weights
safeStat |= predictContributionInteractions(iTree, nTreesToUse, nRowsToProcess, xBD.get(), resRow.get(), dim);
safeStat |= predictContributionInteractions(iTree, nTreesToUse, nRowsToProcess, xBD.get(), nominal, resRow.get(), dim);

daal_free(nominal);
}
else
{
Expand Down
34 changes: 7 additions & 27 deletions cpp/daal/src/algorithms/dtrees/gbt/treeshap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,30 @@ namespace internal
{

// extend our decision path with a fraction of one and zero extensions
void treeShapExtendPath(PathElement * uniquePath, size_t uniqueDepth, float zeroFraction, float oneFraction, FeatureIndexType featureIndex)
void extendPath(PathElement * uniquePath, size_t uniqueDepth, float zeroFraction, float oneFraction, int featureIndex)
{
uniquePath[uniqueDepth].featureIndex = featureIndex;
uniquePath[uniqueDepth].zeroFraction = zeroFraction;
uniquePath[uniqueDepth].oneFraction = oneFraction;
uniquePath[uniqueDepth].partialWeight = (uniqueDepth == 0 ? 1.0f : 0.0f);

const float constant = 1.0f / static_cast<float>(uniqueDepth + 1);
for (int i = uniqueDepth - 1; i >= 0; i--)
{
uniquePath[i + 1].partialWeight += oneFraction * uniquePath[i].partialWeight * (i + 1) / static_cast<float>(uniqueDepth + 1);
uniquePath[i].partialWeight = zeroFraction * uniquePath[i].partialWeight * (uniqueDepth - i) / static_cast<float>(uniqueDepth + 1);
uniquePath[i + 1].partialWeight += oneFraction * uniquePath[i].partialWeight * (i + 1) * constant;
uniquePath[i].partialWeight = zeroFraction * uniquePath[i].partialWeight * (uniqueDepth - i) * constant;
}
}

// undo a previous extension of the decision path
void treeShapUnwindPath(PathElement * uniquePath, size_t uniqueDepth, size_t pathIndex)
void unwindPath(PathElement * uniquePath, size_t uniqueDepth, size_t pathIndex)
{
printf("treeShapUnwindPath: Going through path elements\n");
printf("uniquePath = %p\n", uniquePath);
printf("uniqueDepth = %lu\n", uniqueDepth);
printf("pathIndex = %lu\n", pathIndex);
printf("---- start\n");
printf("%p\n", uniquePath + pathIndex);

const float oneFraction = uniquePath[pathIndex].oneFraction;
const float zeroFraction = uniquePath[pathIndex].zeroFraction;

printf("%p\n", uniquePath + uniqueDepth);
float nextOnePortion = uniquePath[uniqueDepth].partialWeight;
float nextOnePortion = uniquePath[uniqueDepth].partialWeight;

for (int i = uniqueDepth - 1; i >= 0; --i)
{
printf("%p\n", uniquePath + i);
if (oneFraction != 0)
{
const float tmp = uniquePath[i].partialWeight;
Expand All @@ -59,27 +50,18 @@ void treeShapUnwindPath(PathElement * uniquePath, size_t uniqueDepth, size_t pat

for (size_t i = pathIndex; i < uniqueDepth; ++i)
{
printf("%p <- %p\n", uniquePath + i, uniquePath + i + 1);
uniquePath[i].featureIndex = uniquePath[i + 1].featureIndex;
uniquePath[i].zeroFraction = uniquePath[i + 1].zeroFraction;
uniquePath[i].oneFraction = uniquePath[i + 1].oneFraction;
}
}

// determine what the total permutation weight would be if we unwound a previous extension in the decision path
float treeShapUnwoundPathSum(const PathElement * uniquePath, size_t uniqueDepth, size_t pathIndex)
float unwoundPathSum(const PathElement * uniquePath, size_t uniqueDepth, size_t pathIndex)
{
printf("treeShapUnwoundPathSum: Going through path elements\n");
printf("uniquePath = %p\n", uniquePath);
printf("uniqueDepth = %lu\n", uniqueDepth);
printf("pathIndex = %lu\n", pathIndex);
printf("---- start\n");
printf("%p\n", uniquePath + pathIndex);

const float oneFraction = uniquePath[pathIndex].oneFraction;
const float zeroFraction = uniquePath[pathIndex].zeroFraction;

printf("%p\n", uniquePath + uniqueDepth);
float nextOnePortion = uniquePath[uniqueDepth].partialWeight;
float total = 0;
// if (oneFraction != 0)
Expand Down Expand Up @@ -109,8 +91,6 @@ float treeShapUnwoundPathSum(const PathElement * uniquePath, size_t uniqueDepth,

for (int i = uniqueDepth - 1; i >= 0; --i)
{
printf("%p\n", uniquePath + i);

if (oneFraction != 0)
{
const float tmp = nextOnePortion * (uniqueDepth + 1) / static_cast<float>((i + 1) * oneFraction);
Expand Down
Loading

0 comments on commit cc21b23

Please sign in to comment.