Skip to content

Commit

Permalink
[fix] context_length_estimate type and max_rolling_batch default for …
Browse files Browse the repository at this point in the history
…inf2
  • Loading branch information
sindhuvahinis authored and tosterberg committed Dec 9, 2023
1 parent 6ad3084 commit 9388e94
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import os
import re
from typing import Optional
from typing import Optional, Union, List

from pydantic import validator, root_validator
from enum import IntEnum, Enum
Expand Down Expand Up @@ -50,7 +50,7 @@ class TransformerNeuronXProperties(Properties):
load_in_8bit: Optional[bool] = False
low_cpu_mem_usage: bool = False
load_split_model: bool = False
context_length_estimate: Optional[dict] = None
context_length_estimate: Optional[List[int]] = None
amp: Optional[str] = None
quantize: Optional[TnXQuantizeMethods] = None
compiled_graph_path: Optional[str] = None
Expand All @@ -62,7 +62,10 @@ def set_neuron_optimal_env(cls, level):

@validator('context_length_estimate', pre=True)
def parse_context_length(cls, context_length_estimate):
return json.loads(context_length_estimate)
return [
int(context_length)
for context_length in context_length_estimate.split(',')
]

@validator('rolling_batch', pre=True)
def validate_rolling_batch(cls, rolling_batch: str) -> str:
Expand Down
41 changes: 25 additions & 16 deletions engines/python/setup/djl_python/tests/test_properties_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,12 @@ def test_tnx_all_configs(self):
# TODO: Replace with actual example of context_length_estimate

properties = {
"n_positions":
"2048",
"load_split_model":
"true",
"load_in_8bit":
"true",
"compiled_graph_path":
"s3://test/bucket/folder",
"low_cpu_mem_usage":
"true",
'context_length_estimate':
'{"context_length": "128", "variable_size": "12"}'
"n_positions": "2048",
"load_split_model": "true",
"load_in_8bit": "true",
"compiled_graph_path": "s3://test/bucket/folder",
"low_cpu_mem_usage": "true",
'context_length_estimate': '256, 512, 1024'
}
tnx_configs = TransformerNeuronXProperties(**common_properties,
**properties)
Expand All @@ -128,10 +122,18 @@ def test_tnx_all_configs(self):
self.assertTrue(tnx_configs.load_in_8bit)
self.assertTrue(tnx_configs.low_cpu_mem_usage)

self.assertDictEqual(tnx_configs.context_length_estimate, {
'context_length': '128',
'variable_size': '12'
})
self.assertListEqual(tnx_configs.context_length_estimate,
[256, 512, 1024])

# tests context length estimate as integer
def test_tnx_cle_int(context_length_estimate):
properties['context_length_estimate'] = context_length_estimate
configs = TransformerNeuronXProperties(**common_properties,
**properties)
self.assertEqual(configs.context_length_estimate, [256])
del properties['context_length_estimate']

test_tnx_cle_int('256')

