Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ataymano committed Nov 15, 2023
1 parent 56d69f0 commit 43ad297
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 18 deletions.
3 changes: 1 addition & 2 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,9 @@ def __init__(
featurizer: Featurizer,
formatter: Callable,
vw_logger: VwLogger,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
self.model_repo = model_repo
self.vw_cmd = vw_cmd
self.workspace = self.model_repo.load(vw_cmd)
Expand Down
23 changes: 7 additions & 16 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,19 @@ def _dotproducts(self, context, actions):
for j in range(len(action_names))
}

def _generic_namespace(self, featurized):
@staticmethod
def _generic_namespace(featurized):
result = base.SparseFeatures()
for ns in featurized.sparse.keys():
if "default_ft" in featurized.sparse[ns]:
result[ns] = featurized.sparse[ns]["default_ft"]
return result

def _generic_namespaces(self, context, actions):
context["@"] = self._generic_namespace(context)
@staticmethod
def _generic_namespaces(context, actions):
context["@"] = PickBestFeaturizer._generic_namespace(context)
for a in actions:
a["#"] = self._generic_namespace(a)
a["#"] = PickBestFeaturizer._generic_namespace(a)

def featurize(
self, event: PickBestEvent
Expand All @@ -163,7 +165,7 @@ def featurize(

if self.auto_embed:
self._dotproducts(context, actions)
self._generic_namespaces(context, actions)
PickBestFeaturizer._generic_namespaces(context, actions)

return context, actions, event.selected

Expand All @@ -183,9 +185,6 @@ def vw_cb_formatter(


class PickBestRandomPolicy(base.Policy[PickBestEvent]):
def __init__(self):
...

def predict(self, event: PickBestEvent) -> List[Tuple[int, float]]:
num_items = len(event.to_select_from)
return [(i, 1.0 / num_items) for i in range(num_items)]
Expand Down Expand Up @@ -265,14 +264,6 @@ def _call_after_predict_before_scoring(

next_inputs = inputs.copy()

# only one key, value pair in event.to_select_from
value = next(iter(event.to_select_from.values()))
v = (
value[event.selected.index]
if event.selected
else event.to_select_from.values()
)

picked = {}
for k, v in event.to_select_from.items():
picked[k] = v[event.selected.index]
Expand Down

0 comments on commit 43ad297

Please sign in to comment.