diff --git a/src/learn_to_pick/base.py b/src/learn_to_pick/base.py index a449815..e1a85aa 100644 --- a/src/learn_to_pick/base.py +++ b/src/learn_to_pick/base.py @@ -165,10 +165,9 @@ def __init__( featurizer: Featurizer, formatter: Callable, vw_logger: VwLogger, - *args: Any, **kwargs: Any, ): - super().__init__(*args, **kwargs) + super().__init__(**kwargs) self.model_repo = model_repo self.vw_cmd = vw_cmd self.workspace = self.model_repo.load(vw_cmd) diff --git a/src/learn_to_pick/pick_best.py b/src/learn_to_pick/pick_best.py index 1682010..abf70d6 100644 --- a/src/learn_to_pick/pick_best.py +++ b/src/learn_to_pick/pick_best.py @@ -143,17 +143,19 @@ def _dotproducts(self, context, actions): for j in range(len(action_names)) } - def _generic_namespace(self, featurized): + @staticmethod + def _generic_namespace(featurized): result = base.SparseFeatures() for ns in featurized.sparse.keys(): if "default_ft" in featurized.sparse[ns]: result[ns] = featurized.sparse[ns]["default_ft"] return result - def _generic_namespaces(self, context, actions): - context["@"] = self._generic_namespace(context) + @staticmethod + def _generic_namespaces(context, actions): + context["@"] = PickBestFeaturizer._generic_namespace(context) for a in actions: - a["#"] = self._generic_namespace(a) + a["#"] = PickBestFeaturizer._generic_namespace(a) def featurize( self, event: PickBestEvent @@ -163,7 +165,7 @@ def featurize( if self.auto_embed: self._dotproducts(context, actions) - self._generic_namespaces(context, actions) + PickBestFeaturizer._generic_namespaces(context, actions) return context, actions, event.selected @@ -183,9 +185,6 @@ def vw_cb_formatter( class PickBestRandomPolicy(base.Policy[PickBestEvent]): - def __init__(self): - ... - def predict(self, event: PickBestEvent) -> List[Tuple[int, float]]: num_items = len(event.to_select_from) return [(i, 1.0 / num_items) for i in range(num_items)] @@ -265,14 +264,6 @@ def _call_after_predict_before_scoring( next_inputs = inputs.copy() - # only one key, value pair in event.to_select_from - value = next(iter(event.to_select_from.values())) - v = ( - value[event.selected.index] - if event.selected - else event.to_select_from.values() - ) - picked = {} for k, v in event.to_select_from.items(): picked[k] = v[event.selected.index]