Skip to content

Commit

Permalink
Merge branch 'main' of github.com:VowpalWabbit/learn_to_pick into byom
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng-tan committed Nov 20, 2023
2 parents b093e1b + 213e931 commit 45e273e
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 254 deletions.
2 changes: 1 addition & 1 deletion notebooks/news_recommendation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
" chosen_article = picked[\"article\"]\n",
" user = event.based_on[\"user\"]\n",
" time_of_day = event.based_on[\"time_of_day\"]\n",
" score = self.get_score(user[0], time_of_day[0], chosen_article)\n",
" score = self.get_score(user, time_of_day, chosen_article)\n",
" return score"
]
},
Expand Down
95 changes: 41 additions & 54 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,66 +20,66 @@
from learn_to_pick.model_repository import ModelRepository
from learn_to_pick.vw_logger import VwLogger
from learn_to_pick.features import Featurized, DenseFeatures, SparseFeatures
from enum import Enum

if TYPE_CHECKING:
import vowpal_wabbit_next as vw

logger = logging.getLogger(__name__)


class _BasedOn:
def __init__(self, value: Any):
self.value = value
class Role(Enum):
CONTEXT = 1
ACTIONS = 2

def __str__(self) -> str:
return str(self.value)

__repr__ = __str__


def BasedOn(anything: Any) -> _BasedOn:
return _BasedOn(anything)


class _ToSelectFrom:
def __init__(self, value: Any):
class _Roled:
def __init__(self, value: Any, role: Role):
self.value = value
self.role = role

def __str__(self) -> str:
return str(self.value)

__repr__ = __str__


def ToSelectFrom(anything: Any) -> _ToSelectFrom:
def BasedOn(anything: Any) -> _Roled:
return _Roled(anything, Role.CONTEXT)


def ToSelectFrom(anything: Any) -> _Roled:
if not isinstance(anything, list):
raise ValueError("ToSelectFrom must be a list to select from")
return _ToSelectFrom(anything)
return _Roled(anything, Role.ACTIONS)


class _Embed:
def __init__(self, value: Any, keep: bool = False):
class _Input:
def __init__(self, value: Any, keep: bool = True, embed: bool = False):
self.value = value
self.keep = keep
self.embed = embed

def __str__(self) -> str:
return str(self.value)

@staticmethod
def create(value: Any, *args, **kwargs):
if isinstance(value, _Roled):
return _Roled(_Input.create(value.value, *args, **kwargs), value.role)
if isinstance(value, list):
return [_Input.create(v, *args, **kwargs) for v in value]
if isinstance(value, dict):
return {k: _Input.create(v, *args, **kwargs) for k, v in value.items()}
if isinstance(value, _Input): # should we swap? it will allow overwriting
return value
return _Input(value, *args, **kwargs)

__repr__ = __str__


def Embed(anything: Any, keep: bool = False) -> Any:
if isinstance(anything, _ToSelectFrom):
return ToSelectFrom(Embed(anything.value, keep=keep))
elif isinstance(anything, _BasedOn):
return BasedOn(Embed(anything.value, keep=keep))
if isinstance(anything, list):
return [Embed(v, keep=keep) for v in anything]
elif isinstance(anything, dict):
return {k: Embed(v, keep=keep) for k, v in anything.items()}
elif isinstance(anything, _Embed):
return anything
return _Embed(anything, keep=keep)
return _Input.create(anything, keep=keep, embed=True)


def EmbedAndKeep(anything: Any) -> Any:
Expand All @@ -93,26 +93,13 @@ def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Exam
return [parser.parse_line(line) for line in input_str.split("\n")]


def get_based_on_and_to_select_from(inputs: Dict[str, Any]) -> Tuple[Dict, Dict]:
to_select_from = {
k: inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _ToSelectFrom)
}

if not to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)

based_on = {
k: inputs[k].value if isinstance(inputs[k].value, list) else inputs[k].value
for k in inputs.keys()
if isinstance(inputs[k], _BasedOn)
def filter_inputs(inputs: Dict[str, Any], role: Role) -> Dict[str, Any]:
return {
k: v.value
for k, v in inputs.items()
if isinstance(v, _Roled) and v.role == role
}

return based_on, to_select_from


# end helper functions

Expand Down Expand Up @@ -470,8 +457,7 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:
)
except Exception as e:
logger.info(
f"The selection scorer was not able to score, \
and the chain was not able to adjust to this response, error: {e}"
f"The selection scorer was not able to score, and the chain was not able to adjust to this response, error: {e}"
)

