Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New embeddings service with inf2 (Inferentia2) runtime #28

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/embeddings-inf2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
## Embeddings Service

Start the server via:
```bash
nos serve up -c serve.yaml --http
```

Optionally, you can provide the `inf2` runtime flag, but this is automatically inferred.

```bash
nos serve up -c serve.yaml --http --runtime inf2
```

### Run the tests

```bash
pytest -sv ./tests/test_embeddings_client.py
```

### Call the service

You can also call the service via the REST API directly:

```bash
curl \
-X POST http://<service-ip>:8000/v1/infer \
-H 'Content-Type: application/json' \
-d '{
"model_id": "BAAI/bge-small-en-v1.5",
"inputs": {
"texts": ["fox jumped over the moon"]
}
}'
```
27 changes: 27 additions & 0 deletions examples/embeddings-inf2/job-inf2-embeddings-deployment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Usage: sky launch -c <cluster-name> job-inf2.yaml

file_mounts:
/app: .

resources:
cloud: aws
region: us-west-2
instance_type: inf2.xlarge
# image_id: ami-09c62125a680f0ead # us-east-2
# image_id: ami-0d4155c8606f16f5b # us-west-1
image_id: ami-096319086cc3d5f23 # us-west-2
disk_size: 256
ports:
- 8000

setup: |
sudo apt-get install -y docker-compose-plugin

cd /app
python3 -m venv .venv
source .venv/bin/activate
pip install git+https://github.com/spillai/nos.git@spillai/0.2.0-dev pytest

run: |
source /app/.venv/bin/activate
cd /app && NOS_LOGGING_LEVEL=DEBUG nos serve up -c serve.yaml --http
103 changes: 103 additions & 0 deletions examples/embeddings-inf2/models/embeddings_inf2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Embeddings model accelerated with AWS Neuron (using optimum-neuron)."""
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Union

import torch
import torch_neuronx
from nos.constants import NOS_CACHE_DIR
from nos.hub import HuggingFaceHubConfig


@dataclass(frozen=True)
class EmbeddingConfig(HuggingFaceHubConfig):
"""Embeddings model configuration."""

batch_size: int = 1
"""Batch size for the model."""

sequence_length: int = 384
"""Sequence length for the model."""


def get_neuon_device_count():
try:
return torch_neuronx.xla_impl.data_parallel.device_count()
except (RuntimeError, AssertionError):
return 0


def _setup_neuron_env():
from nos.logging import logger

# print environment for all neuron related variables
for k, v in os.environ.items():
if "NEURON" in k:
logger.debug(f"{k}={v}")
cores: int = int(os.getenv("NOS_NEURON_CORES", 2))
logger.info(f"Setting up neuron env with {cores} cores")
cache_dir = NOS_CACHE_DIR / "neuron"
os.environ["NEURONX_CACHE"] = "on"
os.environ["NEURONX_DUMP_TO"] = str(cache_dir)
os.environ["NEURON_RT_NUM_CORES"] = str(cores)
os.environ["NEURON_RT_VISIBLE_CORES"] = ",".join([str(i) for i in range(cores)])
os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference"


class EmbeddingServiceInf2:
configs = {
"BAAI/bge-small-en-v1.5": EmbeddingConfig(
model_name="BAAI/bge-small-en-v1.5",
),
}

def __init__(self, model_name: str = "BAAI/bge-small-en-v1.5"):
_setup_neuron_env()

from nos.logging import logger
from optimum.neuron import NeuronModelForSentenceTransformers
from transformers import AutoTokenizer

try:
self.cfg = EmbeddingServiceInf2.configs[model_name]
except KeyError:
raise ValueError(f"Invalid model_name: {model_name}, available models: {self.configs.keys()}")

cache_dir = (
NOS_CACHE_DIR / "neuron" / f"{self.cfg.model_name}-bs-{self.cfg.batch_size}-sl-{self.cfg.sequence_length}"
)
if Path(cache_dir).exists():
logger.info(f"Loading model from {cache_dir}")
self.model = NeuronModelForSentenceTransformers.from_pretrained(str(cache_dir))
logger.info(f"Loaded model from {cache_dir}")
else:
# Load Transformers model and export it to AWS Inferentia2
input_shapes = {
"batch_size": self.cfg.batch_size,
"sequence_length": self.cfg.sequence_length,
}
self.model = NeuronModelForSentenceTransformers.from_pretrained(
self.cfg.model_name, export=True, **input_shapes
)
self.model.save_pretrained(str(cache_dir))
logger.info(f"Saved model to {cache_dir}")
self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name)
self.logger = logger
self.logger.info(f"Loaded neuron model: {self.cfg.model_name}")

@torch.inference_mode()
def __call__(
self,
texts: Union[str, List[str]],
) -> Iterable[str]:
"""Embed text with the model."""
if isinstance(texts, str):
texts = [texts]
inputs = self.tokenizer(
texts,
padding=True,
return_tensors="pt",
)
outputs = self.model(**inputs)
return outputs.sentence_embedding.cpu().numpy()
16 changes: 16 additions & 0 deletions examples/embeddings-inf2/serve.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
images:
embeddings-inf2:
base: autonomi/nos:latest-inf2
env:
NOS_LOGGING_LEVEL: DEBUG
NOS_NEURON_CORES: 2
run:
- python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
- pip install sentence-transformers

models:
BAAI/bge-small-en-v1.5:
model_cls: EmbeddingServiceInf2
model_path: models/embeddings_inf2.py
default_method: __call__
runtime_env: embeddings-inf2
12 changes: 12 additions & 0 deletions examples/embeddings-inf2/tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import numpy as np


def test_embeddings():
from models.embeddings_inf2 import EmbeddingServiceInf2

model = EmbeddingServiceInf2()
texts = "What is the meaning of life?"
response = model(texts=texts)
assert response is not None
assert isinstance(response, np.ndarray)
print(response.shape)
20 changes: 20 additions & 0 deletions examples/embeddings-inf2/tests/test_embeddings_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest


@pytest.mark.parametrize("model_id", ["BAAI/bge-small-en-v1.5"])
def test_embeddings_client(model_id):
import numpy as np
from nos.client import Client

# Create a client
client = Client("[::]:50051")
assert client.WaitForServer()

# Load the embeddings model
model = client.Module(model_id)

# Embed text with the model
texts = "What is the meaning of life?"
response = model(texts=texts)
assert response is not None
assert isinstance(response, np.ndarray)