Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pytorch policy #31

Merged
merged 18 commits into from
Nov 28, 2023
Merged
50 changes: 46 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Note: all code examples presented here can be found in `notebooks/readme.ipynb`
- Use a custom score function to grade the decision.
- Directly specify the score manually and asynchronously.

The beauty of `learn_to_pick` is its flexibility. Whether you're a fan of VowpalWabbit or prefer PyTorch (coming soon), the library can seamlessly integrate with both, allowing them to be the brain behind your decisions.
The beauty of `learn_to_pick` is its flexibility. Whether you're a fan of VowpalWabbit or prefer PyTorch, the library can seamlessly integrate with both, allowing them to be the brain behind your decisions.

## Installation

Expand All @@ -43,6 +43,8 @@ The `PickBest` scenario should be used when:
- Only one option is optimal for a specific criteria or context
- There exists a mechanism to provide feedback on the suitability of the chosen option for the specific criteria

### Scorer

Example usage with llm default scorer:

```python
Expand Down Expand Up @@ -113,7 +115,46 @@ dummy_score = 1
picker.update_with_delayed_score(dummy_score, result)
```

`PickBest` is highly configurable to work with a VowpalWabbit decision making policy, a PyTorch decision making policy (coming soon), or with a custom user defined decision making policy
### Using Pytorch policy

Example usage with a Pytorch policy:
```python
from learn_to_pick import PyTorchPolicy

pytorch_picker = learn_to_pick.PickBest.create(
policy=PyTorchPolicy(), selection_scorer=CustomSelectionScorer())

pytorch_picker.run(
pick = learn_to_pick.ToSelectFrom(["option1", "option2"]),
criteria = learn_to_pick.BasedOn("some criteria")
)
```

Example usage with a custom Pytorch policy:
You can alway create a custom Pytorch policy by implementing the Policy interface

```python
class CustomPytorchPolicy(Policy):
def __init__(self, **kwargs: Any):
...

def predict(self, event: TEvent) -> Any:
...

def learn(self, event: TEvent) -> None:
...

def log(self, event: TEvent) -> None:
...

def save(self) -> None:
...

pytorch_picker = learn_to_pick.PickBest.create(
policy=CustomPytorchPolicy(), selection_scorer=CustomSelectionScorer())
```

`PickBest` is highly configurable to work with a VowpalWabbit decision making policy, a PyTorch decision making policy, or with a custom user defined decision making policy

The main thing that needs to be decided from the get-go is:

Expand All @@ -134,7 +175,8 @@ In all three cases, when a score is calculated or provided, the decision making
## Example Notebooks

- `readme.ipynb` showcases all examples shown in this README
- `news_recommendation.ipynb` showcases a personalization scenario where we have to pick articles for specific users
- `news_recommendation.ipynb` showcases a personalization scenario where we have to pick articles for specific users with VowpalWabbit policy
- `news_recommendation_pytorch.ipynb` showcases the same personalization scenario where we have to pick articles for specific users with Pytorch policy
- `prompt_variable_injection.ipynb` showcases learned prompt variable injection and registering callback functionality

### Advanced Usage
Expand Down Expand Up @@ -183,7 +225,7 @@ class CustomSelectionScorer(learn_to_pick.SelectionScorer):
# inputs: the inputs to the picker in Dict[str, Any] format
# picked: the selection that was made by the policy
# event: metadata that can be used to determine the score if needed

# scoring logic goes here

dummy_score = 1.0
Expand Down
238 changes: 238 additions & 0 deletions notebooks/news_recommendation_pytorch.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from setuptools import setup, find_packages
import os

with open("README.md", "r", encoding="UTF-8") as fh:
long_description = fh.read()
Expand Down
17 changes: 12 additions & 5 deletions src/learn_to_pick/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
BasedOn,
Embed,
Featurizer,
ModelRepository,
Policy,
SelectionScorer,
ToSelectFrom,
VwPolicy,
VwLogger,
embed,
)
from learn_to_pick.pick_best import (
Expand All @@ -22,6 +19,14 @@
)


from learn_to_pick.vw.policy import VwPolicy
from learn_to_pick.vw.model_repository import ModelRepository
from learn_to_pick.vw.logger import VwLogger

from learn_to_pick.pytorch.policy import PyTorchPolicy
from learn_to_pick.pytorch.feature_embedder import PyTorchFeatureEmbedder


def configure_logger() -> None:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand All @@ -48,9 +53,11 @@ def configure_logger() -> None:
"SelectionScorer",
"AutoSelectionScorer",
"Featurizer",
"ModelRepository",
"Policy",
"PyTorchPolicy",
"PyTorchFeatureEmbedder",
"embed",
"ModelRepository",
"VwPolicy",
"VwLogger",
"embed",
]
53 changes: 1 addition & 52 deletions src/learn_to_pick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,12 @@
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
Callable,
)

from learn_to_pick.metrics import MetricsTrackerAverage, MetricsTrackerRollingWindow
from learn_to_pick.model_repository import ModelRepository
from learn_to_pick.vw_logger import VwLogger

from learn_to_pick.features import Featurized, DenseFeatures, SparseFeatures
from enum import Enum

Expand Down Expand Up @@ -89,10 +86,6 @@ def EmbedAndKeep(anything: Any) -> Any:
# helper functions


def _parse_lines(parser: "vw.TextFormatParser", input_str: str) -> List["vw.Example"]:
return [parser.parse_line(line) for line in input_str.split("\n")]


