diff --git a/finetune/train_cogvideox_image_to_video_lora.py b/finetune/train_cogvideox_image_to_video_lora.py index 3a62eff1..edbfc0fd 100644 --- a/finetune/train_cogvideox_image_to_video_lora.py +++ b/finetune/train_cogvideox_image_to_video_lora.py @@ -573,36 +573,36 @@ def _load_dataset_from_local_path(self): return instance_prompts, instance_videos def _resize_for_rectangle_crop(self, arr): - image_size = self.height, self.width - reshape_mode = self.video_reshape_mode - if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: - arr = resize( - arr, - size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], - interpolation=InterpolationMode.BICUBIC, - ) - else: - arr = resize( - arr, - size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], - interpolation=InterpolationMode.BICUBIC, - ) + image_size = self.height, self.width + reshape_mode = self.video_reshape_mode + if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: + arr = resize( + arr, + size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], + interpolation=InterpolationMode.BICUBIC, + ) + else: + arr = resize( + arr, + size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], + interpolation=InterpolationMode.BICUBIC, + ) - h, w = arr.shape[2], arr.shape[3] - arr = arr.squeeze(0) + h, w = arr.shape[2], arr.shape[3] + arr = arr.squeeze(0) - delta_h = h - image_size[0] - delta_w = w - image_size[1] + delta_h = h - image_size[0] + delta_w = w - image_size[1] - if reshape_mode == "random" or reshape_mode == "none": - top = np.random.randint(0, delta_h + 1) - left = np.random.randint(0, delta_w + 1) - elif reshape_mode == "center": - top, left = delta_h // 2, delta_w // 2 - else: - raise NotImplementedError - arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) - return arr + if reshape_mode == "random" or reshape_mode == "none": + top = np.random.randint(0, delta_h + 1) + left = np.random.randint(0, delta_w + 1) + elif reshape_mode == "center": + top, left = delta_h // 2, delta_w // 2 + else: + raise NotImplementedError + arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) + return arr def _preprocess_data(self): try: @@ -622,8 +622,7 @@ def _preprocess_data(self): videos = [] for filename in self.instance_video_paths: - progress_dataset_bar.update(1) - video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) + video_reader = decord.VideoReader(uri=filename.as_posix()) video_num_frames = len(video_reader) start_frame = min(self.skip_frames_start, video_num_frames) @@ -651,8 +650,12 @@ def _preprocess_data(self): # Training transforms frames = (frames - 127.5) / 127.5 frames = frames.permute(0, 3, 1, 2) # [F, C, H, W] + progress_dataset_bar.set_description( + f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}" + ) frames = self._resize_for_rectangle_crop(frames) videos.append(frames.contiguous()) # [F, C, H, W] + progress_dataset_bar.update(1) progress_dataset_bar.close()