From 254a80ca2438c79cf6ae0a44bb51c64b13dada25 Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Wed, 22 May 2024 20:13:19 -0700 Subject: [PATCH] Add jit script KJT benchmarks Summary: Add benchmarks for jit scripted KJT methods Reviewed By: gnahzg Differential Revision: D57701618 --- .../tests/keyed_jagged_tensor_benchmark.py | 86 ++++++++++++++----- 1 file changed, 65 insertions(+), 21 deletions(-) diff --git a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py index 24324395e..1cc393b22 100644 --- a/torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py +++ b/torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py @@ -18,7 +18,7 @@ from torchrec.distributed.test_utils.test_model import ModelInput from torchrec.modules.embedding_configs import EmbeddingBagConfig -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor def generate_kjt( @@ -69,15 +69,17 @@ def wrapped_func( kjt: KeyedJaggedTensor, test_func: Callable[[KeyedJaggedTensor], object], fn_kwargs: Dict[str, Any], + jit_script: bool, ) -> Callable[..., object]: def fn() -> object: return test_func(kjt, **fn_kwargs) - return fn + return fn if jit_script else torch.jit.script(fn) def benchmark_kjt( - method_name: str, + test_name: str, + test_func: Callable[..., object], kjt: KeyedJaggedTensor, num_repeat: int, num_warmup: int, @@ -86,21 +88,15 @@ def benchmark_kjt( mean_pooling_factor: int, fn_kwargs: Dict[str, Any], is_static_method: bool, + jit_script: bool, ) -> None: - test_name = method_name - - # pyre-ignore - def test_func(kjt: KeyedJaggedTensor, **kwargs): - return getattr(KeyedJaggedTensor if is_static_method else kjt, method_name)( - **kwargs - ) for _ in range(num_warmup): - test_func(kjt, **fn_kwargs) + test_func(**fn_kwargs) times = [] for _ in range(num_repeat): - time_elapsed = timeit.timeit(wrapped_func(kjt, test_func, fn_kwargs), number=1) + time_elapsed = timeit.timeit(lambda: test_func(**fn_kwargs), number=1) # remove length_per_key and offset_per_key cache for fairer comparison kjt.unsync() times.append(time_elapsed) @@ -112,7 +108,7 @@ def test_func(kjt: KeyedJaggedTensor, **kwargs): ) print( - f" {test_name : <{35}} | B: {batch_size : <{8}} | F: {num_features : <{8}} | Mean Pooling Factor: {mean_pooling_factor : <{8}} | Runtime (P50): {result.runtime_percentile(50, interpolation='linear'):5f} ms | Runtime (P90): {result.runtime_percentile(90, interpolation='linear'):5f} ms" + f" {test_name : <{35}} | JIT Script: {'Yes' if jit_script else 'No' : <{8}} | B: {batch_size : <{8}} | F: {num_features : <{8}} | Mean Pooling Factor: {mean_pooling_factor : <{8}} | Runtime (P50): {result.runtime_percentile(50, interpolation='linear'):5f} ms | Runtime (P90): {result.runtime_percentile(90, interpolation='linear'):5f} ms" ) @@ -148,6 +144,31 @@ def gen_dist_split_input( return (kjt_lengths, kjt_values, batch_size_per_rank, recat) +@torch.jit.script +def permute(kjt: KeyedJaggedTensor, indices: List[int]) -> KeyedJaggedTensor: + return kjt.permute(indices) + + +@torch.jit.script +def todict(kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]: + return kjt.to_dict() + + +@torch.jit.script +def split(kjt: KeyedJaggedTensor, segments: List[int]) -> List[KeyedJaggedTensor]: + return kjt.split(segments) + + +@torch.jit.script +def getitem(kjt: KeyedJaggedTensor, key: str) -> JaggedTensor: + return kjt[key] + + +@torch.jit.script +def dist_splits(kjt: KeyedJaggedTensor, key_splits: List[int]) -> List[List[int]]: + return kjt.dist_splits(key_splits) + + def bench( num_repeat: int, num_warmup: int, @@ -184,12 +205,13 @@ def bench( tables, batch_size, num_workers, num_features, mean_pooling_factor, device ) - benchmarked_methods: List[Tuple[str, Dict[str, Any], bool]] = [ - ("permute", {"indices": permute_indices}, False), - ("to_dict", {}, False), - ("split", {"segments": splits}, False), - ("__getitem__", {"key": key}, False), - ("dist_splits", {"key_splits": splits}, False), + # pyre-ignore[33] + benchmarked_methods: List[Tuple[str, Dict[str, Any], bool, Callable[..., Any]]] = [ + ("permute", {"indices": permute_indices}, False, permute), + ("to_dict", {}, False, todict), + ("split", {"segments": splits}, False, split), + ("__getitem__", {"key": key}, False, getitem), + ("dist_splits", {"key_splits": splits}, False, dist_splits), ( "dist_init", { @@ -206,12 +228,33 @@ def bench( "stride_per_rank": strides_per_rank, }, True, # is static method + torch.jit.script(KeyedJaggedTensor.dist_init), ), ] - for method_name, fn_kwargs, is_static_method in benchmarked_methods: + for method_name, fn_kwargs, is_static_method, jit_func in benchmarked_methods: + test_func = getattr(KeyedJaggedTensor if is_static_method else kjt, method_name) + benchmark_kjt( + test_name=method_name, + test_func=test_func, + kjt=kjt, + num_repeat=num_repeat, + num_warmup=num_warmup, + num_features=num_features, + batch_size=batch_size, + mean_pooling_factor=mean_pooling_factor, + fn_kwargs=fn_kwargs, + is_static_method=is_static_method, + jit_script=False, + ) + + if not is_static_method: + # Explicitly pass in KJT for instance methods + fn_kwargs = {"kjt": kjt, **fn_kwargs} + benchmark_kjt( - method_name=method_name, + test_name=method_name, + test_func=jit_func, kjt=kjt, num_repeat=num_repeat, num_warmup=num_warmup, @@ -220,6 +263,7 @@ def bench( mean_pooling_factor=mean_pooling_factor, fn_kwargs=fn_kwargs, is_static_method=is_static_method, + jit_script=True, )