Skip to content

Commit

Permalink
get_context_actions to featurized
Browse files Browse the repository at this point in the history
  • Loading branch information
ataymano committed Nov 16, 2023
1 parent 76e04cb commit 7b34c51
Showing 1 changed file with 22 additions and 25 deletions.
47 changes: 22 additions & 25 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,6 @@ def __init__(
self.to_select_from = to_select_from
self.based_on = based_on

def context(self, model) -> base.Featurized:
return base.embed(self.based_on or {}, model)

def actions(self, model) -> List[base.Featurized]:
to_select_from_var_name, to_select_from = next(
iter(self.to_select_from.items()), (None, None)
)

action_embs = (
(
base.embed(to_select_from, model, to_select_from_var_name)
if self.to_select_from
else None
)
if to_select_from
else None
)
if not action_embs:
raise ValueError(
"Context and to_select_from must be provided in the inputs dictionary"
)
return action_embs


class VwTxt:
@staticmethod
Expand Down Expand Up @@ -157,11 +134,31 @@ def _generic_namespaces(context, actions):
for a in actions:
a["#"] = PickBestFeaturizer._generic_namespace(a)

def get_context_actions(self, event) -> Tuple[base.Featurized, List[base.Featurized]]:
context = base.embed(event.based_on or {}, self.model)
to_select_from_var_name, to_select_from = next(
iter(event.to_select_from.items()), (None, None)
)

actions = (
(
base.embed(to_select_from, self.model, to_select_from_var_name)
if event.to_select_from
else None
)
if to_select_from
else None
)
if not actions:
raise ValueError(
"Context and to_select_from must be provided in the inputs dictionary"
)
return context, actions

def featurize(
self, event: PickBestEvent
) -> Tuple[base.Featurized, List[base.Featurized], PickBestSelected]:
context = event.context(self.model)
actions = event.actions(self.model)
context, actions = self.get_context_actions(event)

if self.auto_embed:
self._dotproducts(context, actions)
Expand Down

0 comments on commit 7b34c51

Please sign in to comment.