Skip to content

Commit

Permalink
Add jit script KJT benchmarks
Browse files Browse the repository at this point in the history
Summary: Add benchmarks for jit scripted KJT methods

Reviewed By: gnahzg

Differential Revision: D57701618
  • Loading branch information
sarckk authored and facebook-github-bot committed May 23, 2024
1 parent a4b0602 commit 254a80c
Showing 1 changed file with 65 additions and 21 deletions.
86 changes: 65 additions & 21 deletions torchrec/sparse/tests/keyed_jagged_tensor_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from torchrec.distributed.test_utils.test_model import ModelInput
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor


def generate_kjt(
Expand Down Expand Up @@ -69,15 +69,17 @@ def wrapped_func(
kjt: KeyedJaggedTensor,
test_func: Callable[[KeyedJaggedTensor], object],
fn_kwargs: Dict[str, Any],
jit_script: bool,
) -> Callable[..., object]:
def fn() -> object:
return test_func(kjt, **fn_kwargs)

return fn
return fn if jit_script else torch.jit.script(fn)


def benchmark_kjt(
method_name: str,
test_name: str,
test_func: Callable[..., object],
kjt: KeyedJaggedTensor,
num_repeat: int,
num_warmup: int,
Expand All @@ -86,21 +88,15 @@ def benchmark_kjt(
mean_pooling_factor: int,
fn_kwargs: Dict[str, Any],
is_static_method: bool,
jit_script: bool,
) -> None:
test_name = method_name

# pyre-ignore
def test_func(kjt: KeyedJaggedTensor, **kwargs):
return getattr(KeyedJaggedTensor if is_static_method else kjt, method_name)(
**kwargs
)

for _ in range(num_warmup):
test_func(kjt, **fn_kwargs)
test_func(**fn_kwargs)

times = []
for _ in range(num_repeat):
time_elapsed = timeit.timeit(wrapped_func(kjt, test_func, fn_kwargs), number=1)
time_elapsed = timeit.timeit(lambda: test_func(**fn_kwargs), number=1)
# remove length_per_key and offset_per_key cache for fairer comparison
kjt.unsync()
times.append(time_elapsed)
Expand All @@ -112,7 +108,7 @@ def test_func(kjt: KeyedJaggedTensor, **kwargs):
)

print(
f" {test_name : <{35}} | B: {batch_size : <{8}} | F: {num_features : <{8}} | Mean Pooling Factor: {mean_pooling_factor : <{8}} | Runtime (P50): {result.runtime_percentile(50, interpolation='linear'):5f} ms | Runtime (P90): {result.runtime_percentile(90, interpolation='linear'):5f} ms"
f" {test_name : <{35}} | JIT Script: {'Yes' if jit_script else 'No' : <{8}} | B: {batch_size : <{8}} | F: {num_features : <{8}} | Mean Pooling Factor: {mean_pooling_factor : <{8}} | Runtime (P50): {result.runtime_percentile(50, interpolation='linear'):5f} ms | Runtime (P90): {result.runtime_percentile(90, interpolation='linear'):5f} ms"
)


Expand Down Expand Up @@ -148,6 +144,31 @@ def gen_dist_split_input(
return (kjt_lengths, kjt_values, batch_size_per_rank, recat)


@torch.jit.script
def permute(kjt: KeyedJaggedTensor, indices: List[int]) -> KeyedJaggedTensor:
return kjt.permute(indices)


@torch.jit.script
def todict(kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]:
return kjt.to_dict()


@torch.jit.script
def split(kjt: KeyedJaggedTensor, segments: List[int]) -> List[KeyedJaggedTensor]:
return kjt.split(segments)


@torch.jit.script
def getitem(kjt: KeyedJaggedTensor, key: str) -> JaggedTensor:
return kjt[key]


@torch.jit.script
def dist_splits(kjt: KeyedJaggedTensor, key_splits: List[int]) -> List[List[int]]:
return kjt.dist_splits(key_splits)


def bench(
num_repeat: int,
num_warmup: int,
Expand Down Expand Up @@ -184,12 +205,13 @@ def bench(
tables, batch_size, num_workers, num_features, mean_pooling_factor, device
)

benchmarked_methods: List[Tuple[str, Dict[str, Any], bool]] = [
("permute", {"indices": permute_indices}, False),
("to_dict", {}, False),
("split", {"segments": splits}, False),
("__getitem__", {"key": key}, False),
("dist_splits", {"key_splits": splits}, False),
# pyre-ignore[33]
benchmarked_methods: List[Tuple[str, Dict[str, Any], bool, Callable[..., Any]]] = [
("permute", {"indices": permute_indices}, False, permute),
("to_dict", {}, False, todict),
("split", {"segments": splits}, False, split),
("__getitem__", {"key": key}, False, getitem),
("dist_splits", {"key_splits": splits}, False, dist_splits),
(
"dist_init",
{
Expand All @@ -206,12 +228,33 @@ def bench(
"stride_per_rank": strides_per_rank,
},
True, # is static method
torch.jit.script(KeyedJaggedTensor.dist_init),
),
]

for method_name, fn_kwargs, is_static_method in benchmarked_methods:
for method_name, fn_kwargs, is_static_method, jit_func in benchmarked_methods:
test_func = getattr(KeyedJaggedTensor if is_static_method else kjt, method_name)
benchmark_kjt(
test_name=method_name,
test_func=test_func,
kjt=kjt,
num_repeat=num_repeat,
num_warmup=num_warmup,
num_features=num_features,
batch_size=batch_size,
mean_pooling_factor=mean_pooling_factor,
fn_kwargs=fn_kwargs,
is_static_method=is_static_method,
jit_script=False,
)

if not is_static_method:
# Explicitly pass in KJT for instance methods
fn_kwargs = {"kjt": kjt, **fn_kwargs}

benchmark_kjt(
method_name=method_name,
test_name=method_name,
test_func=jit_func,
kjt=kjt,
num_repeat=num_repeat,
num_warmup=num_warmup,
Expand All @@ -220,6 +263,7 @@ def bench(
mean_pooling_factor=mean_pooling_factor,
fn_kwargs=fn_kwargs,
is_static_method=is_static_method,
jit_script=True,
)


Expand Down

0 comments on commit 254a80c

Please sign in to comment.