Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanups in featurization #35

Merged
merged 8 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,21 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam
return [parser.parse_line(line) for line in input_str.split("\n")]


def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
to_select_from = {
k: inputs[k].value
def get_based_on(inputs: Dict[str, Any]) -> Dict:
return {
k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _ToSelectFrom)
if isinstance(inputs[k], _BasedOn)
}

if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)

based_on = {
k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value
def get_to_select_from(inputs: Dict[str, Any]) -> Dict:
return {
k: inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _BasedOn)
if isinstance(inputs[k], _ToSelectFrom)
}

return based_on, to_select_from


# end helper functions

Expand Down
76 changes: 30 additions & 46 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,19 @@ class PickBestEvent(base.Event[PickBestSelected]):
def __init__(
self,
inputs: Dict[str, Any],
to_select_from: Dict[str, Any],
based_on: Dict[str, Any],
selected: Optional[PickBestSelected] = None,
):
super().__init__(inputs=inputs, selected=selected or PickBestSelected())
self.to_select_from = to_select_from
self.based_on = based_on

def context(self, model) -> base.Featurized:
return base.embed(self.based_on or {}, model)

def actions(self, model) -> List[base.Featurized]:
to_select_from_var_name, to_select_from = next(
iter(self.to_select_from.items()), (None, None)
)

action_embs = (
(
base.embed(to_select_from, model, to_select_from_var_name)
if self.to_select_from
else None
self.to_select_from = base.get_to_select_from(inputs)
self.based_on = base.get_based_on(inputs)
if not self.to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)
if to_select_from
else None
)
if not action_embs:
if len(self.to_select_from) > 1:
raise ValueError(
"Context and to_select_from must be provided in the inputs dictionary"
"Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from."
)
return action_embs


class VwTxt:
Expand All @@ -77,9 +60,9 @@ def _sparse_2_str(values: base.SparseFeatures) -> str:
def _to_str(v):
import numbers

return v if isinstance(v, numbers.Number) else f"={v}"
return f":{v}" if isinstance(v, numbers.Number) else f"={v}"

return " ".join([f"{k}:{_to_str(v)}" for k, v in values.items()])
return " ".join([f"{k}{_to_str(v)}" for k, v in values.items()])

@staticmethod
def featurized_2_str(obj: base.Featurized) -> str:
Expand Down Expand Up @@ -157,11 +140,29 @@ def _generic_namespaces(context, actions):
for a in actions:
a["#"] = PickBestFeaturizer._generic_namespace(a)

def get_context_and_actions(
self, event
) -> Tuple[base.Featurized, List[base.Featurized]]:
context = base.embed(event.based_on or {}, self.model)
to_select_from_var_name, to_select_from = next(
iter(event.to_select_from.items()), (None, None)
)

actions = (
(
base.embed(to_select_from, self.model, to_select_from_var_name)
if event.to_select_from
else None
)
if to_select_from
else None
)
return context, actions

def featurize(
self, event: PickBestEvent
) -> Tuple[base.Featurized, List[base.Featurized], PickBestSelected]:
context = event.context(self.model)
actions = event.actions(self.model)
context, actions = self.get_context_and_actions(event)

if self.auto_embed:
self._dotproducts(context, actions)
Expand Down Expand Up @@ -224,24 +225,7 @@ class PickBest(base.RLLoop[PickBestEvent]):
"""

def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
if not actions:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)

if len(list(actions.values())) > 1:
raise ValueError(
"Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from."
)

if not context:
raise ValueError(
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on."
)

event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
return event
return PickBestEvent(inputs=inputs)

def _call_after_predict_before_scoring(
self,
Expand Down
35 changes: 17 additions & 18 deletions tests/unit_tests/test_pick_best_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ def test_multiple_ToSelectFrom_throws() -> None:
)


def test_missing_basedOn_from_throws() -> None:
def test_missing_basedOn_from_dont_throw() -> None:
pick = learn_to_pick.PickBest.create(
llm=fake_llm_caller,
featurizer=learn_to_pick.PickBestFeaturizer(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
pick.run(action=learn_to_pick.ToSelectFrom(actions))
pick.run(action=learn_to_pick.ToSelectFrom(actions))


def test_ToSelectFrom_not_a_list_throws() -> None:
Expand Down Expand Up @@ -169,10 +168,10 @@ def test_everything_embedded() -> None:

expected = "\n".join(
[
f"shared |User_dense {encoded_ctx_str_1} |User_sparse default_ft:={ctx_str_1}",
f"|action_dense {action_dense} |action_sparse default_ft:={str1}",
f"|action_dense {action_dense} |action_sparse default_ft:={str2}",
f"|action_dense {action_dense} |action_sparse default_ft:={str3}",
f"shared |User_dense {encoded_ctx_str_1} |User_sparse default_ft={ctx_str_1}",
f"|action_dense {action_dense} |action_sparse default_ft={str1}",
f"|action_dense {action_dense} |action_sparse default_ft={str2}",
f"|action_dense {action_dense} |action_sparse default_ft={str3}",
]
) # noqa

Expand All @@ -198,10 +197,10 @@ def test_default_auto_embedder_is_off() -> None:

expected = "\n".join(
[
f"shared |User_sparse default_ft:={ctx_str_1}",
f"|action_sparse default_ft:={str1}",
f"|action_sparse default_ft:={str2}",
f"|action_sparse default_ft:={str3}",
f"shared |User_sparse default_ft={ctx_str_1}",
f"|action_sparse default_ft={str1}",
f"|action_sparse default_ft={str2}",
f"|action_sparse default_ft={str3}",
]
) # noqa

Expand All @@ -227,10 +226,10 @@ def test_default_w_embeddings_off() -> None:

expected = "\n".join(
[
f"shared |User_sparse default_ft:={ctx_str_1}",
f"|action_sparse default_ft:={str1}",
f"|action_sparse default_ft:={str2}",
f"|action_sparse default_ft:={str3}",
f"shared |User_sparse default_ft={ctx_str_1}",
f"|action_sparse default_ft={str1}",
f"|action_sparse default_ft={str2}",
f"|action_sparse default_ft={str3}",
]
) # noqa

Expand Down Expand Up @@ -258,9 +257,9 @@ def test_default_w_embeddings_on() -> None:

expected = "\n".join(
[
f"shared |User_sparse default_ft:={ctx_str_1} |@_sparse User:={ctx_str_1}",
f"|action_sparse default_ft:={str1} |{dot_prod} |#_sparse action:={str1} ",
f"|action_sparse default_ft:={str2} |{dot_prod} |#_sparse action:={str2} ",
f"shared |User_sparse default_ft={ctx_str_1} |@_sparse User={ctx_str_1}",
f"|action_sparse default_ft={str1} |{dot_prod} |#_sparse action={str1} ",
f"|action_sparse default_ft={str2} |{dot_prod} |#_sparse action={str2} ",
]
) # noqa

Expand Down
Loading
Loading