diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index abba8b53b..eb4e88726 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -23,6 +23,8 @@ RepetitionPenaltyLogitsProcessor, MinLengthLogitsProcessor, MaxLengthCriteria, + StopStringCriteria, + EosTokenCriteria, StoppingCriteriaList ) @@ -298,7 +300,12 @@ def generate( cur_len += 1 - if stopping_criteria(out, None).any(): + is_done = False + if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria: + is_done = stopping_criteria(out, None).all() + else: + is_done = stopping_criteria(out, None).any() + if is_done: break if num_dims == 1: @@ -439,7 +446,14 @@ def _generate_beamsearch( # increase cur_len cur_len = cur_len + 1 - if beam_scorer.is_done or stopping_criteria(input_ids, None).any(): + is_done = False + if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria: + is_done = stopping_criteria(input_ids, None).all() + else: + is_done = stopping_criteria(input_ids, None).any() + if is_done: + break + if beam_scorer.is_done or is_done: break final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None