Skip to content

Commit

Permalink
Use Case: Enhancing LLM Serving with Torch Compiled RAG on AWS Gravit…
Browse files Browse the repository at this point in the history
…on (#3276)

* RAG based LLM usecase

* RAG based LLM usecase

* Changes for deploying RAG

* Updated README

* Added main blog

* Added main blog assets

* Added main blog assets

* Added use case to index html

* Added benchmark config

* Minor edits to README

* Added new MD for Gen AI usecases

* Added link to GV3 tutorial

* Addressed review comments

* Update examples/usecases/RAG_based_LLM_serving/README.md

Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>

* Update examples/usecases/RAG_based_LLM_serving/README.md

Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>

* Addressed review comments

---------

Co-authored-by: Matthias Reso <13337103+mreso@users.noreply.github.com>
  • Loading branch information
agunapal and mreso authored Aug 2, 2024
1 parent 24e2492 commit 3f40180
Show file tree
Hide file tree
Showing 20 changed files with 766 additions and 12 deletions.
7 changes: 7 additions & 0 deletions docs/genai_use_cases.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# TorchServe GenAI use cases and showcase

This document shows interesting usecases with TorchServe for Gen AI deployments.

## [Enhancing LLM Serving with Torch Compiled RAG on AWS Graviton](https://pytorch.org/serve/enhancing_llm_serving_compile_rag.html)

In this blog, we show how to deploy a RAG Endpoint using TorchServe, increase throughput using `torch.compile` and improve the response generated by the Llama Endpoint. We also show how the RAG endpoint can be deployed on CPU using AWS Graviton, while the Llama endpoint is still deployed on a GPU. This kind of microservices-based RAG solution efficiently utilizes compute resources, resulting in potential cost savings for customers.
7 changes: 7 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ What's going on in TorchServe?
:link: use_cases.html
:tags: Examples

.. customcarditem::
:header: TorchServe GenAI Use Cases
:card_description: Showcasing GenAI deployment scenarios and use cases
:image: https://raw.githubusercontent.com/pytorch/serve/master/examples/LLM/llama/images/llama.png
:link: genai_use_cases.html
:tags: Use Cases

.. customcarditem::
:header: Performance
:card_description: Guides and best practices on how to improve perfromance when working with TorchServe
Expand Down
3 changes: 2 additions & 1 deletion docs/sphinx/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@ docset: html
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
cp ../../SECURITY.md ../Security.md
cp ../../SECURITY.md ../security.md
cp ../../examples//usecases/RAG_based_LLM_serving/README.md ../enhancing_llm_serving_compile_rag.md
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(self):
self.max_new_tokens = None
self.tokenizer = None
self.initialized = False
self.quant_config = None
self.return_full_text = True

def initialize(self, ctx: Context):
"""In this initialize function, the HF large model is loaded and
Expand All @@ -37,6 +39,10 @@ def initialize(self, ctx: Context):
self.max_new_tokens = int(ctx.model_yaml_config["handler"]["max_new_tokens"])
model_name = ctx.model_yaml_config["handler"]["model_name"]
model_path = f'{model_dir}/{ctx.model_yaml_config["handler"]["model_path"]}'
self.return_full_text = ctx.model_yaml_config["handler"].get(
"return_full_text", True
)
quantization = ctx.model_yaml_config["handler"].get("quantization", True)
seed = int(ctx.model_yaml_config["handler"]["manual_seed"])
torch.manual_seed(seed)

Expand All @@ -45,22 +51,23 @@ def initialize(self, ctx: Context):
self.tokenizer.padding_side = "left"
logger.info("Model %s loaded tokenizer successfully", ctx.model_name)

if self.tokenizer.vocab_size >= 128000:
quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
else:
quant_config = BitsAndBytesConfig(load_in_8bit=True)
if quantization:
if self.tokenizer.vocab_size >= 128000:
self.quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
else:
self.quant_config = BitsAndBytesConfig(load_in_8bit=True)

self.model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="balanced",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
quantization_config=quant_config,
quantization_config=self.quant_config,
trust_remote_code=True,
)
self.device = next(iter(self.model.parameters())).device
Expand Down Expand Up @@ -115,9 +122,12 @@ def inference(self, input_batch):
"""
outputs = self.model.generate(
**input_batch,
max_length=self.max_new_tokens,
max_new_tokens=self.max_new_tokens,
)

if not self.return_full_text:
outputs = outputs[:, input_batch["input_ids"].shape[1] :]

inferences = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
Expand Down
104 changes: 104 additions & 0 deletions examples/usecases/RAG_based_LLM_serving/Deploy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Deploy Llama & RAG using TorchServe

## Contents
* [Deploy Llama](#deploy-llama)
* [Download Llama](#download-model)
* [Generate MAR file](#generate-mar-file)
* [Add MAR to model store](#add-the-mar-file-to-model-store)
* [Start TorchServe](#start-torchserve)
* [Query Llama](#query-llama)
* [Deploy RAG](#deploy-rag)
* [Download embedding model](#download-embedding-model)
* [Generate MAR file](#generate-mar-file-1)
* [Add MAR to model store](#add-the-mar-file-to-model-store-1)
* [Start TorchServe](#start-torchserve-1)
* [Query Llama](#query-rag)
* [End-to-End](#)

### Deploy Llama

### Download Llama

Follow [this instruction](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) to get permission

Login with a Hugging Face account
```
huggingface-cli login
# or using an environment variable
huggingface-cli login --token $HUGGINGFACE_TOKEN
```bash
python ../../large_models/Huggingface_accelerate/Download_model.py --model_path model --model_name meta-llama/Meta-Llama-3-8B-Instruct
```
Model will be saved in the following path, `model/models--meta-llama--Meta-Llama-3-8B-Instruct`.

### Generate MAR file

Add the downloaded path to " model_path:" in `model-config.yaml` and run the following.

```
torch-model-archiver --model-name llama3-8b-instruct --version 1.0 --handler ../../large_models/Huggingface_accelerate/llama/custom_handler.py --config-file llama-config.yaml -r ../../large_models/Huggingface_accelerate/llama/requirements.txt --archive-format no-archive
```

### Add the mar file to model store

```bash
mkdir model_store
mv llama3-8b-instruct model_store
mv model model_store/llama3-8b-instruct
```

### Start TorchServe

```bash
torchserve --start --ncs --ts-config ../../large_models/Huggingface_accelerate/llama/config.properties --model-store model_store --models llama3-8b-instruct --disable-token-auth --enable-model-api
```
### Query Llama

```bash
python query_llama.py
```

### Deploy RAG

### Download embedding model

```
python ../../large_models/Huggingface_accelerate/Download_model.py --model_name sentence-transformers/all-mpnet-base-v2
```
Model is download to `model/models--sentence-transformers--all-mpnet-base-v2`

### Generate MAR file

Add the downloaded path to " model_path:" in `rag-config.yaml` and run the following
```
torch-model-archiver --model-name rag --version 1.0 --handler rag_handler.py --config-file rag-config.yaml --extra-files="hf_custom_embeddings.py" -r requirements.txt --archive-format no-archive
```

### Add the mar file to model store

```bash
mkdir -p model_store
mv rag model_store
mv model model_store/rag
```

### Start TorchServe
```
torchserve --start --ncs --ts-config config.properties --model-store model_store --models rag --disable-token-auth --enable-model-api
```

### Query RAG

```bash
python query_rag.py
```

### RAG + LLM

Send the query to RAG to get the context, send the response to Llama to get more accurate results

```bash
python query_rag_llama.py
```
Loading

0 comments on commit 3f40180

Please sign in to comment.