Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
ataymano committed Nov 15, 2023
1 parent 966590d commit 36fd0aa
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 227 deletions.
18 changes: 9 additions & 9 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Type,
TypeVar,
Union,
Callable
Callable,
)

from learn_to_pick.metrics import MetricsTrackerAverage, MetricsTrackerRollingWindow
Expand Down Expand Up @@ -183,9 +183,7 @@ def predict(self, event: TEvent) -> Any:
import vowpal_wabbit_next as vw

text_parser = vw.TextFormatParser(self.workspace)
return self.workspace.predict_one(
_parse_lines(text_parser, self.format(event))
)
return self.workspace.predict_one(_parse_lines(text_parser, self.format(event)))

def learn(self, event: TEvent) -> None:
import vowpal_wabbit_next as vw
Expand Down Expand Up @@ -489,18 +487,20 @@ def run(self, *args, **kwargs) -> Dict[str, Any]:


def _embed_string_type(
item: Union[str, _Embed], model: Any, namespace: str) -> Featurized:
item: Union[str, _Embed], model: Any, namespace: str
) -> Featurized:
"""Helper function to embed a string or an _Embed object."""
import re

result = Featurized()
if isinstance(item, _Embed):
result[namespace] = DenseFeatures(model.encode(item.value))
if item.keep:
keep_str = item.value.replace(" ", "_")
result[namespace] = {'raw': re.sub(r"[\t\n\r\f\v]+", " ", keep_str)}
result[namespace] = {"raw": re.sub(r"[\t\n\r\f\v]+", " ", keep_str)}
elif isinstance(item, str):
encoded = item.replace(" ", "_")
result[namespace] = {'raw': re.sub(r"[\t\n\r\f\v]+", " ", encoded)}
result[namespace] = {"raw": re.sub(r"[\t\n\r\f\v]+", " ", encoded)}
else:
raise ValueError(f"Unsupported type {type(item)} for embedding")