event = self._call_after_scoring_before_learning(score=score, event=event)
Expand All @@ -486,14 +472,15 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:


def _embed_string_type(
item: Union[str, _Embed], model: Any, namespace: str
item: Union[str, _Input], model: Any, namespace: str
) -> Featurized:
"""Helper function to embed a string or an _Embed object."""
import re

result = Featurized()
if isinstance(item, _Embed):
result[namespace] = DenseFeatures(model.encode(item.value))
if isinstance(item, _Input):
if item.embed:
result[namespace] = DenseFeatures(model.encode(item.value))
if item.keep:
keep_str = item.value.replace(" ", "_")
result[namespace] = {"default_ft": re.sub(r"[\t\n\r\f\v]+", " ", keep_str)}
Expand Down Expand Up @@ -535,7 +522,7 @@ def _embed_list_type(


def embed(
to_embed: Union[Union[str, _Embed], Dict, List[Union[str, _Embed]], List[Dict]],
to_embed: Union[Union[str, _Input], Dict, List[Union[str, _Input]], List[Dict]],
model: Any,
namespace: Optional[str] = None,
) -> Union[Featurized, List[Featurized]]:
Expand All @@ -549,7 +536,7 @@ def embed(
Returns:
List[Dict[str, str]]: A list of dictionaries where each dictionary has the namespace as the key and the embedded string as the value
"""
if (isinstance(to_embed, _Embed) and isinstance(to_embed.value, str)) or isinstance(
if (isinstance(to_embed, _Input) and isinstance(to_embed.value, str)) or isinstance(
to_embed, str
):
return _embed_string_type(to_embed, model, namespace)
Expand Down
76 changes: 30 additions & 46 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,19 @@ class PickBestEvent(base.Event[PickBestSelected]):
def __init__(
self,
inputs: Dict[str, Any],
to_select_from: Dict[str, Any],
based_on: Dict[str, Any],
selected: Optional[PickBestSelected] = None,
):
super().__init__(inputs=inputs, selected=selected or PickBestSelected())
self.to_select_from = to_select_from
self.based_on = based_on

def context(self, model) -> base.Featurized:
return base.embed(self.based_on or {}, model)

def actions(self, model) -> List[base.Featurized]:
to_select_from_var_name, to_select_from = next(
iter(self.to_select_from.items()), (None, None)
)

action_embs = (
(
base.embed(to_select_from, model, to_select_from_var_name)
if self.to_select_from
else None
self.to_select_from = base.filter_inputs(inputs, base.Role.ACTIONS)
self.based_on = base.filter_inputs(inputs, base.Role.CONTEXT)
if not self.to_select_from:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)
if to_select_from
else None
)
if not action_embs:
if len(self.to_select_from) > 1:
raise ValueError(
"Context and to_select_from must be provided in the inputs dictionary"
"Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from."
)
return action_embs


class VwTxt:
Expand All @@ -77,9 +60,9 @@ def _sparse_2_str(values: base.SparseFeatures) -> str:
def _to_str(v):
import numbers

return v if isinstance(v, numbers.Number) else f"={v}"
return f":{v}" if isinstance(v, numbers.Number) else f"={v}"

return " ".join([f"{k}:{_to_str(v)}" for k, v in values.items()])
return " ".join([f"{k}{_to_str(v)}" for k, v in values.items()])

@staticmethod
def featurized_2_str(obj: base.Featurized) -> str:
Expand Down Expand Up @@ -157,11 +140,29 @@ def _generic_namespaces(context, actions):
for a in actions:
a["#"] = PickBestFeaturizer._generic_namespace(a)

def get_context_and_actions(
self, event
) -> Tuple[base.Featurized, List[base.Featurized]]:
context = base.embed(event.based_on or {}, self.model)
to_select_from_var_name, to_select_from = next(
iter(event.to_select_from.items()), (None, None)
)

actions = (
(
base.embed(to_select_from, self.model, to_select_from_var_name)
if event.to_select_from
else None
)
if to_select_from
else None
)
return context, actions

def featurize(
self, event: PickBestEvent
) -> Tuple[base.Featurized, List[base.Featurized], PickBestSelected]:
context = event.context(self.model)
actions = event.actions(self.model)
context, actions = self.get_context_and_actions(event)

if self.auto_embed:
self._dotproducts(context, actions)
Expand Down Expand Up @@ -224,24 +225,7 @@ class PickBest(base.RLLoop[PickBestEvent]):
"""

def _call_before_predict(self, inputs: Dict[str, Any]) -> PickBestEvent:
context, actions = base.get_based_on_and_to_select_from(inputs=inputs)
if not actions:
raise ValueError(
"No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."
)

if len(list(actions.values())) > 1:
raise ValueError(
"Only one variable using 'ToSelectFrom' can be provided in the inputs for PickBest run() call. Please provide only one variable containing a list to select from."
)

if not context:
raise ValueError(
"No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on."
)

event = PickBestEvent(inputs=inputs, to_select_from=actions, based_on=context)
return event
return PickBestEvent(inputs=inputs)

def _call_after_predict_before_scoring(
self,
Expand Down
35 changes: 17 additions & 18 deletions tests/unit_tests/test_pick_best_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,15 @@ def test_multiple_ToSelectFrom_throws() -> None:
)


def test_missing_basedOn_from_throws() -> None:
def test_missing_basedOn_from_dont_throw() -> None:
pick = learn_to_pick.PickBest.create(
llm=fake_llm_caller,
featurizer=learn_to_pick.PickBestFeaturizer(
auto_embed=False, model=MockEncoder()
),
)
actions = ["0", "1", "2"]
with pytest.raises(ValueError):
pick.run(action=learn_to_pick.ToSelectFrom(actions))
pick.run(action=learn_to_pick.ToSelectFrom(actions))


def test_ToSelectFrom_not_a_list_throws() -> None:
Expand Down Expand Up @@ -169,10 +168,10 @@ def test_everything_embedded() -> None:

expected = "\n".join(
[
f"shared |User_dense {encoded_ctx_str_1} |User_sparse default_ft:={ctx_str_1}",
f"|action_dense {action_dense} |action_sparse default_ft:={str1}",
f"|action_dense {action_dense} |action_sparse default_ft:={str2}",
f"|action_dense {action_dense} |action_sparse default_ft:={str3}",
f"shared |User_dense {encoded_ctx_str_1} |User_sparse default_ft={ctx_str_1}",
f"|action_dense {action_dense} |action_sparse default_ft={str1}",
f"|action_dense {action_dense} |action_sparse default_ft={str2}",
f"|action_dense {action_dense} |action_sparse default_ft={str3}",
]
) # noqa

Expand All @@ -198,10 +197,10 @@ def test_default_auto_embedder_is_off() -> None:

expected = "\n".join(
[
f"shared |User_sparse default_ft:={ctx_str_1}",
f"|action_sparse default_ft:={str1}",
f"|action_sparse default_ft:={str2}",
f"|action_sparse default_ft:={str3}",
f"shared |User_sparse default_ft={ctx_str_1}",
f"|action_sparse default_ft={str1}",
f"|action_sparse default_ft={str2}",
f"|action_sparse default_ft={str3}",
]
) # noqa

Expand All @@ -227,10 +226,10 @@ def test_default_w_embeddings_off() -> None:

expected = "\n".join(
[
f"shared |User_sparse default_ft:={ctx_str_1}",
f"|action_sparse default_ft:={str1}",
f"|action_sparse default_ft:={str2}",
f"|action_sparse default_ft:={str3}",
f"shared |User_sparse default_ft={ctx_str_1}",
f"|action_sparse default_ft={str1}",
f"|action_sparse default_ft={str2}",
f"|action_sparse default_ft={str3}",
]
) # noqa

Expand Down Expand Up @@ -258,9 +257,9 @@ def test_default_w_embeddings_on() -> None:

expected = "\n".join(
[
f"shared |User_sparse default_ft:={ctx_str_1} |@_sparse User:={ctx_str_1}",
f"|action_sparse default_ft:={str1} |{dot_prod} |#_sparse action:={str1} ",
f"|action_sparse default_ft:={str2} |{dot_prod} |#_sparse action:={str2} ",
f"shared |User_sparse default_ft={ctx_str_1} |@_sparse User={ctx_str_1}",
f"|action_sparse default_ft={str1} |{dot_prod} |#_sparse action={str1} ",
f"|action_sparse default_ft={str2} |{dot_prod} |#_sparse action={str2} ",
]
) # noqa

Expand Down
Loading

0 comments on commit 45e273e

Please sign in to comment.