diff --git a/dataset/cityscapes/filter_dataset.py b/dataset/cityscapes/filter_dataset.py index 715cc92..29073f3 100644 --- a/dataset/cityscapes/filter_dataset.py +++ b/dataset/cityscapes/filter_dataset.py @@ -1,12 +1,12 @@ # coding: utf-8 __author__ = 'RocketFlash: https://github.com/RocketFlash' -from tqdm import tqdm from scipy.io import loadmat from shutil import copyfile, rmtree from pathlib import Path import os from config import data_filtering as cfg +import argparse ''' class_label =0: ignore regions (fake humans, e.g. people on posters, reflections etc.) @@ -41,7 +41,7 @@ def filter_data(mat_file_path, allowed_classes=['pedestrian'], filtered_file_names = [] - for img_idx in tqdm(range(len(mat)), leave=True, position=0): + for img_idx in tqdm(range(len(mat)), position=0): img_anno = mat[img_idx][0, 0] img_name_with_ext = img_anno[1][0] @@ -62,6 +62,15 @@ def filter_data(mat_file_path, allowed_classes=['pedestrian'], if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--notebook', action='store_true', help='true if run it in jupyter notebook') + args = parser.parse_args() + + if args.notebook: + from tqdm.notebook import tqdm + else: + from tqdm import tqdm + print('START DATASET FILTERING') if cfg.dataset_type=='citypersons': DATASET_ROOT = cfg.dataset_root diff --git a/dataset/cityscapes/generate_and_filter_dataset.sh b/dataset/cityscapes/generate_and_filter_dataset.sh index 77f1162..c8391db 100755 --- a/dataset/cityscapes/generate_and_filter_dataset.sh +++ b/dataset/cityscapes/generate_and_filter_dataset.sh @@ -1,3 +1,9 @@ #!/bin/bash -python generate_dataset.py -python filter_dataset.py \ No newline at end of file + +if [ ! -z $1 ]; then + python generate_dataset.py --notebook + python filter_dataset.py --notebook +else + python generate_dataset.py + python filter_dataset.py +fi \ No newline at end of file diff --git a/dataset/cityscapes/generate_dataset.py b/dataset/cityscapes/generate_dataset.py index be99fca..948a7f4 100644 --- a/dataset/cityscapes/generate_dataset.py +++ b/dataset/cityscapes/generate_dataset.py @@ -5,9 +5,9 @@ import cv2 import numpy as np from pathlib import Path -from tqdm import tqdm from shutil import rmtree from config import data_generation as cfg +import argparse NAME_TO_ID = { 'pedestrian': 24 @@ -22,19 +22,19 @@ def generate_object_dataset_cityscapes(annotations_path, images_path, save_dir, postfix_instance = 'gtFine_instanceIds' obj_id = NAME_TO_ID[object_name] - for split_dir in tqdm(split_dirs, position=0, leave=True, desc="Total splits"): + for split_dir in tqdm(split_dirs, position=0, desc="Total splits"): SPLIT_ANNOS = annotations_path / split_dir SPLIT_IMAGES = images_path / split_dir CITY_DIRS = SPLIT_ANNOS.glob('*/') CITY_NAMES = [city_path.name for city_path in CITY_DIRS] - for city_name in tqdm(CITY_NAMES, position=1, leave=True, desc=f'{split_dir:5} split cities: '): + for city_name in tqdm(CITY_NAMES, position=1, desc=f'{split_dir:5} split cities: '): city_images_path = SPLIT_IMAGES / city_name city_annos_path = SPLIT_ANNOS / city_name city_images = city_images_path.glob('*.png') - for city_image in tqdm(list(city_images), leave=True, position=2, desc=f'{city_name:15} city images: '): + for city_image in tqdm(list(city_images), position=2, desc=f'{city_name:15} city images: '): mask_label_path = city_annos_path / city_image.name.replace(postfix_image, postfix_label) mask_instance_path = city_annos_path / city_image.name.replace(postfix_image, postfix_instance) @@ -76,7 +76,14 @@ def generate_object_dataset_cityscapes(annotations_path, images_path, save_dir, cv2.imwrite(str(obj_region_image_path), obj_rgba) if __name__ == '__main__': - + parser = argparse.ArgumentParser() + parser.add_argument('--notebook', action='store_true', help='true if run it in jupyter notebook') + args = parser.parse_args() + + if args.notebook: + from tqdm.notebook import tqdm + else: + from tqdm import tqdm print('START DATASET GENERATION') if cfg.dataset_type=='cityscapes': DATASET_ROOT = cfg.dataset_root