diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index c4883c1..5b787d1 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -43,29 +43,6 @@ def __init__( self.to_select_from = to_select_from self.based_on = based_on - def context(self, model) -> base.Featurized: - return base.embed(self.based_on or {}, model) - - def actions(self, model) -> List[base.Featurized]: - to_select_from_var_name, to_select_from = next( - iter(self.to_select_from.items()), (None, None) - ) - - action_embs = ( - ( - base.embed(to_select_from, model, to_select_from_var_name) - if self.to_select_from - else None - ) - if to_select_from - else None - ) - if not action_embs: - raise ValueError( - "Context and to_select_from must be provided in the inputs dictionary" - ) - return action_embs - class VwTxt: @staticmethod @@ -157,11 +134,31 @@ def _generic_namespaces(context, actions): for a in actions: a["#"] = PickBestFeaturizer._generic_namespace(a) + def get_context_actions(self, event) -> Tuple[base.Featurized, List[base.Featurized]]: + context = base.embed(event.based_on or {}, self.model) + to_select_from_var_name, to_select_from = next( + iter(event.to_select_from.items()), (None, None) + ) + + actions = ( + ( + base.embed(to_select_from, self.model, to_select_from_var_name) + if event.to_select_from + else None + ) + if to_select_from + else None + ) + if not actions: + raise ValueError( + "Context and to_select_from must be provided in the inputs dictionary" + ) + return context, actions + def featurize( self, event: PickBestEvent ) -> Tuple[base.Featurized, List[base.Featurized], PickBestSelected]: - context = event.context(self.model) - actions = event.actions(self.model) + context, actions = self.get_context_actions(event) if self.auto_embed: self._dotproducts(context, actions)