Skip to content

Commit

Permalink
Only run adjust_tensors_for_parallel_ if bsz is different
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Nov 24, 2024
1 parent cb07dd4 commit 8421f63
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def adjust_tensors_for_parallel(hidden_states, *tensors):
"""
outputs = []
for tensor in tensors:
if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]:
if tensor is not None and hidden_states.shape[0] > tensor.shape[0]:
repeats = [1] * len(tensor.shape)
repeats[0] = hidden_states.shape[0] // tensor.shape[0]
new_tensor = tensor.repeat(*repeats)
Expand All @@ -249,7 +249,7 @@ def adjust_tensors_for_parallel_(hidden_states, *tensors):
In-place version of adjust_tensors_for_parallel().
"""
for tensor in tensors:
if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]:
if tensor is not None and hidden_states.shape[0] > tensor.shape[0]:
repeats = [1] * len(tensor.shape)
repeats[0] = hidden_states.shape[0] // tensor.shape[0]
new_tensor = tensor.repeat(*repeats)
Expand Down

0 comments on commit 8421f63

Please sign in to comment.