Expand All @@ -513,7 +513,7 @@ def _embed_dict_type(item: Dict, model: Any) -> Featurized:
for ns, embed_item in item.items():
if isinstance(embed_item, list):
for idx, embed_list_item in enumerate(embed_item):
result.merge(_embed_string_type(embed_list_item, model, f'{ns}_{idx}'))
result.merge(_embed_string_type(embed_list_item, model, f"{ns}_{idx}"))
else:
result.merge(_embed_string_type(embed_item, model, ns))
return result
Expand All @@ -529,7 +529,7 @@ def _embed_list_type(
elif isinstance(embed_item, list):
result.append(Featurized())
for idx, embed_list_item in enumerate(embed_item):
result[-1].merge(_embed_string_type(embed_list_item, model, f'{idx}'))
result[-1].merge(_embed_string_type(embed_list_item, model, f"{idx}"))
else:
result.append(_embed_string_type(embed_item, model, namespace))
return result
Expand Down
13 changes: 10 additions & 3 deletions src/learn_to_pick/features.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union, Optional, Dict, List
import numpy as np


class SparseFeatures(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -12,7 +13,11 @@ def __init__(self, *args, **kwargs):


class Featurized:
def __init__(self, sparse: Optional[Dict[str, SparseFeatures]] = None, dense: Optional[Dict[str, DenseFeatures]] = None):
def __init__(
self,
sparse: Optional[Dict[str, SparseFeatures]] = None,
dense: Optional[Dict[str, DenseFeatures]] = None,
):
self.sparse = sparse or {}
self.dense = dense or {}

Expand All @@ -22,8 +27,10 @@ def __setitem__(self, key, value):
elif isinstance(value, List) or isinstance(value, np.ndarray):
self.dense[key] = DenseFeatures(value)
else:
raise ValueError(f'Cannot convert {type(value)} to either DenseFeatures or SparseFeatures')

raise ValueError(
f"Cannot convert {type(value)} to either DenseFeatures or SparseFeatures"
)

def merge(self, other):
self.sparse.update(other.sparse)
self.dense.update(other.dense)
66 changes: 46 additions & 20 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
self.probability = probability
self.score = score


class PickBestEvent(base.Event[PickBestSelected]):
def __init__(
self,
Expand Down Expand Up @@ -65,6 +66,7 @@ def actions(self, model) -> List[base.Featurized]:
)
return action_embs


class VwTxt:
@staticmethod
def _dense_2_str(values: base.DenseFeatures) -> str:
Expand All @@ -74,15 +76,27 @@ def _dense_2_str(values: base.DenseFeatures) -> str:
def _sparse_2_str(values: base.SparseFeatures) -> str:
def _to_str(v):
import numbers
return v if isinstance(v, numbers.Number) else f'={v}'

return v if isinstance(v, numbers.Number) else f"={v}"

return " ".join([f"{k}:{_to_str(v)}" for k, v in values.items()])

@staticmethod
def featurized_2_str(obj: base.Featurized) -> str:
return " ".join(chain.from_iterable([
map(lambda kv: f'|{kv[0]}_dense {VwTxt._dense_2_str(kv[1])}', obj.dense.items()),
map(lambda kv: f'|{kv[0]}_sparse {VwTxt._sparse_2_str(kv[1])}', obj.sparse.items())]))
return " ".join(
chain.from_iterable(
[
map(
lambda kv: f"|{kv[0]}_dense {VwTxt._dense_2_str(kv[1])}",
obj.dense.items(),
),
map(
lambda kv: f"|{kv[0]}_sparse {VwTxt._sparse_2_str(kv[1])}",
obj.sparse.items(),
),
]
)
)


class PickBestFeaturizer(base.Featurizer[PickBestEvent]):
Expand All @@ -109,54 +123,64 @@ def __init__(
def _dotproducts(self, context, actions):
_context_dense = base.Featurized()
for ns in context.sparse.keys():
if 'raw' in context.sparse[ns]:
_context_dense[ns] = self.model.encode(context.sparse[ns]['raw'])
if "raw" in context.sparse[ns]:
_context_dense[ns] = self.model.encode(context.sparse[ns]["raw"])

_actions_dense = [base.Featurized() for _ in range(len(actions))]
for _action, action in zip(_actions_dense, actions):
for ns in action.sparse.keys():
if 'raw' in action.sparse[ns]:
_action[ns] = self.model.encode(action.sparse[ns]['raw'])
if "raw" in action.sparse[ns]:
_action[ns] = self.model.encode(action.sparse[ns]["raw"])

context_names = list(_context_dense.dense.keys())
context_matrix = np.stack(list(_context_dense.dense.values()))
for _a, a in zip(_actions_dense, actions):
action_names = list(_a.dense.keys())
product = np.dot(context_matrix, np.stack(list(_a.dense.values())).T)
a['dotprod'] = {f'{context_names[i]}_{action_names[j]}': product[i, j] for i in range(len(context_names)) for j in range(len(action_names))}
a["dotprod"] = {
f"{context_names[i]}_{action_names[j]}": product[i, j]
for i in range(len(context_names))
for j in range(len(action_names))
}

def _generic_namespace(self, featurized):
result = base.SparseFeatures()
for ns in featurized.sparse.keys():
if 'raw' in featurized.sparse[ns]:
result[ns] = featurized.sparse[ns]['raw']
if "raw" in featurized.sparse[ns]:
result[ns] = featurized.sparse[ns]["raw"]
return result

def _generic_namespaces(self, context, actions):
context['@'] = self._generic_namespace(context)
context["@"] = self._generic_namespace(context)
for a in actions:
a['#'] = self._generic_namespace(a)
a["#"] = self._generic_namespace(a)

def featurize(self, event: PickBestEvent) -> Tuple[base.Featurized, List[base.Featurized], PickBestSelected]:
def featurize(
self, event: PickBestEvent
) -> Tuple[base.Featurized, List[base.Featurized], PickBestSelected]:
context = event.context(self.model)
actions = event.actions(self.model)

if self.auto_embed:
self._dotproducts(context, actions)
self._generic_namespaces(context, actions)

return context, actions, event.selected


def vw_cb_formatter(context: base.Featurized, actions: List[base.Featurized], selected: PickBestSelected) -> str:
def vw_cb_formatter(
context: base.Featurized, actions: List[base.Featurized], selected: PickBestSelected
) -> str:
nactions = len(actions)
context_str = f"shared {VwTxt.featurized_2_str(context)}"
labels = ["" for _ in range(nactions)]
if selected.score is not None:
labels[selected.index] = f"{selected.index}:{-selected.score}:{selected.probability} "
labels[
selected.index
] = f"{selected.index}:{-selected.score}:{selected.probability} "
actions_str = [f"{l}{VwTxt.featurized_2_str(a)}" for a, l in zip(actions, labels)]
return "\n".join([context_str] + actions_str)


class PickBestRandomPolicy(base.Policy[PickBestEvent]):
def __init__(self):
Expand Down Expand Up @@ -235,7 +259,9 @@ def _call_after_predict_before_scoring(
sampled_ap = prediction[sampled_index]
sampled_action = sampled_ap[0]
sampled_prob = sampled_ap[1]
event.selected = PickBestSelected(index=sampled_action, probability=sampled_prob)
event.selected = PickBestSelected(
index=sampled_action, probability=sampled_prob
)

next_inputs = inputs.copy()

Expand Down
52 changes: 32 additions & 20 deletions tests/unit_tests/test_pick_best_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,18 @@ def test_everything_embedded() -> None:
str2 = "1"
str3 = "2"
action_dense = "0:1.0 1:0.0"

ctx_str_1 = "context1"
encoded_ctx_str_1 = "0:8.0 1:0.0"

expected = "\n".join([
f"shared |User_dense {encoded_ctx_str_1} |User_sparse raw:={ctx_str_1}",
f"|action_dense {action_dense} |action_sparse raw:={str1}",
f"|action_dense {action_dense} |action_sparse raw:={str2}",
f"|action_dense {action_dense} |action_sparse raw:={str3}"]) # noqa
expected = "\n".join(
[
f"shared |User_dense {encoded_ctx_str_1} |User_sparse raw:={ctx_str_1}",
f"|action_dense {action_dense} |action_sparse raw:={str1}",
f"|action_dense {action_dense} |action_sparse raw:={str2}",
f"|action_dense {action_dense} |action_sparse raw:={str3}",
]
) # noqa

actions = [str1, str2, str3]

Expand All @@ -193,11 +196,14 @@ def test_default_auto_embedder_is_off() -> None:
str3 = "2"
ctx_str_1 = "context1"

expected = "\n".join([
f"shared |User_sparse raw:={ctx_str_1}",
f"|action_sparse raw:={str1}",
f"|action_sparse raw:={str2}",
f"|action_sparse raw:={str3}"]) # noqa
expected = "\n".join(
[
f"shared |User_sparse raw:={ctx_str_1}",
f"|action_sparse raw:={str1}",
f"|action_sparse raw:={str2}",
f"|action_sparse raw:={str3}",
]
) # noqa

actions = [str1, str2, str3]

Expand All @@ -219,11 +225,14 @@ def test_default_w_embeddings_off() -> None:
str3 = "2"
ctx_str_1 = "context1"

expected = "\n".join([
f"shared |User_sparse raw:={ctx_str_1}",
f"|action_sparse raw:={str1}",
f"|action_sparse raw:={str2}",
f"|action_sparse raw:={str3}"]) # noqa
expected = "\n".join(
[
f"shared |User_sparse raw:={ctx_str_1}",
f"|action_sparse raw:={str1}",
f"|action_sparse raw:={str2}",
f"|action_sparse raw:={str3}",
]
) # noqa

actions = [str1, str2, str3]

Expand All @@ -247,10 +256,13 @@ def test_default_w_embeddings_on() -> None:
ctx_str_1 = "context1"
dot_prod = "dotprod_sparse User_action:5.0" # dot prod of [1.0, 2.0] and [1.0, 2.0]

expected = "\n".join([
f"shared |User_sparse raw:={ctx_str_1} |@_sparse User:={ctx_str_1}",
f"|action_sparse raw:={str1} |{dot_prod} |#_sparse action:={str1} ",
f"|action_sparse raw:={str2} |{dot_prod} |#_sparse action:={str2} "]) # noqa
expected = "\n".join(
[
f"shared |User_sparse raw:={ctx_str_1} |@_sparse User:={ctx_str_1}",
f"|action_sparse raw:={str1} |{dot_prod} |#_sparse action:={str1} ",
f"|action_sparse raw:={str2} |{dot_prod} |#_sparse action:={str2} ",
]
) # noqa

actions = [str1, str2]

Expand Down
Loading

0 comments on commit 36fd0aa

Please sign in to comment.