Skip to content

Commit

Permalink
implement per-rate scaling in SSE and CPU kernels (DNA only)
Browse files Browse the repository at this point in the history
  • Loading branch information
amkozlov committed Apr 26, 2017
1 parent 8ce3148 commit 790a201
Show file tree
Hide file tree
Showing 9 changed files with 846 additions and 249 deletions.
99 changes: 97 additions & 2 deletions src/core_derivatives.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
Schloss-Wolfsbrunnenweg 35, D-69118 Heidelberg, Germany
*/

#include <limits.h>
#include "pll.h"

PLL_EXPORT int pll_core_update_sumtable_ti_4x4(unsigned int sites,
unsigned int rate_cats,
const double * parent_clv,
const unsigned char * left_tipchars,
const unsigned int * parent_scaler,
double ** eigenvecs,
double ** inv_eigenvecs,
double ** freqs,
Expand All @@ -45,9 +47,46 @@ PLL_EXPORT int pll_core_update_sumtable_ti_4x4(unsigned int sites,

unsigned int states = 4;

unsigned int min_scaler;
unsigned int * rate_scalings = NULL;
int per_rate_scaling = (attrib & PLL_ATTRIB_RATE_SCALERS) ? 1 : 0;

/* powers of scale threshold for undoing the scaling */
double scale_minlh[PLL_SCALE_RATE_MAXDIFF];
if (per_rate_scaling)
{
rate_scalings = (unsigned int*) calloc(rate_cats, sizeof(unsigned int));

double scale_factor = 1.0;
for (i = 0; i < PLL_SCALE_RATE_MAXDIFF; ++i)
{
scale_factor *= PLL_SCALE_THRESHOLD;
scale_minlh[i] = scale_factor;
}
}

/* build sumtable */
for (n = 0; n < sites; n++)
{
if (per_rate_scaling)
{
/* compute minimum per-rate scaler -> common per-site scaler */
min_scaler = UINT_MAX;
for (i = 0; i < rate_cats; ++i)
{
rate_scalings[i] = (parent_scaler) ? parent_scaler[n*rate_cats+i] : 0;
if (rate_scalings[i] < min_scaler)
min_scaler = rate_scalings[i];
}

/* compute relative capped per-rate scalers */
for (i = 0; i < rate_cats; ++i)
{
rate_scalings[i] = PLL_MIN(rate_scalings[i] - min_scaler,
PLL_SCALE_RATE_MAXDIFF);
}
}

for (i = 0; i < rate_cats; ++i)
{
t_eigenvecs = eigenvecs[i];
Expand All @@ -67,13 +106,19 @@ PLL_EXPORT int pll_core_update_sumtable_ti_4x4(unsigned int sites,
tipstate >>= 1;
}
sum[j] = lefterm * righterm;

if (rate_scalings && rate_scalings[i] > 0)
sum[j] *= scale_minlh[rate_scalings[i]-1];
}

t_clvc += states;
sum += states;
}
}

if (rate_scalings)
free(rate_scalings);

return PLL_SUCCESS;
}

Expand Down Expand Up @@ -109,10 +154,13 @@ PLL_EXPORT int pll_core_update_sumtable_ii(unsigned int states,
rate_cats,
parent_clv,
child_clv,
parent_scaler,
child_scaler,
eigenvecs,
inv_eigenvecs,
freqs,
sumtable);
sumtable,
attrib);
}
#endif
#ifdef HAVE_AVX
Expand Down Expand Up @@ -150,9 +198,47 @@ PLL_EXPORT int pll_core_update_sumtable_ii(unsigned int states,
}
#endif

unsigned int min_scaler;
unsigned int * rate_scalings = NULL;
int per_rate_scaling = (attrib & PLL_ATTRIB_RATE_SCALERS) ? 1 : 0;

/* powers of scale threshold for undoing the scaling */
double scale_minlh[PLL_SCALE_RATE_MAXDIFF];
if (per_rate_scaling)
{
rate_scalings = (unsigned int*) calloc(rate_cats, sizeof(unsigned int));

double scale_factor = 1.0;
for (i = 0; i < PLL_SCALE_RATE_MAXDIFF; ++i)
{
scale_factor *= PLL_SCALE_THRESHOLD;
scale_minlh[i] = scale_factor;
}
}

