Skip to content

Commit

Permalink
18891: Adds end-to-end surprisal as distance to all distance opcodes …
Browse files Browse the repository at this point in the history
…and queries, MINOR (#54)
  • Loading branch information
howsohazard authored Jan 23, 2024
1 parent 50de95f commit aaa27bc
Show file tree
Hide file tree
Showing 12 changed files with 265 additions and 214 deletions.
16 changes: 8 additions & 8 deletions docs/language.js

Large diffs are not rendered by default.

286 changes: 167 additions & 119 deletions src/Amalgam/GeneralizedDistance.h

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions src/Amalgam/SeparableBoxFilterDataStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1011,13 +1011,13 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(G
auto value_found = column->stringIdValueToIndices.find(value.stringID);
if(value_found != end(column->stringIdValueToIndices))
{
double term = dist_params.ComputeDistanceTermNonNominalExactMatch(query_feature_index, high_accuracy);
double term = dist_params.ComputeDistanceTermContinuousExactMatch(query_feature_index, high_accuracy);
AccumulatePartialSums(*(value_found->second), query_feature_index, term);
}
}

//the next closest string will have an edit distance of 1
return dist_params.ComputeDistanceTermNonNominalNonCyclicNonNullRegular(1.0, query_feature_index, high_accuracy);
return dist_params.ComputeDistanceTermContinuousNonCyclicNonNullRegular(1.0, query_feature_index, high_accuracy);
}
else if(effective_feature_type == GeneralizedDistance::EFDT_CONTINUOUS_CODE)
{
Expand All @@ -1035,7 +1035,7 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(G
}

//next most similar code must be at least a distance of 1 edit away
return dist_params.ComputeDistanceTermNonNominalNonCyclicNonNullRegular(1.0, query_feature_index, high_accuracy);
return dist_params.ComputeDistanceTermContinuousNonCyclicNonNullRegular(1.0, query_feature_index, high_accuracy);
}
//else feature_type == FDT_CONTINUOUS_NUMERIC or FDT_CONTINUOUS_UNIVERSALLY_NUMERIC

