diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 48a6bc8ac..df78966b5 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -234,7 +234,7 @@ def adjust_tensors_for_parallel(hidden_states, *tensors): """ outputs = [] for tensor in tensors: - if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]: + if tensor is not None and hidden_states.shape[0] > tensor.shape[0]: repeats = [1] * len(tensor.shape) repeats[0] = hidden_states.shape[0] // tensor.shape[0] new_tensor = tensor.repeat(*repeats) @@ -249,7 +249,7 @@ def adjust_tensors_for_parallel_(hidden_states, *tensors): In-place version of adjust_tensors_for_parallel(). """ for tensor in tensors: - if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]: + if tensor is not None and hidden_states.shape[0] > tensor.shape[0]: repeats = [1] * len(tensor.shape) repeats[0] = hidden_states.shape[0] // tensor.shape[0] new_tensor = tensor.repeat(*repeats)