/* build sumtable */
for (n = 0; n < sites; n++)
{
if (per_rate_scaling)
{
/* compute minimum per-rate scaler -> common per-site scaler */
min_scaler = UINT_MAX;
for (i = 0; i < rate_cats; ++i)
{
rate_scalings[i] = (parent_scaler) ? parent_scaler[n*rate_cats+i] : 0;
rate_scalings[i] += (child_scaler) ? child_scaler[n*rate_cats+i] : 0;
if (rate_scalings[i] < min_scaler)
min_scaler = rate_scalings[i];
}

/* compute relative capped per-rate scalers */
for (i = 0; i < rate_cats; ++i)
{
rate_scalings[i] = PLL_MIN(rate_scalings[i] - min_scaler,
PLL_SCALE_RATE_MAXDIFF);
}
}

for (i = 0; i < rate_cats; ++i)
{
t_eigenvecs = eigenvecs[i];
Expand All @@ -169,13 +255,19 @@ PLL_EXPORT int pll_core_update_sumtable_ii(unsigned int states,
righterm += t_eigenvecs[j * states + k] * t_clvc[k];
}
sum[j] = lefterm * righterm;

if (rate_scalings && rate_scalings[i] > 0)
sum[j] *= scale_minlh[rate_scalings[i]-1];
}
t_clvc += states;
t_clvp += states;
sum += states;
}
}

if (rate_scalings)
free(rate_scalings);

return PLL_SUCCESS;
}