Expand All @@ -1052,9 +1052,9 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(G

double term = 0.0;
if(exact_index_found)
term = dist_params.ComputeDistanceTermNonNominalExactMatch(query_feature_index, high_accuracy);
term = dist_params.ComputeDistanceTermContinuousExactMatch(query_feature_index, high_accuracy);
else
term = dist_params.ComputeDistanceTermNonNominalNonNullRegular(
term = dist_params.ComputeDistanceTermContinuousNonNullRegular(
value.number - column->sortedNumberValueEntries[value_index]->value.number, query_feature_index, high_accuracy);

size_t num_entities_computed = AccumulatePartialSums(column->sortedNumberValueEntries[value_index]->indicesWithValue, query_feature_index, term);
Expand Down Expand Up @@ -1203,7 +1203,7 @@ double SeparableBoxFilterDataStore::PopulatePartialSumsWithSimilarFeatureValue(G
break;
}

term = dist_params.ComputeDistanceTermNonNominalNonNullRegular(next_closest_diff, query_feature_index, high_accuracy);
term = dist_params.ComputeDistanceTermContinuousNonNullRegular(next_closest_diff, query_feature_index, high_accuracy);
num_entities_computed += AccumulatePartialSums(
column->sortedNumberValueEntries[next_closest_index]->indicesWithValue, query_feature_index, term);

Expand Down
10 changes: 5 additions & 5 deletions src/Amalgam/SeparableBoxFilterDataStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class SeparableBoxFilterDataStore
{
double max_diff = columnData[absolute_feature_index]->GetMaxDifferenceTermFromValue(
dist_params.featureParams[query_feature_index], value_type, value);
return dist_params.ComputeDistanceTermNonNominalNonNullRegular(max_diff, query_feature_index, high_accuracy);
return dist_params.ComputeDistanceTermContinuousNonNullRegular(max_diff, query_feature_index, high_accuracy);
}

//gets the matrix cell index for the specified index
Expand Down Expand Up @@ -737,7 +737,7 @@ class SeparableBoxFilterDataStore
case GeneralizedDistance::EFDT_CONTINUOUS_UNIVERSALLY_NUMERIC:
{
const size_t column_index = target_label_indices[query_feature_index];
return dist_params.ComputeDistanceTermNonNominalNonCyclicOneNonNullRegular(
return dist_params.ComputeDistanceTermContinuousNonCyclicOneNonNullRegular(
target_values[query_feature_index].number - GetValue(entity_index, column_index).number,
query_feature_index, high_accuracy);
}
Expand All @@ -754,7 +754,7 @@ class SeparableBoxFilterDataStore
const size_t column_index = target_label_indices[query_feature_index];
auto &column_data = columnData[column_index];
if(column_data->numberIndices.contains(entity_index))
return dist_params.ComputeDistanceTermNonNominalNonCyclicOneNonNullRegular(
return dist_params.ComputeDistanceTermContinuousNonCyclicOneNonNullRegular(
target_values[query_feature_index].number - GetValue(entity_index, column_index).number,
query_feature_index, high_accuracy);
else
Expand All @@ -766,7 +766,7 @@ class SeparableBoxFilterDataStore
const size_t column_index = target_label_indices[query_feature_index];
auto &column_data = columnData[column_index];
if(column_data->numberIndices.contains(entity_index))
return dist_params.ComputeDistanceTermNonNominalOneNonNullRegular(
return dist_params.ComputeDistanceTermContinuousOneNonNullRegular(
target_values[query_feature_index].number - GetValue(entity_index, column_index).number,
query_feature_index, high_accuracy);
else
Expand Down Expand Up @@ -922,7 +922,7 @@ class SeparableBoxFilterDataStore
else
effective_feature_type = GeneralizedDistance::EFDT_CONTINUOUS_NUMERIC_PRECOMPUTED;

dist_params.ComputeAndStoreInternedNumberValuesAndDistanceTerms(query_feature_index, position_value_numeric, &column_data->internedNumberIndexToNumberValue);
dist_params.ComputeAndStoreInternedNumberValuesAndDistanceTerms(position_value_numeric, query_feature_index, &column_data->internedNumberIndexToNumberValue);
}
else
{
Expand Down
23 changes: 15 additions & 8 deletions src/Amalgam/amlg_code/full_test.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,15 @@
;should print 3ish
(print "35 " (generalized_distance (list 1 1) (list "continuous_code" "nominal_string") (list 0 5) (null) 1 (list (list 1.5 2 3 4 5) "s") (list (list 1 2 3) "s") ) "\n")

;surprisal
;should both be 0
(print "36 " (generalized_distance (list 1 1) (list "continuous_numeric" "continuous_numeric") (null) (list 0.5 0.5) 1 (list 1 1) (list 1 1) (null) (true) ) "\n" )
(print "37 " (generalized_distance (list 1 1) (list "nominal_numeric" "nominal_numeric") (null) (list 0.5 0.5) 1 (list 1 1) (list 1 1) (null) (true) ) "\n" )

;surprisal
(print "38 " (generalized_distance (list 1 1) (list "continuous_numeric" "continuous_numeric") (null) (list 0.5 0.5) 1 (list 1 1) (list 2 2) (null) (true) ) "\n" )
(print "39 " (generalized_distance (list 1 1) (list "nominal_numeric" "nominal_numeric") (list 2 2) (list 0.25 0.25) 1 (list 1 1) (list 2 2) (null) (true) ) "\n" )

(print "--entropy--\n")
(print (entropy (list 0.5 0.5)) "\n")
(print (entropy (list 0.5 0.5) (list 0.25 0.75) -1 1) "\n")
Expand Down Expand Up @@ -3889,7 +3898,6 @@

;should be:
;(list "vert0" "vert1" "vert2" "vert3")
;(list 0.049787068367863944 0.049787068367863944 0.01831563888873418 0.006737946999085467)
(print "probabilities: "
(compute_on_contained_entities "SurprisalTransformContainer" (list
(query_nearest_generalized_distance
Expand All @@ -3899,9 +3907,9 @@
(null) ; context_weights
(list "continuous_numeric") ; types
(null) ; attributes
(null) ; context_deviations
(list 0.25) ; context_deviations
1 ; p_parameter
"surprisal_to_prob" ; dwe = 1 means return computed distance to each case
"surprisal_to_prob" ; distance transform
(null) ; weight
(rand)
(null)
Expand All @@ -3913,7 +3921,6 @@

;should be
;(list "vert0" "vert2" "vert3" "vert1")
;(list 0.09709538455906153 0.01831563888873418 0.006737946999085467 0)
(print "weighted probabilities: "
(compute_on_contained_entities "SurprisalTransformContainer" (list
(query_nearest_generalized_distance
Expand All @@ -3923,9 +3930,9 @@
(null) ; context_weights
(list "continuous_numeric") ; types
(null) ; attributes
(null) ; context_deviations
(list 0.25) ; context_deviations
1 ; p_parameter
"surprisal_to_prob" ; dwe = 1 means return computed distance to each case
"surprisal_to_prob" ; distance transform
"weight" ; weight
(rand)
(null)
Expand All @@ -3941,12 +3948,12 @@

;should be approx 2.123
(print "surprisal contribution: " (compute_on_contained_entities "SurprisalTransformContainer" (list
(compute_entity_distance_contributions 4 (list "x") (list "testvert") (null) (null) (null) (null) 1 "surprisal_to_prob" (null) "fixed_seed" (null) "precise")
(compute_entity_distance_contributions 4 (list "x") (list "testvert") (null) (null) (null) (list 0.25) 1 "surprisal_to_prob" (null) "fixed_seed" (null) "precise")
)))

;should be approx 2.123
(print "weighted surprisal contribution: " (compute_on_contained_entities "SurprisalTransformContainer" (list
(compute_entity_distance_contributions 4 (list "x") (list "testvert") (null) (null) (null) (null) 1 "surprisal_to_prob" "weight" "fixed_seed" (null) "precise")
(compute_entity_distance_contributions 4 (list "x") (list "testvert") (null) (null) (null) (list 0.25) 1 "surprisal_to_prob" "weight" "fixed_seed" (null) "precise")
)))

(print "--concurrency tests--\n")
Expand Down
26 changes: 1 addition & 25 deletions src/Amalgam/amlg_code/test.amlg
Original file line number Diff line number Diff line change
@@ -1,28 +1,4 @@
(seq
(create_entities "BoxConvictionTestContainer" (null) )
(print "17 " (generalized_distance (null) (list "nominal_numeric") (list 1) (null) 1 (list 1 2 3) (list 10 2 4) ) "\n")

(create_entities (list "BoxConvictionTestContainer" "vert0") (lambda
(null ##x 0 ##y 0 ##weight 2)
) )

(create_entities (list "BoxConvictionTestContainer" "vert1") (lambda
(null ##x 0 ##y 1 ##weight 1)
) )

(create_entities (list "BoxConvictionTestContainer" "vert2") (lambda
(null ##x 1 ##y 0 ##weight 1)
) )

(create_entities (list "BoxConvictionTestContainer" "vert3") (lambda
(null ##x 2 ##y 1 ##weight 1)
) )

;should print:
;dc: (list
;(list "vert0" "vert1" "vert2" "vert3")
;(list 1 1 1 1.4142135623730951)
;)
(print "dc: " (compute_on_contained_entities "BoxConvictionTestContainer" (list
(compute_entity_distance_contributions 1 (list "x" "y") (list "vert3") (null) (null) (null) (null) 2.0 -1 (null) "fixed_seed" (null) "recompute_precise" (true))
)))
)
4 changes: 2 additions & 2 deletions src/Amalgam/entity/EntityQueries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ EvaluableNodeReference EntityQueryCondition::GetMatchingEntities(Entity *contain
}

//transform distances as appropriate
EntityQueriesStatistics::DistanceTransform<Entity *> distance_transform(transformSuprisalToProb,
EntityQueriesStatistics::DistanceTransform<Entity *> distance_transform(distParams.computeSurprisal,
distanceWeightExponent, weightLabel != StringInternPool::NOT_A_STRING_ID,
[this](Entity *e, double &weight_value) { return e->GetValueAtLabelAsNumber(weightLabel, weight_value); });

Expand Down Expand Up @@ -775,7 +775,7 @@ EvaluableNodeReference EntityQueryCondition::GetMatchingEntities(Entity *contain
entity_values.push_back(DistanceReferencePair<Entity *>(GetConditionDistanceMeasure(matching_entities[i], high_accuracy), matching_entities[i]));

//transform distances as appropriate
EntityQueriesStatistics::DistanceTransform<Entity *> distance_transform(transformSuprisalToProb,
EntityQueriesStatistics::DistanceTransform<Entity *> distance_transform(distParams.computeSurprisal,
distanceWeightExponent, weightLabel != StringInternPool::NOT_A_STRING_ID,
[this](Entity *e, double &weight_value) { return e->GetValueAtLabelAsNumber(weightLabel, weight_value); });

Expand Down
3 changes: 0 additions & 3 deletions src/Amalgam/entity/EntityQueries.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ class EntityQueryCondition
//only applicable when transformSuprisalToProb is false
double distanceWeightExponent;

//if true, the values will be transformed from surprisal to probability; if false, will perform a distance transform
bool transformSuprisalToProb;

//if ENT_QUERY_SELECT has a start offset
bool hasStartOffset;

Expand Down
4 changes: 2 additions & 2 deletions src/Amalgam/entity/EntityQueryBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,15 +330,15 @@ namespace EntityQueryBuilder
cur_condition->distParams.pValue = p_value;

//value transforms for whatever is measured as "distance"
cur_condition->transformSuprisalToProb = false;
cur_condition->distanceWeightExponent = 1.0;
cur_condition->distParams.computeSurprisal = false;
if(ocn.size() > DISTANCE_VALUE_TRANSFORM)
{
EvaluableNode *dwe_param = ocn[DISTANCE_VALUE_TRANSFORM];
if(!EvaluableNode::IsNull(dwe_param))
{
if(dwe_param->GetType() == ENT_STRING && dwe_param->GetStringIDReference() == ENBISI_surprisal_to_prob)
cur_condition->transformSuprisalToProb = true;
cur_condition->distParams.computeSurprisal = true;
else //try to convert to number
cur_condition->distanceWeightExponent = EvaluableNode::ToNumber(dwe_param, 1.0);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Amalgam/entity/EntityQueryCaches.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ void EntityQueryCaches::GetMatchingEntities(EntityQueryCondition *cond, BitArray
weight_column = sbfds.GetColumnIndexFromLabelId(cond->weightLabel);

auto get_weight = sbfds.GetNumberValueFromEntityIndexFunction(weight_column);
EntityQueriesStatistics::DistanceTransform<size_t> distance_transform(cond->transformSuprisalToProb,
EntityQueriesStatistics::DistanceTransform<size_t> distance_transform(cond->distParams.computeSurprisal,
cond->distanceWeightExponent, use_entity_weights, get_weight);

//if first, need to populate with all entities
Expand Down
8 changes: 6 additions & 2 deletions src/Amalgam/interpreter/InterpreterOpcodesMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,9 +1015,9 @@ EvaluableNodeReference Interpreter::InterpretNode_ENT_GENERALIZED_DISTANCE(Evalu

//get value_names if applicable
std::vector<StringInternPool::StringID> value_names;
if(ocn.size() > 8)
if(ocn.size() > 7)
{
EvaluableNodeReference value_names_node = InterpretNodeForImmediateUse(ocn[8]);
EvaluableNodeReference value_names_node = InterpretNodeForImmediateUse(ocn[7]);
if(!EvaluableNode::IsNull(value_names_node))
{
//extract the names for each value into value_names
Expand All @@ -1034,6 +1034,10 @@ EvaluableNodeReference Interpreter::InterpretNode_ENT_GENERALIZED_DISTANCE(Evalu
evaluableNodeManager->FreeNodeTreeIfPossible(value_names_node);
}

dist_params.computeSurprisal = false;
if(ocn.size() > 8)
dist_params.computeSurprisal = InterpretNodeIntoBoolValue(ocn[8], false);

//get the origin and destination
std::vector<EvaluableNodeImmediateValue> location;
std::vector<EvaluableNodeImmediateValueType> location_types;
Expand Down
Loading

0 comments on commit aaa27bc

Please sign in to comment.