From 8280d220502b12da466b9ec940d55167ce6323de Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Thu, 19 Dec 2024 13:56:25 +0800 Subject: [PATCH] add assert in get_assigned_chunk to check if tensor.shape[dim] is divisible by world_size --- src/para_attn/primitives.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/para_attn/primitives.py b/src/para_attn/primitives.py index 89685f0..3d0fd9e 100644 --- a/src/para_attn/primitives.py +++ b/src/para_attn/primitives.py @@ -119,6 +119,8 @@ def get_assigned_chunk( if idx is None: idx = get_rank(group) world_size = get_world_size(group) + total_size = tensor.shape[dim] + assert total_size % world_size == 0, f"tensor.shape[{dim}]={total_size} is not divisible by world_size={world_size}" return tensor.chunk(world_size, dim=dim)[idx]