From 4007d8913e1cf7e068f842628d4df74510da706f Mon Sep 17 00:00:00 2001 From: "haoning.wu" Date: Sun, 21 Jan 2024 12:49:03 +0800 Subject: [PATCH] fix masking error --- q_align/train/train_mem.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/q_align/train/train_mem.py b/q_align/train/train_mem.py index a8d5c1a..eba7c84 100644 --- a/q_align/train/train_mem.py +++ b/q_align/train/train_mem.py @@ -370,7 +370,7 @@ def preprocess_v1( if has_image: round_len = len(tokenizer_image_token(rou, tokenizer)) - instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 + instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 3 else: round_len = len(tokenizer(rou).input_ids) instruction_len = len(tokenizer(parts[0]).input_ids) - 2 @@ -387,7 +387,6 @@ def preprocess_v1( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) - return dict( input_ids=input_ids, labels=targets,