def filter_inputs(inputs: Dict[str, Any], role: Role) -> Dict[str, Any]:
return {
k: v.value
Expand Down Expand Up @@ -144,50 +137,6 @@ def save(self) -> None:
pass


class VwPolicy(Policy):
def __init__(
self,
model_repo: ModelRepository,
vw_cmd: List[str],
featurizer: Featurizer,
formatter: Callable,
vw_logger: VwLogger,
**kwargs: Any,
):
super().__init__(**kwargs)
self.model_repo = model_repo
self.vw_cmd = vw_cmd
self.workspace = self.model_repo.load(vw_cmd)
self.featurizer = featurizer
self.formatter = formatter
self.vw_logger = vw_logger

def format(self, event):
return self.formatter(*self.featurizer.featurize(event))

def predict(self, event: TEvent) -> Any:
import vowpal_wabbit_next as vw

text_parser = vw.TextFormatParser(self.workspace)
return self.workspace.predict_one(_parse_lines(text_parser, self.format(event)))

def learn(self, event: TEvent) -> None:
import vowpal_wabbit_next as vw

vw_ex = self.format(event)
text_parser = vw.TextFormatParser(self.workspace)
multi_ex = _parse_lines(text_parser, vw_ex)
self.workspace.learn_one(multi_ex)

def log(self, event: TEvent) -> None:
if self.vw_logger.logging_enabled():
vw_ex = self.format(event)
self.vw_logger.log(vw_ex)

def save(self) -> None:
self.model_repo.save(self.workspace)


class Featurizer(Generic[TEvent], ABC):
def __init__(self, *args: Any, **kwargs: Any):
pass
Expand Down
10 changes: 7 additions & 3 deletions src/learn_to_pick/pick_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import numpy as np

from learn_to_pick import base
from learn_to_pick.vw.policy import VwPolicy
from learn_to_pick.vw.model_repository import ModelRepository
from learn_to_pick.vw.logger import VwLogger


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -333,14 +337,14 @@ def create_policy(

vw_cmd = interactions + vw_cmd

return base.VwPolicy(
model_repo=base.ModelRepository(
return VwPolicy(
model_repo=ModelRepository(
model_save_dir, with_history=True, reset=reset_model
),
vw_cmd=vw_cmd,
featurizer=featurizer,
formatter=formatter,
vw_logger=base.VwLogger(rl_logs),
vw_logger=VwLogger(rl_logs),
)

def _default_policy(self):
Expand Down
Empty file.
69 changes: 69 additions & 0 deletions src/learn_to_pick/pytorch/feature_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from sentence_transformers import SentenceTransformer
import torch
from torch import Tensor

from learn_to_pick import PickBestFeaturizer
from learn_to_pick.base import Event
from learn_to_pick.features import SparseFeatures
from typing import Any, Tuple, TypeVar, Union

TEvent = TypeVar("TEvent", bound=Event)


class PyTorchFeatureEmbedder:
def __init__(self, model: Any = None):
if model is None:
model = SentenceTransformer("all-MiniLM-L6-v2")

self.model = model
self.featurizer = PickBestFeaturizer(auto_embed=False)

def encode(self, to_encode: str) -> Tensor:
embeddings = self.model.encode(to_encode, convert_to_tensor=True)
normalized = torch.nn.functional.normalize(embeddings)
return normalized

def convert_features_to_text(self, sparse_features: SparseFeatures) -> str:
results = []
for ns, obj in sparse_features.items():
value = obj.get("default_ft", "")
results.append(f"{ns}={value}")
return " ".join(results)

def format(
self, event: TEvent
) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]]:
context_featurized, actions_featurized, selected = self.featurizer.featurize(
event
)

if len(context_featurized.dense) > 0:
raise NotImplementedError(
"pytorch policy doesn't support context with dense features"
)

for action_featurized in actions_featurized:
if len(action_featurized.dense) > 0:
raise NotImplementedError(
"pytorch policy doesn't support action with dense features"
)

context_sparse = self.encode(
[self.convert_features_to_text(context_featurized.sparse)]
)

actions_sparse = []
for action_featurized in actions_featurized:
actions_sparse.append(
self.convert_features_to_text(action_featurized.sparse)
)
actions_sparse = self.encode(actions_sparse).unsqueeze(0)

if selected.score is not None:
return (
torch.Tensor([[selected.score]]),
context_sparse,
actions_sparse[:, selected.index, :].unsqueeze(1),
)
else:
return context_sparse, actions_sparse
21 changes: 21 additions & 0 deletions src/learn_to_pick/pytorch/igw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from torch import Tensor
from typing import Tuple


def IGW(fhat: torch.Tensor, gamma: float) -> Tuple[Tensor, Tensor]:
from math import sqrt

fhatahat, ahat = fhat.max(dim=1)
A = fhat.shape[1]
gamma *= sqrt(A)
p = 1 / (A + gamma * (fhatahat.unsqueeze(1) - fhat))
sump = p.sum(dim=1)
p[range(p.shape[0]), ahat] += torch.clamp(1 - sump, min=0, max=None)
return torch.multinomial(p, num_samples=1).squeeze(1), ahat


def SamplingIGW(A: Tensor, P: Tensor, gamma: float) -> list:
exploreind, _ = IGW(P, gamma)
explore = [ind for _, ind in zip(A, exploreind)]
return explore
Loading
Loading