Expand Down Expand Up @@ -214,11 +306,13 @@ PLL_EXPORT int pll_core_update_sumtable_ti(unsigned int states,
rate_cats,
parent_clv,
left_tipchars,
parent_scaler,
eigenvecs,
inv_eigenvecs,
freqs,
tipmap,
sumtable);
sumtable,
attrib);
}
#endif
#ifdef HAVE_AVX
Expand Down Expand Up @@ -265,6 +359,7 @@ PLL_EXPORT int pll_core_update_sumtable_ti(unsigned int states,
rate_cats,
parent_clv,
left_tipchars,
parent_scaler,
eigenvecs,
inv_eigenvecs,
freqs,
Expand Down
98 changes: 53 additions & 45 deletions src/core_derivatives_avx.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,22 @@ static int core_update_sumtable_ii_4x4_avx(unsigned int sites,
unsigned int states = 4;

unsigned int min_scaler = 0;
unsigned int * rate_scale_factors = NULL;
unsigned int * rate_scalings = NULL;
int per_rate_scaling = (attrib & PLL_ATTRIB_RATE_SCALERS) ? 1 : 0;

/* powers of scale threshold for undoing the scaling */
__m256d v_scale_minlh[PLL_SCALE_RATE_MAXDIFF];
if (per_rate_scaling)
rate_scale_factors = (unsigned int*) calloc(rate_cats, sizeof(unsigned int));
{
rate_scalings = (unsigned int*) calloc(rate_cats, sizeof(unsigned int));

/* powers of scale threshold for undoing the scaling */
__m256d v_scale_minlh[5] = {
_mm256_set1_pd(1.0),
_mm256_set1_pd(PLL_SCALE_THRESHOLD),
_mm256_set1_pd(PLL_SCALE_THRESHOLD * PLL_SCALE_THRESHOLD),
_mm256_set1_pd(PLL_SCALE_THRESHOLD * PLL_SCALE_THRESHOLD *
PLL_SCALE_THRESHOLD),
_mm256_set1_pd(PLL_SCALE_THRESHOLD * PLL_SCALE_THRESHOLD *
PLL_SCALE_THRESHOLD * PLL_SCALE_THRESHOLD)
};
double scale_factor = 1.0;
for (i = 0; i < PLL_SCALE_RATE_MAXDIFF; ++i)
{
scale_factor *= PLL_SCALE_THRESHOLD;
v_scale_minlh[i] = _mm256_set1_pd(scale_factor);
}
}

/* transposed inv_eigenvecs */
double * tt_inv_eigenvecs = (double *) pll_aligned_alloc (
Expand Down Expand Up @@ -95,10 +95,17 @@ static int core_update_sumtable_ii_4x4_avx(unsigned int sites,
min_scaler = UINT_MAX;
for (i = 0; i < rate_cats; ++i)
{
rate_scale_factors[i] = (parent_scaler) ? parent_scaler[n*rate_cats+i] : 0;
rate_scale_factors[i] += (child_scaler) ? child_scaler[n*rate_cats+i] : 0;
if (rate_scale_factors[i] < min_scaler)
min_scaler = rate_scale_factors[i];
rate_scalings[i] = (parent_scaler) ? parent_scaler[n*rate_cats+i] : 0;
rate_scalings[i] += (child_scaler) ? child_scaler[n*rate_cats+i] : 0;
if (rate_scalings[i] < min_scaler)
min_scaler = rate_scalings[i];
}

/* compute relative capped per-rate scalers */
for (i = 0; i < rate_cats; ++i)
{
rate_scalings[i] = PLL_MIN(rate_scalings[i] - min_scaler,
PLL_SCALE_RATE_MAXDIFF);
}
}

Expand Down Expand Up @@ -170,12 +177,9 @@ static int core_update_sumtable_ii_4x4_avx(unsigned int sites,
__m256d v_sum = _mm256_mul_pd (v_lefterm_sum, v_righterm_sum);

/* apply per-rate scalers */
if (per_rate_scaling)
if (rate_scalings && rate_scalings[i] > 0)
{
int scalings = rate_scale_factors[i] - min_scaler > 4 ?
4 : (rate_scale_factors[i] - min_scaler);

v_sum = _mm256_mul_pd(v_sum, v_scale_minlh[scalings]);
v_sum = _mm256_mul_pd(v_sum, v_scale_minlh[rate_scalings[i]-1]);
}

_mm256_store_pd (sum, v_sum);
Expand All @@ -188,8 +192,8 @@ static int core_update_sumtable_ii_4x4_avx(unsigned int sites,

pll_aligned_free (tt_inv_eigenvecs);

if (rate_scale_factors)
free(rate_scale_factors);
if (rate_scalings)
free(rate_scalings);

return PLL_SUCCESS;
}
Expand Down Expand Up @@ -413,22 +417,22 @@ static int core_update_sumtable_ti_4x4_avx(unsigned int sites,
const double * t_eigenvecs_trans;

unsigned int min_scaler = 0;
unsigned int * rate_scale_factors = NULL;
unsigned int * rate_scalings = NULL;
int per_rate_scaling = (attrib & PLL_ATTRIB_RATE_SCALERS) ? 1 : 0;

/* powers of scale threshold for undoing the scaling */
__m256d v_scale_minlh[PLL_SCALE_RATE_MAXDIFF];
if (per_rate_scaling)
rate_scale_factors = (unsigned int*) calloc(rate_cats, sizeof(unsigned int));
{
rate_scalings = (unsigned int*) calloc(rate_cats, sizeof(unsigned int));

/* powers of scale threshold for undoing the scaling */
__m256d v_scale_minlh[5] = {
_mm256_set1_pd(1.0),
_mm256_set1_pd(PLL_SCALE_THRESHOLD),
_mm256_set1_pd(PLL_SCALE_THRESHOLD * PLL_SCALE_THRESHOLD),
_mm256_set1_pd(PLL_SCALE_THRESHOLD * PLL_SCALE_THRESHOLD *
PLL_SCALE_THRESHOLD),
_mm256_set1_pd(PLL_SCALE_THRESHOLD * PLL_SCALE_THRESHOLD *
PLL_SCALE_THRESHOLD * PLL_SCALE_THRESHOLD)
};
double scale_factor = 1.0;
for (i = 0; i < PLL_SCALE_RATE_MAXDIFF; ++i)
{
scale_factor *= PLL_SCALE_THRESHOLD;
v_scale_minlh[i] = _mm256_set1_pd(scale_factor);
}
}

double * eigenvecs_trans = (double *) pll_aligned_alloc (
(states * states * rate_cats) * sizeof(double),
Expand Down Expand Up @@ -513,9 +517,16 @@ static int core_update_sumtable_ti_4x4_avx(unsigned int sites,
min_scaler = UINT_MAX;
for (i = 0; i < rate_cats; ++i)
{
rate_scale_factors[i] = (parent_scaler) ? parent_scaler[n*rate_cats+i] : 0;
if (rate_scale_factors[i] < min_scaler)
min_scaler = rate_scale_factors[i];
rate_scalings[i] = (parent_scaler) ? parent_scaler[n*rate_cats+i] : 0;
if (rate_scalings[i] < min_scaler)
min_scaler = rate_scalings[i];
}

/* compute relative capped per-rate scalers */
for (i = 0; i < rate_cats; ++i)
{
rate_scalings[i] = PLL_MIN(rate_scalings[i] - min_scaler,
PLL_SCALE_RATE_MAXDIFF);
}
}

Expand All @@ -541,12 +552,9 @@ static int core_update_sumtable_ti_4x4_avx(unsigned int sites,
__m256d v_sum = _mm256_mul_pd(v_lefterm, v_righterm);

/* apply per-rate scalers */
if (per_rate_scaling)
if (rate_scalings && rate_scalings[i] > 0)
{
int scalings = rate_scale_factors[i] - min_scaler > 4 ?
4 : (rate_scale_factors[i] - min_scaler);

v_sum = _mm256_mul_pd(v_sum, v_scale_minlh[scalings]);
v_sum = _mm256_mul_pd(v_sum, v_scale_minlh[rate_scalings[i]-1]);
}

_mm256_store_pd(sum, v_sum);
Expand All @@ -561,8 +569,8 @@ static int core_update_sumtable_ti_4x4_avx(unsigned int sites,
pll_aligned_free(eigenvecs_trans);
pll_aligned_free(precomp_left);

if (rate_scale_factors)
free(rate_scale_factors);
if (rate_scalings)
free(rate_scalings);

return PLL_SUCCESS;
}
Expand Down
Loading

0 comments on commit 790a201

Please sign in to comment.