Skip to content

Commit

Permalink
20138: Improves nominal deviation math and fixes bugs around code com…
Browse files Browse the repository at this point in the history
…parisons with small values (#123)
  • Loading branch information
howsohazard authored May 2, 2024
1 parent 8b90164 commit 286e662
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 57 deletions.
38 changes: 10 additions & 28 deletions src/Amalgam/GeneralizedDistance.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,16 +514,13 @@ class GeneralizedDistanceEvaluator
{
//need to have at least two classes in existence
double nominal_count = std::max(featureAttribs[index].typeAttributes.nominalCount, 2.0);
//TODO 17631: change to be weighted average: prob of nominal * deviation of random guessing
double prob_max_entropy_match = 1 / nominal_count;

//find probability that the correct class was selected
//can't go below base probability of guessing
double prob_class_given_match = std::max(1 - deviation, prob_max_entropy_match);
double prob_class_given_match = 1 - deviation;

//find the probability that any other class besides the correct class was selected
//divide the probability among the other classes
double prob_class_given_nonmatch = (1 - prob_class_given_match) / (nominal_count - 1);
double prob_class_given_nonmatch = deviation / (nominal_count - 1);

double surprisal_class_given_match = -std::log(prob_class_given_match);
double surprisal_class_given_nonmatch = -std::log(prob_class_given_nonmatch);
Expand All @@ -536,16 +533,11 @@ class GeneralizedDistanceEvaluator
else if(DoesFeatureHaveDeviation(index))
{
double nominal_count = featureAttribs[index].typeAttributes.nominalCount;

// n = number of nominal classes
// match: deviation ^ p * weight
// non match: (deviation + (1 - deviation) / (n - 1)) ^ p * weight
//if there is only one nominal class, the smallest delta value it could be is the specified smallest delta, otherwise it's 1.0
double dist_term = 0;
double dist_term = 1;
//the probability of each other term is spread across all of the different nominal classes,
//so take the remaining probability for all other classes besides the one nonmatch chosen
if(nominal_count > 1)
dist_term = (deviation + (1 - deviation) / (nominal_count - 1));
else
dist_term = 1;
dist_term = 1.0 - (deviation / (nominal_count - 1));

return dist_term;
}
Expand Down Expand Up @@ -605,16 +597,13 @@ class GeneralizedDistanceEvaluator
{
//need to have at least two classes in existence
double nominal_count = std::max(featureAttribs[index].typeAttributes.nominalCount, 2.0);
//TODO 17631: change to be weighted average: prob of nominal * deviation of random guessing
double prob_max_entropy_match = 1 / nominal_count;


//find probability that the correct class was selected
//can't go below base probability of guessing
double prob_class_given_match = std::max(1 - match_deviation, prob_max_entropy_match);
double prob_class_given_match = 1 - match_deviation;

//find the probability that any other class besides the correct class was selected,
//but cannot exceed the probability of a match
double prob_class_given_nonmatch = std::min(1 - nonmatch_deviation, prob_class_given_match);
double prob_class_given_nonmatch = match_deviation / (nominal_count - 1);

if(match)
return ComputeDistanceTermBaseNominalMatchFromMatchProbabilities(index,
Expand Down Expand Up @@ -779,15 +768,8 @@ class GeneralizedDistanceEvaluator
{
if(computeSurprisal)
{
//need to have at least two classes in existence
double nominal_count = std::max(featureAttribs[index].typeAttributes.nominalCount, 2.0);
//TODO 17631: change to be weighted average: prob of nominal * deviation of random guessing
double prob_max_entropy_match = 1 / nominal_count;

//find probability that the correct class was selected
//can't go below base probability of guessing
double prob_class_given_match = std::max(1 - deviation, prob_max_entropy_match);

double prob_class_given_match = 1 - deviation;
diff = -std::log(prob_class_given_match);
}
else //nonsurprisal nominals just use the deviation as provided
Expand Down
14 changes: 14 additions & 0 deletions src/Amalgam/SBFDSColumnData.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,20 @@ class SBFDSColumnData
return value;
}

//returns the number of unique values in the column
//if value_type is ENIVT_NULL, then it will include all types, otherwise it will only consider
//the unique values for the type requested
inline size_t GetNumUniqueValues(EvaluableNodeImmediateValueType value_type = ENIVT_NULL)
{
if(value_type == ENIVT_NUMBER)
return numberIndices.size();

if(value_type == ENIVT_STRING_ID)
return stringIdIndices.size();

return numberIndices.size() + stringIdIndices.size() + codeIndices.size();
}

//returns the maximum difference between value and any other value for this column
//if empty, will return infinity
inline double GetMaxDifferenceTerm(GeneralizedDistanceEvaluator::FeatureAttributes &feature_attribs)
Expand Down
21 changes: 9 additions & 12 deletions src/Amalgam/SeparableBoxFilterDataStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,15 +409,9 @@ class SeparableBoxFilterDataStore
}

//returns the number of unique values for a column for the given value_type
size_t GetNumUniqueValuesForColumn(size_t column_index, EvaluableNodeImmediateValueType value_type)
inline size_t GetNumUniqueValuesForColumn(size_t column_index, EvaluableNodeImmediateValueType value_type)
{
auto &column_data = columnData[column_index];
if(value_type == ENIVT_NUMBER)
return column_data->numberIndices.size();
else if(value_type == ENIVT_STRING_ID)
return column_data->stringIdIndices.size();
else //return everything else
return GetNumInsertedEntities() - column_data->invalidIndices.size();
return columnData[column_index]->GetNumUniqueValues(value_type);
}

//returns a function that will take in an entity index iterator and reference to a double to store the value and return true if the value is found
Expand Down Expand Up @@ -935,8 +929,8 @@ class SeparableBoxFilterDataStore
}
}

//recomputes column indices for each feature as well as filling in unknowns
inline void PopulateColumnIndicesAndUnknownFeatureValueDifferences(
//sets values in dist_eval corresponding to the columns specified by position_label_ids
inline void PopulateGeneralizedDistanceEvaluatorFromColumnData(
GeneralizedDistanceEvaluator &dist_eval, std::vector<size_t> &position_label_ids)
{
for(size_t query_feature_index = 0; query_feature_index < position_label_ids.size(); query_feature_index++)
Expand All @@ -947,15 +941,18 @@ class SeparableBoxFilterDataStore

auto &feature_attribs = dist_eval.featureAttribs[query_feature_index];
feature_attribs.featureIndex = column->second;
auto &column_data = columnData[feature_attribs.featureIndex];

if(feature_attribs.IsFeatureNominal() && FastIsNaN(feature_attribs.typeAttributes.nominalCount))
feature_attribs.typeAttributes.nominalCount = static_cast<double>(column_data->GetNumUniqueValues());

//if either known or unknown to unknown is missing, need to compute difference
// and store it where it is needed
double unknown_distance_term = 0.0;
if(FastIsNaN(feature_attribs.knownToUnknownDistanceTerm.deviation)
|| FastIsNaN(feature_attribs.unknownToUnknownDistanceTerm.deviation))
{
unknown_distance_term = columnData[feature_attribs.featureIndex]->GetMaxDifferenceTerm(
feature_attribs);
unknown_distance_term = column_data->GetMaxDifferenceTerm(feature_attribs);

if(FastIsNaN(feature_attribs.knownToUnknownDistanceTerm.deviation))
feature_attribs.knownToUnknownDistanceTerm.deviation = unknown_distance_term;
Expand Down
6 changes: 3 additions & 3 deletions src/Amalgam/amlg_code/full_test.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -3535,7 +3535,7 @@
(list 1 1 1 120 1 50.1)
) "\n")

(print "expected: 256.5114466\n")
(print "expected: 515.6547015746778\n")

(print (generalized_distance
;weights
Expand All @@ -3554,7 +3554,7 @@
(list 1 1 1 120 1 50.1)
) "\n")

(print "expected: 8.037178684\n")
(print "expected: 16.936168566704314\n")

(print (generalized_distance
;weights
Expand All @@ -3573,7 +3573,7 @@
(list 1 1 1 120 1 50.1)
) "\n")

(print "expected: 0.14362593\n")
(print "expected: 1.501901120883626e-25\n")

(create_entities "DistanceTestContainer"
(lambda (null))
Expand Down
2 changes: 1 addition & 1 deletion src/Amalgam/entity/EntityQueryCaches.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ void EntityQueryCaches::GetMatchingEntities(EntityQueryCondition *cond, BitArray
if(matching_entities.size() == 0)
return;

sbfds.PopulateColumnIndicesAndUnknownFeatureValueDifferences(cond->distEvaluator, cond->positionLabels);
sbfds.PopulateGeneralizedDistanceEvaluatorFromColumnData(cond->distEvaluator, cond->positionLabels);
cond->distEvaluator.InitializeParametersAndFeatureParams();

if(cond->queryType == ENT_QUERY_NEAREST_GENERALIZED_DISTANCE || cond->queryType == ENT_QUERY_WITHIN_GENERALIZED_DISTANCE)
Expand Down
26 changes: 13 additions & 13 deletions src/Amalgam/out.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ current_index: 2
interpreter "C:\\Users\\ChristopherHazard\\Desktop\\Howso_repos\\amalgam\\x64\\MT_Release_EXE\\Amalgam.exe"
raaa 2
rwww 1
start_time 1714575485.109338
start_time 1714591759.00049
www 1
x 12
zz 10
Expand Down Expand Up @@ -1302,7 +1302,7 @@ current_index: 2
interpreter "C:\\Users\\ChristopherHazard\\Desktop\\Howso_repos\\amalgam\\x64\\MT_Release_EXE\\Amalgam.exe"
raaa 2
rwww 1
start_time 1714575485.109338
start_time 1714591759.00049
www 1
x 12
zz 10
Expand Down Expand Up @@ -1344,7 +1344,7 @@ current_index: 2
interpreter "C:\\Users\\ChristopherHazard\\Desktop\\Howso_repos\\amalgam\\x64\\MT_Release_EXE\\Amalgam.exe"
raaa 2
rwww 1
start_time 1714575485.109338
start_time 1714591759.00049
www 1
x 12
zz 10
Expand Down Expand Up @@ -1612,7 +1612,7 @@ e:
- .inf

25: (assoc a 1)
current date-time in epoch: 2024-05-01-10.58.05.1487100
current date-time in epoch: 2024-05-01-15.29.19.0478620
2020-06-07 00:22:59
1391230800
1391230800
Expand Down Expand Up @@ -3430,7 +3430,7 @@ deep sets

--set_entity_root_permission--
RootTest
1714575485.261987
1714591759.338789
(true)

RootTest
Expand Down Expand Up @@ -4443,13 +4443,13 @@ a
(assoc a1 5.0990195135927845 a2 2 a3 5.0990195135927845)
(assoc a1 1 a3 1 a4 0)
--accuracy tests--
256.5114465493462
expected: 256.5114466
8.037178667714414
expected: 8.037178684
3.621729247632793e-35
expected: 0.14362593
(assoc point1 256.5114465493462 point2 273.1247383899543 point3 256.5114465493462)
515.6547015746778
expected: 515.6547015746778
16.936168566704314
expected: 16.936168566704314
1.501901120883626e-25
expected: 1.501901120883626e-25
(assoc point1 515.6547015746778 point2 545.8279373833903 point3 515.6547015746778)
distance symmetry tests
(list
(list
Expand Down Expand Up @@ -4676,4 +4676,4 @@ concurrent entity writes successful: (true)

--clean-up test files--
--total execution time--
2.024893045425415
1.975032091140747

0 comments on commit 286e662

Please sign in to comment.