Skip to content

Commit

Permalink
22584: Improves ablation check in !TrainCreateCases to prevent undesi…
Browse files Browse the repository at this point in the history
…red warning (#406)

Since `!TrainCreateCases` is called in `train` before
`!AutoAnalyzeIfNeeded`, if auto-ablation and auto-analyze are enabled a
warning can be raised regarding not having analyzed with case weights
even if all parameters are set correctly.

This is not observed in larger use-cases like the Engine recipe or
benchmarks because this is only relevant when the train batch size is
$\geq$ both the auto-analyze threshold and the auto-ablation minimum
model size.
  • Loading branch information
jdbeel authored Jan 21, 2025
1 parent 2da60f6 commit 5ba4b09
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
27 changes: 18 additions & 9 deletions howso/train.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,9 @@
;split by batches of cases until next analyze
(while (< input_case_index (size cases))

(if (and thresholds_enabled (not skip_ablation))
;Only compute prediction stats for ablation if thresholds are enabled, we're
; not skipping ablation, and we have hyperparameters to use for the computation.
(if (and thresholds_enabled (not skip_ablation) (size !hyperparameterMetadataMap))
(seq
(assign (assoc
prev_prediction_stats_map new_prediction_stats_map
Expand Down Expand Up @@ -836,6 +838,7 @@
(and
(>= (+ !dataMassChangeSinceLastDataReduction batch_size) !autoAblationMaxNumCases)
(not skip_reduce_data)
(size !hyperparameterMetadataMap)
)
(call reduce_data (assoc
abs_threshold_map !autoAblationAbsThresholdMap
Expand Down Expand Up @@ -879,8 +882,12 @@
(let
(assoc
indices_to_train
(if thresholds_satisfied
;If one or more thresholds has been satisfied, just train all of the cases in this batch.
;If one or more thresholds has been satisfied, or the Trainee has not yet been analyzed,
; just train all of the cases in this batch.
(if (or
thresholds_satisfied
(not (size !hyperparameterMetadataMap))
)
(indices cases)
;Otherwise, do the normal ablation filtering.
||(filter
Expand Down Expand Up @@ -1031,12 +1038,14 @@

(accum_to_entities (assoc !dataMassChangeSinceLastAnalyze mass_to_accumulate))

;recompute influence weights entropy
(call !ComputeAndStoreInfluenceWeightEntropies (assoc
features features
weight_feature accumulate_weight_feature
use_case_weights (true)
))
;recompute influence weights entropy only after we've analyzed.
(if (size !hyperparameterMetadataMap)
(call !ComputeAndStoreInfluenceWeightEntropies (assoc
features features
weight_feature accumulate_weight_feature
use_case_weights (true)
))
)
)
)

Expand Down
8 changes: 6 additions & 2 deletions unit_tests/ut_h_basic_ablation.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
#unit_test (direct_assign_to_entities (assoc unit_test (load "unit_test.amlg")))
(call (load "unit_test_howso.amlg") (assoc name "ut_h_basic_ablation.amlg"))

;Set a non-default hyperparameter map since ablation checks for non-default
; hyperparameters.
(call_entity "howso" "set_params" (assoc
default_hyperparameter_map (assoc "k" 2 "p" 0.4 "dt" -1)
hyperparameter_map
{targetless { "A.B.C.D.E." { ".none"
{k 2 p 0.4 dt -1}
}}}
))

(declare (assoc
Expand Down Expand Up @@ -91,7 +96,6 @@
obs (call !get_num_cases)
))


(print "conviction value of new case: "
(call_entity "howso" "react_group" (assoc features (append context_features action_features ) new_cases (list (list (list 1 2 1 3 3)))) ) "\n"
)
Expand Down
2 changes: 1 addition & 1 deletion unit_tests/ut_h_reduce_data.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
features (append context_labels action_labels)
cases training_data
session "iris_session"
skip_auto_analyze (true)
skip_auto_analyze (false)
))
(call_entity "howso" "analyze" (assoc
use_case_weights (true)
Expand Down

0 comments on commit 5ba4b09

Please sign in to comment.