From 001fcac31faa6133d0c4a1aec828ed7781939506 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Sat, 11 May 2024 02:27:14 +0000 Subject: [PATCH] fix stopping_criteria check --- src/open_clip/coca_model.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/open_clip/coca_model.py b/src/open_clip/coca_model.py index abba8b53b..eb4e88726 100644 --- a/src/open_clip/coca_model.py +++ b/src/open_clip/coca_model.py @@ -23,6 +23,8 @@ RepetitionPenaltyLogitsProcessor, MinLengthLogitsProcessor, MaxLengthCriteria, + StopStringCriteria, + EosTokenCriteria, StoppingCriteriaList ) @@ -298,7 +300,12 @@ def generate( cur_len += 1 - if stopping_criteria(out, None).any(): + is_done = False + if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria: + is_done = stopping_criteria(out, None).all() + else: + is_done = stopping_criteria(out, None).any() + if is_done: break if num_dims == 1: @@ -439,7 +446,14 @@ def _generate_beamsearch( # increase cur_len cur_len = cur_len + 1 - if beam_scorer.is_done or stopping_criteria(input_ids, None).any(): + is_done = False + if EosTokenCriteria in stopping_criteria or StopStringCriteria in stopping_criteria: + is_done = stopping_criteria(input_ids, None).all() + else: + is_done = stopping_criteria(input_ids, None).any() + if is_done: + break + if beam_scorer.is_done or is_done: break final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None