-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- added `nos/neuron/device.py` `NeuronDevice` with tests
- Loading branch information
Showing
8 changed files
with
245 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
## Embeddings Service | ||
|
||
Start the server via: | ||
```bash | ||
nos serve up -c serve.yaml --http | ||
``` | ||
|
||
Optionally, you can provide the `inf2` runtime flag, but this is automatically inferred. | ||
|
||
```bash | ||
nos serve up -c serve.yaml --http --runtime inf2 | ||
``` | ||
|
||
### Run the tests | ||
|
||
```bash | ||
pytest -sv ./tests/test_embeddings_inf2_client.py | ||
``` | ||
|
||
### Call the service | ||
|
||
You can also call the service via the REST API directly: | ||
|
||
```bash | ||
curl \ | ||
-X POST http://<service-ip>:8000/v1/infer \ | ||
-H 'Content-Type: application/json' \ | ||
-d '{ | ||
"model_id": "BAAI/bge-small-en-v1.5", | ||
"inputs": { | ||
"texts": ["fox jumped over the moon"] | ||
} | ||
}' | ||
``` |
26 changes: 26 additions & 0 deletions
26
examples/inf2/embeddings/job-inf2-embeddings-deployment.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Usage: sky launch -c <cluster-name> job-inf2.yaml | ||
# image_id: ami-09c62125a680f0ead # us-east-2 | ||
# image_id: ami-0d4155c8606f16f5b # us-west-1 | ||
# image_id: ami-096319086cc3d5f23 # us-west-2 | ||
|
||
file_mounts: | ||
/app: . | ||
|
||
resources: | ||
cloud: aws | ||
region: us-west-2 | ||
instance_type: inf2.xlarge | ||
image_id: ami-096319086cc3d5f23 # us-west-2 | ||
disk_size: 256 | ||
ports: | ||
- 8000 | ||
|
||
setup: | | ||
sudo apt-get install -y docker-compose-plugin | ||
cd /app && python3 -m venv .venv && source .venv/bin/activate | ||
pip install git+https://github.com/spillai/nos.git pytest | ||
run: | | ||
source /app/.venv/bin/activate | ||
cd /app && NOS_LOGGING_LEVEL=DEBUG nos serve up -c serve.yaml --http |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
"""Embeddings model accelerated with AWS Neuron (using optimum-neuron).""" | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Iterable, List, Union | ||
|
||
import torch | ||
|
||
from nos.constants import NOS_CACHE_DIR | ||
from nos.hub import HuggingFaceHubConfig | ||
from nos.neuron.device import NeuronDevice | ||
|
||
|
||
@dataclass(frozen=True) | ||
class EmbeddingConfig(HuggingFaceHubConfig): | ||
"""Embeddings model configuration.""" | ||
|
||
batch_size: int = 1 | ||
"""Batch size for the model.""" | ||
|
||
sequence_length: int = 384 | ||
"""Sequence length for the model.""" | ||
|
||
|
||
class EmbeddingServiceInf2: | ||
configs = { | ||
"BAAI/bge-small-en-v1.5": EmbeddingConfig( | ||
model_name="BAAI/bge-small-en-v1.5", | ||
), | ||
} | ||
|
||
def __init__(self, model_name: str = "BAAI/bge-small-en-v1.5"): | ||
from optimum.neuron import NeuronModelForSentenceTransformers | ||
from transformers import AutoTokenizer | ||
|
||
from nos.logging import logger | ||
|
||
NeuronDevice.setup_environment() | ||
try: | ||
self.cfg = EmbeddingServiceInf2.configs[model_name] | ||
except KeyError: | ||
raise ValueError(f"Invalid model_name: {model_name}, available models: {self.configs.keys()}") | ||
|
||
# Load model from cache if available, otherwise load from HF and compile | ||
# (cache is specific to model_name, batch_size and sequence_length) | ||
cache_dir = ( | ||
NOS_CACHE_DIR / "neuron" / f"{self.cfg.model_name}-bs-{self.cfg.batch_size}-sl-{self.cfg.sequence_length}" | ||
) | ||
if Path(cache_dir).exists(): | ||
logger.info(f"Loading model from {cache_dir}") | ||
self.model = NeuronModelForSentenceTransformers.from_pretrained(str(cache_dir)) | ||
logger.info(f"Loaded model from {cache_dir}") | ||
else: | ||
input_shapes = { | ||
"batch_size": self.cfg.batch_size, | ||
"sequence_length": self.cfg.sequence_length, | ||
} | ||
self.model = NeuronModelForSentenceTransformers.from_pretrained( | ||
self.cfg.model_name, export=True, **input_shapes | ||
) | ||
self.model.save_pretrained(str(cache_dir)) | ||
logger.info(f"Saved model to {cache_dir}") | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name) | ||
self.logger = logger | ||
self.logger.info(f"Loaded neuron model: {self.cfg.model_name}") | ||
|
||
@torch.inference_mode() | ||
def __call__( | ||
self, | ||
texts: Union[str, List[str]], | ||
) -> Iterable[str]: | ||
"""Embed text with the model.""" | ||
if isinstance(texts, str): | ||
texts = [texts] | ||
inputs = self.tokenizer( | ||
texts, | ||
padding=True, | ||
return_tensors="pt", | ||
) | ||
outputs = self.model(**inputs) | ||
return outputs.sentence_embedding.cpu().numpy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
images: | ||
embeddings-inf2: | ||
base: autonomi/nos:latest-inf2 | ||
env: | ||
NOS_LOGGING_LEVEL: DEBUG | ||
NOS_NEURON_CORES: 2 | ||
run: | ||
- python -m pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com | ||
- pip install sentence-transformers | ||
|
||
models: | ||
BAAI/bge-small-en-v1.5: | ||
model_cls: EmbeddingServiceInf2 | ||
model_path: models/embeddings_inf2.py | ||
default_method: __call__ | ||
runtime_env: embeddings-inf2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import numpy as np | ||
|
||
|
||
def test_embeddings(): | ||
from models.embeddings_inf2 import EmbeddingServiceInf2 | ||
|
||
model = EmbeddingServiceInf2() | ||
texts = "What is the meaning of life?" | ||
response = model(texts=texts) | ||
assert response is not None | ||
assert isinstance(response, np.ndarray) | ||
print(response.shape) |
21 changes: 21 additions & 0 deletions
21
examples/inf2/embeddings/tests/test_embeddings_inf2_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import pytest | ||
|
||
|
||
@pytest.mark.parametrize("model_id", ["BAAI/bge-small-en-v1.5"]) | ||
def test_embeddings_client(model_id): | ||
import numpy as np | ||
|
||
from nos.client import Client | ||
|
||
# Create a client | ||
client = Client("[::]:50051") | ||
assert client.WaitForServer() | ||
|
||
# Load the embeddings model | ||
model = client.Module(model_id) | ||
|
||
# Embed text with the model | ||
texts = "What is the meaning of life?" | ||
response = model(texts=texts) | ||
assert response is not None | ||
assert isinstance(response, np.ndarray) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
from dataclasses import dataclass | ||
|
||
import torch_neuronx | ||
|
||
from nos.constants import NOS_CACHE_DIR | ||
from nos.logging import logger | ||
|
||
|
||
@dataclass | ||
class NeuronDevice: | ||
"""Neuron device environment.""" | ||
|
||
_instance: "NeuronDevice" = None | ||
|
||
@classmethod | ||
def get(cls): | ||
if cls._instance is None: | ||
cls._instance = cls() | ||
return cls._instance | ||
|
||
@staticmethod | ||
def device_count() -> int: | ||
try: | ||
return torch_neuronx.xla_impl.data_parallel.device_count() | ||
except (RuntimeError, AssertionError): | ||
return 0 | ||
|
||
@staticmethod | ||
def setup_environment() -> None: | ||
"""Setup neuron environment.""" | ||
for k, v in os.environ.items(): | ||
if "NEURON" in k: | ||
logger.debug(f"{k}={v}") | ||
cores: int = int(os.getenv("NOS_NEURON_CORES", 2)) | ||
logger.info(f"Setting up neuron env with {cores} cores") | ||
cache_dir = NOS_CACHE_DIR / "neuron" | ||
os.environ["NEURONX_CACHE"] = "on" | ||
os.environ["NEURONX_DUMP_TO"] = str(cache_dir) | ||
os.environ["NEURON_RT_NUM_CORES"] = str(cores) | ||
os.environ["NEURON_RT_VISIBLE_CORES"] = ",".join([str(i) for i in range(cores)]) | ||
os.environ["NEURON_CC_FLAGS"] = "--model-type=transformer-inference" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import pytest | ||
|
||
from nos.common.runtime import is_torch_neuron_available | ||
|
||
|
||
pytestmark = pytest.mark.skipif(not is_torch_neuron_available(), reason="Requires torch_neuron") | ||
|
||
|
||
def test_neuron_device(): | ||
from nos.neuron.device import NeuronDevice | ||
|
||
neuron_env = NeuronDevice.get() | ||
assert neuron_env.device_count() > 0 | ||
assert neuron_env.setup_environment() is None |