def test_tnx_configs_error_case(self):
properties = {
Expand All @@ -152,8 +154,15 @@ def test_non_existent_directory(directory):
TransformerNeuronXProperties(**common_properties, **properties)
del properties['compiled_graph_path']

def test_invalid_context_length(context_length_estimate):
properties['context_length_estimate'] = context_length_estimate
with self.assertRaises(ValueError):
TransformerNeuronXProperties(**common_properties, **properties)
del properties['context_length_estimate']

test_url_not_s3_uri("https://random.url.address/")
test_non_existent_directory("not_a_directory")
test_invalid_context_length("invalid")

def test_trtllm_configs(self):
properties = {
Expand Down
11 changes: 7 additions & 4 deletions engines/python/setup/djl_python/transformers_neuronx.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@

class NeuronXSampleAdapter(HuggingFaceGenerationModelAdapter):

def __init__(self, _config, _model):
def __init__(self, _config, _model, batch_size):
super().__init__(_config, _model)
self.model_type = _config.model_type
self.sample_options = ["start_ids", "top_k"]
self.batch_size = batch_size
if self.model_type == "llama":
self.sample_options = self.sample_options + [
"top_p", "eos_token_override", "temperature", "streamer"
Expand Down Expand Up @@ -199,12 +200,14 @@ def initialize(self, properties):
self.load_model(model_config.model_type)

# HuggingFace compatible generate model and Neuron custom sample method
self.model = NeuronXSampleAdapter(model_config, self.model)
self.model = NeuronXSampleAdapter(model_config, self.model,
self.tnx_configs.batch_size)
self.initialized = True
if self.tnx_configs.rolling_batch != "disable":
rolling_batch_size = min(self.tnx_configs.batch_size,
self.tnx_configs.max_rolling_batch_size)
self.rolling_batch = NeuronRollingBatch(
self.model, self.tokenizer,
self.tnx_configs.max_rolling_batch_size,
self.model, self.tokenizer, rolling_batch_size,
self.tnx_configs.n_positions)

def parse_input(self, inputs):
Expand Down
20 changes: 10 additions & 10 deletions serving/docs/configurations_large_model_inference_containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,17 @@ If you specify Engine to be MPI, rolling_batch to vllm in DeepSpeed container, t

If you are using Neuron container and engine set to Python, the following parameter will be accessible.

| Item | Required | Description | Example value |
|--------------------------------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------|
| option.n_positions | No | Input sequence length | Default: `128` |
| option.load_in_8bit | No | Specify this option to quantize your model using the supported quantization methods in TransformerNeuronX | `False`, `True` Default: `False` |
| Item | Required | Description | Example value |
|--------------------------------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------|
| option.n_positions | No | Input sequence length | Default: `128` |
| option.load_in_8bit | No | Specify this option to quantize your model using the supported quantization methods in TransformerNeuronX | `False`, `True` Default: `False` |
| Advanced parameters |
| option.unroll | No | Unroll the model graph for compilation. With `unroll=None` compiler will have more opportunities to do optimizations across the layers | Default: `None` |
| option.neuron_optimize_level | No | Neuron runtime compiler optimization level, determines the type of optimizations applied during compilation. The higher optimize level we go, the longer time will spend on compilation. But in exchange, you will get better latency/throughput. Default value is not set (optimize level 2) that have a balance of compilation time and performance | `1`,`2`,`3` Default: `2` |
| option.context_length_estimate | No | Estimated context input length for Llama models. Customer can specify different size bucket to increase the KV cache re-usability. This will help to improve latency | Default: `None` |
| option.low_cpu_mem_usage | No | Reduce CPU memory usage when loading models. | Default: `False` |
| option.load_split_model | No | Toggle to True when using model artifacts that have already been split for neuron compilation/loading. | Default: `False` |
| option.compiled_graph_path | No | Provide an s3 URI, or a local directory that stores the pre-compiled graph for your model (NEFF cache) to skip runtime compilation. | Default: `None` |
| option.unroll | No | Unroll the model graph for compilation. With `unroll=None` compiler will have more opportunities to do optimizations across the layers | Default: `None` |
| option.neuron_optimize_level | No | Neuron runtime compiler optimization level, determines the type of optimizations applied during compilation. The higher optimize level we go, the longer time will spend on compilation. But in exchange, you will get better latency/throughput. Default value is not set (optimize level 2) that have a balance of compilation time and performance | `1`,`2`,`3` Default: `2` |
| option.context_length_estimate | No | Estimated context input length for Llama models. Customer can specify different size bucket to increase the KV cache re-usability. This will help to improve latency | Example: `256,512,1024` (integers separated by comma if multiple values) <br/> Default: `None` |
| option.low_cpu_mem_usage | No | Reduce CPU memory usage when loading models. | Default: `False` |
| option.load_split_model | No | Toggle to True when using model artifacts that have already been split for neuron compilation/loading. | Default: `False` |
| option.compiled_graph_path | No | Provide an s3 URI, or a local directory that stores the pre-compiled graph for your model (NEFF cache) to skip runtime compilation. | Default: `None` |


### TensorRT-LLM
Expand Down
1 change: 1 addition & 0 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@
"option.n_positions": 512,
"option.model_loading_timeout": 2400,
"option.load_split_model": True,
"option.context_length_estimate": '256, 512, 1024'
},
"opt-1.3b-streaming": {
"option.model_id": "s3://djl-llm/opt-1.3b/",
Expand Down

0 comments on commit 9388e94

Please sign in to comment.