-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
xcp ref implementation #2895
xcp ref implementation #2895
Changes from 2 commits
478255a
0293b6f
9617b65
988389d
a7a74df
4a60b17
bd6aa89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -26,6 +26,7 @@ | |||||||||||||||||||||||||
#define __SERVICE_STAT_REF_H__ | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
#include "src/externals/service_memory.h" | ||||||||||||||||||||||||||
#include "src/externals/service_blas_ref.h" | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
typedef void (*func_type)(DAAL_INT, DAAL_INT, DAAL_INT, void *); | ||||||||||||||||||||||||||
extern "C" | ||||||||||||||||||||||||||
|
@@ -174,6 +175,56 @@ struct RefStatistics<double, cpu> | |||||||||||||||||||||||||
__int64 method) | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
int errcode = 0; | ||||||||||||||||||||||||||
daal::internal::ref::OpenBlas<double, cpu> blasInst; | ||||||||||||||||||||||||||
double accWtOld = *nPreviousObservations; | ||||||||||||||||||||||||||
double accWt = *nPreviousObservations + nVectors; | ||||||||||||||||||||||||||
DAAL_INT one = 1; | ||||||||||||||||||||||||||
char transa = 'N'; | ||||||||||||||||||||||||||
char transb = 'N'; | ||||||||||||||||||||||||||
double beta = 0.0; | ||||||||||||||||||||||||||
double alpha; | ||||||||||||||||||||||||||
if (accWtOld != 0) | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
double * sumOld = daal::services::internal::service_malloc<double, cpu>(nFeatures, sizeof(double)); | ||||||||||||||||||||||||||
DAAL_CHECK_MALLOC(sumOld); | ||||||||||||||||||||||||||
for (DAAL_INT i = 0; i < nFeatures; ++i) | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
sumOld[i] = sum[i]; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
// S_old S_old^t/accWtOld | ||||||||||||||||||||||||||
alpha = 1.0 / accWtOld; | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any checks for overflow? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does onedal have some macros to check floating point overflow? |
||||||||||||||||||||||||||
beta = 1.0; | ||||||||||||||||||||||||||
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures); | ||||||||||||||||||||||||||
daal::services::daal_free(sumOld); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
for (DAAL_INT i = 0; i < nVectors; ++i) | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
for (DAAL_INT j = 0; j < nFeatures; ++j) // if accWtOld = 0, overwrite sum | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
if (accWtOld != 0) | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
sum[j] += data[i * nFeatures + j]; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
if (i == 0) | ||||||||||||||||||||||||||
sum[j] = data[i * nFeatures + j]; //overwrite the current sum | ||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||
sum[j] += data[i * nFeatures + j]; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
// -S S^t/accWt | ||||||||||||||||||||||||||
alpha = -1.0 / accWt; | ||||||||||||||||||||||||||
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
// X X^t | ||||||||||||||||||||||||||
transb = 'T'; | ||||||||||||||||||||||||||
alpha = 1.0; | ||||||||||||||||||||||||||
beta = 1.0; | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would recommend not to use the same variables. It's confusing and also can be dangerous if someone will forgot to change them. |
||||||||||||||||||||||||||
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct, | ||||||||||||||||||||||||||
&nFeatures); | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
return errcode; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
@@ -285,6 +336,56 @@ struct RefStatistics<float, cpu> | |||||||||||||||||||||||||
__int64 method) | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
int errcode = 0; | ||||||||||||||||||||||||||
daal::internal::ref::OpenBlas<float, cpu> blasInst; | ||||||||||||||||||||||||||
float accWtOld = *nPreviousObservations; | ||||||||||||||||||||||||||
float accWt = *nPreviousObservations + nVectors; | ||||||||||||||||||||||||||
DAAL_INT one = 1; | ||||||||||||||||||||||||||
char transa = 'N'; | ||||||||||||||||||||||||||
char transb = 'N'; | ||||||||||||||||||||||||||
float beta = 0.0; | ||||||||||||||||||||||||||
float alpha; | ||||||||||||||||||||||||||
if (accWtOld != 0) | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
float * sumOld = daal::services::internal::service_malloc<float, cpu>(nFeatures, sizeof(float)); | ||||||||||||||||||||||||||
DAAL_CHECK_MALLOC(sumOld); | ||||||||||||||||||||||||||
for (DAAL_INT i = 0; i < nFeatures; ++i) | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
sumOld[i] = sum[i]; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
// S_old S_old^t/accWtOld | ||||||||||||||||||||||||||
alpha = 1.0 / accWtOld; | ||||||||||||||||||||||||||
beta = 1.0; | ||||||||||||||||||||||||||
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures); | ||||||||||||||||||||||||||
daal::services::daal_free(sumOld); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
for (DAAL_INT i = 0; i < nVectors; ++i) | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
for (DAAL_INT j = 0; j < nFeatures; ++j) // if accWtOld = 0, overwrite sum | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
if (accWtOld != 0) | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
sum[j] += data[i * nFeatures + j]; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||
if (i == 0) | ||||||||||||||||||||||||||
sum[j] = data[i * nFeatures + j]; //overwrite the current sum | ||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||
sum[j] += data[i * nFeatures + j]; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
// -S S^t/accWt | ||||||||||||||||||||||||||
alpha = -1.0 / accWt; | ||||||||||||||||||||||||||
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
// X X^t | ||||||||||||||||||||||||||
transb = 'T'; | ||||||||||||||||||||||||||
alpha = 1.0; | ||||||||||||||||||||||||||
beta = 1.0; | ||||||||||||||||||||||||||
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct, | ||||||||||||||||||||||||||
&nFeatures); | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
return errcode; | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.