Skip to content

Commit

Permalink
Fix stopping_criteria result check in coca_model (#860)
Browse files Browse the repository at this point in the history
* fix stopping_criteria check in coca_model

* fix stopping_criteria check
  • Loading branch information
MengqingCao authored Jun 22, 2024
1 parent 58e4e39 commit 45b43c9
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions src/open_clip/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
MaxLengthCriteria,
StopStringCriteria,
EosTokenCriteria,
StoppingCriteriaList
)

Expand Down Expand Up @@ -311,7 +313,12 @@ def generate(

cur_len += 1

if stopping_criteria(out, None):
is_done = False
if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria:
is_done = stopping_criteria(out, None).all()
else:
is_done = stopping_criteria(out, None).any()
if is_done:
break

if num_dims == 1:
Expand Down Expand Up @@ -453,7 +460,14 @@ def _generate_beamsearch(

# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or stopping_criteria(input_ids, None):
is_done = False
if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria:
is_done = stopping_criteria(input_ids, None).all()
else:
is_done = stopping_criteria(input_ids, None).any()
if is_done:
break
if beam_scorer.is_done or is_done:
break

final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
Expand Down

0 comments on commit 45b43c9

Please sign in to comment.