Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New inf2 embeddings service example #537

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/inf2/embeddings/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
## Embeddings Service

Start the server via:
```bash
nos serve up -c serve.yaml --http
```

Optionally, you can provide the `inf2` runtime flag, but this is automatically inferred.

```bash
nos serve up -c serve.yaml --http --runtime inf2
```

### Run the tests

```bash
pytest -sv ./tests/test_embeddings_inf2_client.py
```

### Call the service

You can also call the service via the REST API directly:

```bash
curl \
-X POST http://<service-ip>:8000/v1/infer \
-H 'Content-Type: application/json' \
-d '{
"model_id": "BAAI/bge-small-en-v1.5",
"inputs": {
"texts": ["fox jumped over the moon"]
}
}'
```
26 changes: 26 additions & 0 deletions examples/inf2/embeddings/job-inf2-embeddings-deployment.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Usage: sky launch -c <cluster-name> job-inf2.yaml
# image_id: ami-09c62125a680f0ead # us-east-2
# image_id: ami-0d4155c8606f16f5b # us-west-1
# image_id: ami-096319086cc3d5f23 # us-west-2

file_mounts:
/app: .

resources:
cloud: aws
region: us-west-2
instance_type: inf2.xlarge
image_id: ami-096319086cc3d5f23 # us-west-2
disk_size: 256
ports:
- 8000

setup: |
sudo apt-get install -y docker-compose-plugin

cd /app && python3 -m venv .venv && source .venv/bin/activate
pip install git+https://github.com/spillai/nos.git pytest

run: |
source /app/.venv/bin/activate
cd /app && NOS_LOGGING_LEVEL=DEBUG nos serve up -c serve.yaml --http
80 changes: 80 additions & 0 deletions examples/inf2/embeddings/models/embeddings_inf2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Embeddings model accelerated with AWS Neuron (using optimum-neuron)."""
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Union

import torch

from nos.constants import NOS_CACHE_DIR
from nos.hub import HuggingFaceHubConfig
from nos.neuron.device import NeuronDevice


@dataclass(frozen=True)
class EmbeddingConfig(HuggingFaceHubConfig):
"""Embeddings model configuration."""

batch_size: int = 1
"""Batch size for the model."""

sequence_length: int = 384
"""Sequence length for the model."""


class EmbeddingServiceInf2:
configs = {
"BAAI/bge-small-en-v1.5": EmbeddingConfig(
model_name="BAAI/bge-small-en-v1.5",
),
}

def __init__(self, model_name: str = "BAAI/bge-small-en-v1.5"):
from optimum.neuron import NeuronModelForSentenceTransformers
from transformers import AutoTokenizer

from nos.logging import logger

NeuronDevice.setup_environment()
try:
self.cfg = EmbeddingServiceInf2.configs[model_name]
except KeyError:
raise ValueError(f"Invalid model_name: {model_name}, available models: {self.configs.keys()}")

# Load model from cache if available, otherwise load from HF and compile
# (cache is specific to model_name, batch_size and sequence_length)
cache_dir = (
NOS_CACHE_DIR / "neuron" / f"{self.cfg.model_name}-bs-{self.cfg.batch_size}-sl-{self.cfg.sequence_length}"
)
if Path(cache_dir).exists():
logger.info(f"Loading model from {cache_dir}")
self.model = NeuronModelForSentenceTransformers.from_pretrained(str(cache_dir))
logger.info(f"Loaded model from {cache_dir}")
else:
input_shapes = {
"batch_size": self.cfg.batch_size,
"sequence_length": self.cfg.sequence_length,
}
self.model = NeuronModelForSentenceTransformers.from_pretrained(
self.cfg.model_name, export=True, **input_shapes
)
self.model.save_pretrained(str(cache_dir))
logger.info(f"Saved model to {cache_dir}")
self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name)
self.logger = logger
self.logger.info(f"Loaded neuron model: {self.cfg.model_name}")

@torch.inference_mode()
def __call__(
self,
texts: Union[str, List[str]],
) -> Iterable[str]:
"""Embed text with the model."""
if isinstance(texts, str):
texts = [texts]
inputs = self.tokenizer(
texts,
padding=True,
return_tensors="pt",
)
outputs = self.model(**inputs)
return outputs.sentence_embedding.cpu().numpy()
16 changes: 16 additions & 0 deletions examples/inf2/embeddings/serve.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
images:
embeddings-inf2:
base: autonomi/nos:latest-inf2
env:
NOS_LOGGING_LEVEL: DEBUG
NOS_NEURON_CORES: 2
run:
- python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
- pip install sentence-transformers

models:
BAAI/bge-small-en-v1.5:
model_cls: EmbeddingServiceInf2
model_path: models/embeddings_inf2.py
default_method: __call__
runtime_env: embeddings-inf2
12 changes: 12 additions & 0 deletions examples/inf2/embeddings/tests/test_embeddings_inf2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import numpy as np


def test_embeddings():
from models.embeddings_inf2 import EmbeddingServiceInf2

model = EmbeddingServiceInf2()
texts = "What is the meaning of life?"
response = model(texts=texts)
assert response is not None
assert isinstance(response, np.ndarray)
print(response.shape)
21 changes: 21 additions & 0 deletions examples/inf2/embeddings/tests/test_embeddings_inf2_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest


@pytest.mark.parametrize("model_id", ["BAAI/bge-small-en-v1.5"])
def test_embeddings_client(model_id):
import numpy as np

from nos.client import Client

# Create a client
client = Client("[::]:50051")
assert client.WaitForServer()

# Load the embeddings model
model = client.Module(model_id)

# Embed text with the model
texts = "What is the meaning of life?"
response = model(texts=texts)
assert response is not None
assert isinstance(response, np.ndarray)
42 changes: 42 additions & 0 deletions nos/neuron/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from dataclasses import dataclass

import torch_neuronx

from nos.constants import NOS_CACHE_DIR
from nos.logging import logger


@dataclass
class NeuronDevice:
"""Neuron device environment."""

_instance: "NeuronDevice" = None

@classmethod
def get(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance

@staticmethod
def device_count() -> int:
try:
return torch_neuronx.xla_impl.data_parallel.device_count()
except (RuntimeError, AssertionError):
return 0

@staticmethod
def setup_environment() -> None:
"""Setup neuron environment."""
for k, v in os.environ.items():
if "NEURON" in k:
logger.debug(f"{k}={v}")
cores: int = int(os.getenv("NOS_NEURON_CORES", 2))
logger.info(f"Setting up neuron env with {cores} cores")
cache_dir = NOS_CACHE_DIR / "neuron"
os.environ["NEURONX_CACHE"] = "on"
os.environ["NEURONX_DUMP_TO"] = str(cache_dir)
os.environ["NEURON_RT_NUM_CORES"] = str(cores)
os.environ["NEURON_RT_VISIBLE_CORES"] = ",".join([str(i) for i in range(cores)])
os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference"
14 changes: 14 additions & 0 deletions tests/neuron/test_neuron_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from nos.common.runtime import is_torch_neuron_available


pytestmark = pytest.mark.skipif(not is_torch_neuron_available(), reason="Requires torch_neuron")


def test_neuron_device():
from nos.neuron.device import NeuronDevice

neuron_env = NeuronDevice.get()
assert neuron_env.device_count() > 0
assert neuron_env.setup_environment() is None
Loading