diff --git a/src/Amalgam/IntegerSet.h b/src/Amalgam/IntegerSet.h index 1c17309b..3a396fe4 100644 --- a/src/Amalgam/IntegerSet.h +++ b/src/Amalgam/IntegerSet.h @@ -432,8 +432,10 @@ class BitArrayIntegerSet size_t indices_per_bucket = num_indices / num_buckets; if(indices_per_bucket >= 48) { + //calculate last bucket in case less than total size + size_t end_buckets = (end_index + 63) / 64; for(size_t bucket = 0, index = 0; - bucket < num_buckets; bucket++, index++) + bucket < end_buckets; bucket++, index++) { uint64_t bucket_bits = bitBucket[bucket]; for(size_t bit = 0; bit < 64; bit++) @@ -465,6 +467,67 @@ class BitArrayIntegerSet } } + //iterates over all of the integers as efficiently as possible, passing them into func + template + static inline void IterateOverIntersection( + BitArrayIntegerSet &bais_1, BitArrayIntegerSet &bais_2, + IntegerFunction func, size_t up_to_index = std::numeric_limits::max()) + { + auto &sparser_bais = (bais_1.size() < bais_2.size() ? bais_1 : bais_2); + auto &denser_bais = (bais_1.size() < bais_2.size() ? bais_2 : bais_1); + + size_t sparser_end_integer = sparser_bais.GetEndInteger(); + size_t sparser_num_buckets = (sparser_end_integer + 63) / 64; + + size_t denser_end_integer = denser_bais.GetEndInteger(); + size_t denser_num_buckets = (denser_end_integer + 63) / 64; + + size_t end_index = std::min(up_to_index, sparser_end_integer); + end_index = std::min(end_index, denser_end_integer); + + size_t num_buckets = std::min(sparser_num_buckets, denser_num_buckets); + + //there are three loops optimized for different densities, high, medium high, and sparse + //the heuristics have been tuned by performance testing across a couple of CPU architectures + //and different data sets + size_t sparser_indices_per_bucket = sparser_num_buckets / sparser_num_buckets; + if(sparser_indices_per_bucket >= 48) + { + for(size_t bucket = 0, index = 0; + bucket < num_buckets; bucket++, index++) + { + uint64_t bucket_bits = (sparser_bais.bitBucket[bucket] & denser_bais.bitBucket[bucket]); + for(size_t bit = 0; bit < 64; bit++) + { + uint64_t mask = (1ULL << bit); + if(bucket_bits & mask) + func(index); + } + } + } + else if(sparser_indices_per_bucket >= 32) + { + for(size_t index = 0; index < end_index; index++) + { + if(sparser_bais.ContainsWithoutMaximumIndexCheck(index) + && denser_bais.ContainsWithoutMaximumIndexCheck(index)) + func(index); + } + } + else //use the iterator, which is more efficient when sparse + { + auto iter = sparser_bais.begin(); + size_t index = *iter; + while(index < end_index) + { + if(denser_bais.ContainsWithoutMaximumIndexCheck(index)) + func(index); + ++iter; + index = *iter; + } + } + } + //sets bucket and bit to the values pointing to the first id in the hash, // or the first element if it is empty inline void FindFirst(size_t &bucket, size_t &bit) diff --git a/src/Amalgam/SeparableBoxFilterDataStore.cpp b/src/Amalgam/SeparableBoxFilterDataStore.cpp index e611c87b..aa2357ea 100644 --- a/src/Amalgam/SeparableBoxFilterDataStore.cpp +++ b/src/Amalgam/SeparableBoxFilterDataStore.cpp @@ -1050,14 +1050,14 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(R return unknown_unknown_term; if(unknown_unknown_term < known_unknown_term || known_unknown_term == 0.0) - AccumulatePartialSums(column->nullIndices, query_feature_index, unknown_unknown_term); + AccumulatePartialSums(enabled_indices, column->nullIndices, query_feature_index, unknown_unknown_term); if(known_unknown_term < unknown_unknown_term || unknown_unknown_term == 0.0) { BitArrayIntegerSet &known_unknown_indices = parametersAndBuffers.potentialMatchesSet; known_unknown_indices = enabled_indices; column->nullIndices.EraseTo(known_unknown_indices); - AccumulatePartialSums(known_unknown_indices, query_feature_index, known_unknown_term); + AccumulatePartialSums(enabled_indices, known_unknown_indices, query_feature_index, known_unknown_term); } double larget_term_not_computed = std::max(known_unknown_term, unknown_unknown_term); @@ -1073,7 +1073,7 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(R } else //nonsymmetric nominal -- need to compute { - AccumulatePartialSums(column->nullIndices, query_feature_index, unknown_unknown_term); + AccumulatePartialSums(enabled_indices, column->nullIndices, query_feature_index, unknown_unknown_term); double nonmatch_dist_term = r_dist_eval.ComputeDistanceTermNominalNonNullSmallestNonmatch(query_feature_index, high_accuracy); //if the next closest match is larger, no need to compute any more values @@ -1106,7 +1106,7 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(R || r_dist_eval.distEvaluator->IsKnownToUnknownDistanceLessThanOrEqualToExactMatch(query_feature_index)) { double known_unknown_term = r_dist_eval.distEvaluator->ComputeDistanceTermKnownToUnknown(query_feature_index, high_accuracy); - AccumulatePartialSums(column->nullIndices, query_feature_index, known_unknown_term); + AccumulatePartialSums(enabled_indices, column->nullIndices, query_feature_index, known_unknown_term); } //if nominal, only need to compute the exact match @@ -1129,7 +1129,8 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(R if(value_found != end(column->valueCodeSizeToIndices)) { auto &entity_indices = *(value_found->second); - ComputeAndAccumulatePartialSums(r_dist_eval, entity_indices, query_feature_index, absolute_feature_index, high_accuracy); + ComputeAndAccumulatePartialSums(r_dist_eval, enabled_indices, entity_indices, + query_feature_index, absolute_feature_index, high_accuracy); } } //else value_type == ENIVT_NULL and already covered above @@ -1203,7 +1204,8 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(R if(value_found != end(column->valueCodeSizeToIndices)) { auto &entity_indices = *(value_found->second); - ComputeAndAccumulatePartialSums(r_dist_eval, entity_indices, query_feature_index, absolute_feature_index, high_accuracy); + ComputeAndAccumulatePartialSums(r_dist_eval, enabled_indices, entity_indices, + query_feature_index, absolute_feature_index, high_accuracy); } if(feature_type == GeneralizedDistanceEvaluator::FDT_NOMINAL_CODE) diff --git a/src/Amalgam/SeparableBoxFilterDataStore.h b/src/Amalgam/SeparableBoxFilterDataStore.h index b2903be0..08dfb603 100644 --- a/src/Amalgam/SeparableBoxFilterDataStore.h +++ b/src/Amalgam/SeparableBoxFilterDataStore.h @@ -540,7 +540,8 @@ class SeparableBoxFilterDataStore //computes each partial sum and adds the term to the partial sums associated for each id in entity_indices for query_feature_index //returns the number of entities indices accumulated size_t ComputeAndAccumulatePartialSums(RepeatedGeneralizedDistanceEvaluator &r_dist_eval, - SortedIntegerSet &entity_indices, size_t query_feature_index, size_t absolute_feature_index, bool high_accuracy) + BitArrayIntegerSet &enabled_indices, SortedIntegerSet &entity_indices, + size_t query_feature_index, size_t absolute_feature_index, bool high_accuracy) { size_t num_entity_indices = entity_indices.size(); size_t max_index = num_entity_indices; @@ -568,6 +569,8 @@ class SeparableBoxFilterDataStore for(int64_t i = 0; i < static_cast(max_index); i++) { const auto entity_index = entity_indices_vector[i]; + if(!enabled_indices.contains(entity_index)) + continue; //get value auto other_value_type = column_data->GetIndexValueType(entity_index); @@ -627,53 +630,80 @@ class SeparableBoxFilterDataStore } } + //return an estimate (upper bound) of the number accumulated return num_entity_indices; } //adds term to the partial sums associated for each id in entity_indices for query_feature_index //returns the number of entities indices accumulated - inline size_t AccumulatePartialSums(BitArrayIntegerSet &entity_indices, size_t query_feature_index, double term) + inline size_t AccumulatePartialSums(BitArrayIntegerSet &enabled_indices, BitArrayIntegerSet &entity_indices, + size_t query_feature_index, double term) { size_t num_entity_indices = entity_indices.size(); if(num_entity_indices == 0) return 0; + //see if the extra logic overhead for performing an intersection is worth doing + //for the reduced cost of fewer memory writes + size_t num_enabled_indices = enabled_indices.size(); + auto &partial_sums = parametersAndBuffers.partialSums; const auto accum_location = partial_sums.GetAccumLocation(query_feature_index); size_t max_element = partial_sums.numInstances; if(term != 0.0) { - entity_indices.IterateOver( - [&partial_sums, &accum_location, term] - (size_t entity_index) - { - partial_sums.Accum(entity_index, accum_location, term); - }, - max_element); + if(num_enabled_indices <= num_entity_indices / 8) + BitArrayIntegerSet::IterateOverIntersection(enabled_indices, entity_indices, + [&partial_sums, &accum_location, term] + (size_t entity_index) + { + partial_sums.Accum(entity_index, accum_location, term); + }, + max_element); + else + entity_indices.IterateOver( + [&partial_sums, &accum_location, term] + (size_t entity_index) + { + partial_sums.Accum(entity_index, accum_location, term); + }, + max_element); } else { - entity_indices.IterateOver( - [&partial_sums, &accum_location] - (size_t entity_index) - { - partial_sums.AccumZero(entity_index, accum_location); - }, - max_element); + if(num_enabled_indices <= num_entity_indices / 8) + BitArrayIntegerSet::IterateOverIntersection(enabled_indices, entity_indices, + [&partial_sums, &accum_location] + (size_t entity_index) + { + partial_sums.AccumZero(entity_index, accum_location); + }, + max_element); + else + entity_indices.IterateOver( + [&partial_sums, &accum_location] + (size_t entity_index) + { + partial_sums.AccumZero(entity_index, accum_location); + }, + max_element); } - return entity_indices.size(); + //return an estimate (upper bound) of the number accumulated + return std::min(enabled_indices.size(), entity_indices.size()); } - //adds term to the partial sums associated for each id in entity_indices for query_feature_index + //adds term to the partial sums associated for each id in both enabled_indices and entity_indices + // for query_feature_index //returns the number of entities indices accumulated - inline size_t AccumulatePartialSums(EfficientIntegerSet &entity_indices, size_t query_feature_index, double term) + inline size_t AccumulatePartialSums(BitArrayIntegerSet &enabled_indices, EfficientIntegerSet &entity_indices, + size_t query_feature_index, double term) { if(entity_indices.IsSisContainer()) return AccumulatePartialSums(entity_indices.GetSisContainer(), query_feature_index, term); else - return AccumulatePartialSums(entity_indices.GetBaisContainer(), query_feature_index, term); + return AccumulatePartialSums(enabled_indices, entity_indices.GetBaisContainer(), query_feature_index, term); } //accumulates the partial sums for the specified value