Skip to content

Commit

Permalink
all cb actions checks to PickBestEvent ctr
Browse files Browse the repository at this point in the history
  • Loading branch information
ataymano committed Nov 16, 2023
1 parent ca7f469 commit 86f3827
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 119 deletions.
33 changes: 11 additions & 22 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,19 @@ class PickBestEvent(base.Event[PickBestSelected]):
def __init__(
self,
inputs: Dict[str, Any],
to_select_from: Dict[str, Any],
based_on: Dict[str, Any],
selected: Optional[PickBestSelected] = None,
):
super().__init__(inputs=inputs, selected=selected or PickBestSelected())
self.to_select_from = to_select_from
self.based_on = based_on
self.to_select_from = base.get_to_select_from(inputs)
self.based_on = base.get_based_on(inputs)
if not self.to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)
if len(self.to_select_from) > 1:
raise ValueError(
"Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from."
)


class VwTxt:
Expand Down Expand Up @@ -151,10 +157,6 @@ def get_context_actions(
if to_select_from
else None
)
if not actions:
raise ValueError(
"Context and to_select_from must be provided in the inputs dictionary"
)
return context, actions

def featurize(
Expand Down Expand Up @@ -223,20 +225,7 @@ class PickBest(base.RLLoop[PickBestEvent]):
"""

def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
to_select_from = base.get_to_select_from(inputs)
if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)
if len(to_select_from) > 1:
raise ValueError(
"Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from."
)
return PickBestEvent(
inputs=inputs,
to_select_from=to_select_from,
based_on=base.get_based_on(inputs),
)
return PickBestEvent(inputs=inputs)

def _call_after_predict_before_scoring(
self,
Expand Down
Loading

0 comments on commit 86f3827

Please sign in to comment.