Skip to content

Commit

Permalink
xcp ref implementation (#2895)
Browse files Browse the repository at this point in the history
* Added xcp ref implementation

This routine computes the matrix of cross product
of data stored in column major format, in batches.

For matrix X of dimensions p x n, the i,j th entry
of the cross product matrix is

C_ij = \sum_k (x_ik-\mu_i) (x_jk-\mu_k)

where x_ij is the jth element of the ith row, of the matrix X.

Implementation uses the BLAS routine GEMM.

Signed-off-by: Dhanus M Lal <[email protected]>

* refactor and check for malloc fail

Signed-off-by: Dhanus M Lal <[email protected]>

* review changes

Signed-off-by: Dhanus M Lal <[email protected]>

---------

Signed-off-by: Dhanus M Lal <[email protected]>
  • Loading branch information
DhanusML authored Oct 22, 2024
1 parent 4716240 commit 4145d40
Showing 1 changed file with 115 additions and 2 deletions.
117 changes: 115 additions & 2 deletions cpp/daal/src/externals/service_stat_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define __SERVICE_STAT_REF_H__

#include "src/externals/service_memory.h"
#include "src/externals/service_blas_ref.h"

typedef void (*func_type)(DAAL_INT, DAAL_INT, DAAL_INT, void *);
extern "C"
Expand Down Expand Up @@ -99,6 +100,62 @@ struct RefStatistics<double, cpu>
__int64 method)
{
int errcode = 0;
daal::internal::ref::OpenBlas<double, cpu> blasInst;
const double accWtOld = *nPreviousObservations;
const double accWt = *nPreviousObservations + nVectors;
constexpr DAAL_INT one = 1;
if (accWtOld != 0)
{
double * const sumOld = daal::services::internal::service_malloc<double, cpu>(nFeatures, sizeof(double));
DAAL_CHECK_MALLOC(sumOld);
for (DAAL_INT i = 0; i < nFeatures; ++i)
{
sumOld[i] = sum[i];
}
// S_old S_old^t/accWtOld
const double alpha = 1.0 / accWtOld;
const double beta = 1.0;
constexpr char transa = 'N';
constexpr char transb = 'N';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures);
daal::services::daal_free(sumOld);
}
for (DAAL_INT i = 0; i < nVectors; ++i)
{
for (DAAL_INT j = 0; j < nFeatures; ++j) // if accWtOld = 0, overwrite sum
{
if (accWtOld != 0)
{
sum[j] += data[i * nFeatures + j];
}
else
{
if (i == 0)
sum[j] = data[i * nFeatures + j]; //overwrite the current sum
else
sum[j] += data[i * nFeatures + j];
}
}
}

// -S S^t/accWt
{
const double alpha = -1.0 / accWt;
const double beta = accWtOld != 0 ? 1.0 : 0.0;
constexpr char transa = 'N';
constexpr char transb = 'N';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures);
}

// X X^t
{
constexpr double alpha = 1.0;
constexpr double beta = 1.0;
constexpr char transa = 'N';
constexpr char transb = 'T';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct,
&nFeatures);
}

return errcode;
}
Expand All @@ -124,7 +181,7 @@ struct RefStatistics<double, cpu>
// E(x-\mu)^2 = E(x^2) - \mu^2
int errcode = 0;
double * sum = (double *)daal::services::internal::service_calloc<double, cpu>(nFeatures, sizeof(double));
if (!sum) return -4;
DAAL_CHECK_MALLOC(sum);
daal::services::internal::service_memset<double, cpu>(variance, double(0), nFeatures);
DAAL_INT feature_ptr, vec_ptr;
double wtInv = (double)1 / nVectors;
Expand Down Expand Up @@ -210,6 +267,62 @@ struct RefStatistics<float, cpu>
__int64 method)
{
int errcode = 0;
daal::internal::ref::OpenBlas<float, cpu> blasInst;
const float accWtOld = *nPreviousObservations;
const float accWt = *nPreviousObservations + nVectors;
constexpr DAAL_INT one = 1;
if (accWtOld != 0)
{
float * const sumOld = daal::services::internal::service_malloc<float, cpu>(nFeatures, sizeof(float));
DAAL_CHECK_MALLOC(sumOld);
for (DAAL_INT i = 0; i < nFeatures; ++i)
{
sumOld[i] = sum[i];
}
// S_old S_old^t/accWtOld
const float alpha = 1.0 / accWtOld;
const float beta = 1.0;
constexpr char transa = 'N';
constexpr char transb = 'N';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sumOld, &nFeatures, sumOld, &one, &beta, crossProduct, &nFeatures);
daal::services::daal_free(sumOld);
}
for (DAAL_INT i = 0; i < nVectors; ++i)
{
for (DAAL_INT j = 0; j < nFeatures; ++j) // if accWtOld = 0, overwrite sum
{
if (accWtOld != 0)
{
sum[j] += data[i * nFeatures + j];
}
else
{
if (i == 0)
sum[j] = data[i * nFeatures + j]; //overwrite the current sum
else
sum[j] += data[i * nFeatures + j];
}
}
}

// -S S^t/accWt
{
const float alpha = -1.0 / accWt;
const float beta = accWtOld != 0 ? 1.0 : 0.0;
constexpr char transa = 'N';
constexpr char transb = 'N';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &one, &alpha, sum, &nFeatures, sum, &one, &beta, crossProduct, &nFeatures);
}

// X X^t
{
constexpr float alpha = 1.0;
constexpr float beta = 1.0;
constexpr char transa = 'N';
constexpr char transb = 'T';
blasInst.xgemm(&transa, &transb, &nFeatures, &nFeatures, &nVectors, &alpha, data, &nFeatures, data, &nFeatures, &beta, crossProduct,
&nFeatures);
}

return errcode;
}
Expand All @@ -235,7 +348,7 @@ struct RefStatistics<float, cpu>
// E(x-\mu)^2 = E(x^2) - \mu^2
int errcode = 0;
float * sum = (float *)daal::services::internal::service_calloc<float, cpu>(nFeatures, sizeof(float));
if (!sum) return -4;
DAAL_CHECK_MALLOC(sum);
daal::services::internal::service_memset<float, cpu>(variance, float(0), nFeatures);
DAAL_INT feature_ptr, vec_ptr;
float wtInv = (float)1 / nVectors;
Expand Down

0 comments on commit 4145d40

Please sign in to comment.