diff --git a/IndexKits/README.md b/IndexKits/README.md new file mode 100644 index 0000000..ec90e55 --- /dev/null +++ b/IndexKits/README.md @@ -0,0 +1,148 @@ +# IndexKits + +[TOC] + +## Introduction + +Index Kits (`index_kits`) for streaming Arrow data. + +* Supports creating datasets from configuration files. +* Supports creating index v2 from arrows. +* Supports the creation, reading, and usage of index files. +* Supports the creation, reading, and usage of multi-bucket, multi-resolution indexes. + +`index_kits` has the following advantages: + +* Optimizes memory usage for streaming reads, supporting the reading of data at the billion level. +* Optimizes memory usage and reading speed of Index files. +* Supports the creation of Base/Multireso Index V2 datasets, including data filtering, repeating, and deduplication during creation. +* Supports loading various dataset types (Arrow files, Base Index V2, multiple Base Index V2, Multireso Index V2, multiple Multireso Index V2). +* Uses a unified API interface to access images, text, and other attributes for all dataset types. +* Built-in shuffle, `resize_and_crop`, and returns the coordinates of the crop. + +## Installation + +* **Install from pre-compiled whl (recommended)** + +* Install from source + + ```shell + cd IndexKits + pip install -e . + ``` + +## Usage + +### Loading an Index V2 File + +```python +from index_kits import ArrowIndexV2 + +index_manager = ArrowIndexV2('data.json') + +# You can shuffle the dataset using a seed. If a seed is provided (i.e., the seed is not None), this shuffle will not affect any random state. +index_manager.shuffle(1234, fast=True) + +for i in range(len(index_manager)): + # Get an image (requires the arrow to contain an image/binary column): + pil_image = index_manager.get_image(i) + # Get text (requires the arrow to contain a text_zh column): + text = index_manager.get_attribute(i, column='text_zh') + # Get MD5 (requires the arrow to contain an md5 column): + md5 = index_manager.get_md5(i) + # Get any attribute by specifying the column (must be contained in the arrow): + ocr_num = index_manager.get_attribute(i, column='ocr_num') + # Get multiple attributes at once + item = index_manager.get_data(i, columns=['text_zh', 'md5']) # i: in-dataset index + print(item) + # { + # 'index': 3, # in-json index + # 'in_arrow_index': 3, # in-arrow index + # 'arrow_name': '/HunYuanDiT/dataset/porcelain/00000.arrow', + # 'text_zh': 'Fortune arrives with the auspicious green porcelain tea cup', + # 'md5': '1db68f8c0d4e95f97009d65fdf7b441c' + # } +``` + + +### Loading a Set of Arrow Files + +If you have a small batch of arrow files, you can directly use `IndexV2Builder` to load these arrow files without creating an Index V2 file. + +```python +from index_kits import IndexV2Builder + +index_manager = IndexV2Builder(arrow_files).to_index_v2() +``` + +### Loading a Multi-Bucket (Multi-Resolution) Index V2 File + +When using a multi-bucket (multi-resolution) index file, refer to the following example code, especially the definition of SimpleDataset and the usage of MultiResolutionBucketIndexV2. + +```python +from torch.utils.data import DataLoader, Dataset +import torchvision.transforms as T + +from index_kits import MultiResolutionBucketIndexV2 +from index_kits.sampler import BlockDistributedSampler + +class SimpleDataset(Dataset): + def __init__(self, index_file, batch_size, world_size): + # When using multi-bucket (multi-resolution), batch_size and world_size need to be specified for underlying data alignment. + self.index_manager = MultiResolutionBucketIndexV2(index_file, batch_size, world_size) + + self.flip_norm = T.Compose( + [ + T.RandomHorizontalFlip(), + T.ToTensor(), + T.Normalize([0.5], [0.5]), + ] + ) + + def shuffle(self, seed, fast=False): + self.index_manager.shuffle(seed, fast=fast) + + def __len__(self): + return len(self.index_manager) + + def __getitem__(self, idx): + image = self.index_manager.get_image(idx) + original_size = image.size # (w, h) + target_size = self.index_manager.get_target_size(idx) # (w, h) + image, crops_coords_left_top = self.index_manager.resize_and_crop(image, target_size) + image = self.flip_norm(image) + text = self.index_manager.get_attribute(idx, column='text_zh') + return image, text, (original_size, target_size, crops_coords_left_top) + +batch_size = 8 # batch_size per GPU +world_size = 8 # total number of GPUs +rank = 0 # rank of the current process +num_workers = 4 # customizable based on actual conditions +shuffle = False # must be set to False +drop_last = True # must be set to True + +# Correct batch_size and world_size must be passed in, otherwise, the multi-resolution data cannot be aligned correctly. +dataset = SimpleDataset('data_multireso.json', batch_size=batch_size, world_size=world_size) +# Must use BlockDistributedSampler to ensure the samples in a batch have the same resolution. +sampler = BlockDistributedSampler(dataset, num_replicas=world_size, rank=rank, + shuffle=shuffle, drop_last=drop_last, batch_size=batch_size) +loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, + num_workers=num_workers, pin_memory=True, drop_last=drop_last) + +for epoch in range(10): + # Please use the shuffle method provided by the dataset, not the DataLoader's shuffle parameter. + dataset.shuffle(epoch, fast=True) + for batch in loader: + pass +``` + + +## Fast Shuffle + +When your index v2 file contains hundreds of millions of samples, using the default shuffle can be quite slow. Therefore, it’s recommended to enable `fast=True` mode: + +```python +index_manager.shuffle(seed=1234, fast=True) +``` + +This will globally shuffle without keeping the indices of the same arrow together. Although this may reduce reading speed, it requires a trade-off based on the model’s forward time. \ No newline at end of file diff --git a/IndexKits/bin/idk b/IndexKits/bin/idk new file mode 100644 index 0000000..0ed3441 --- /dev/null +++ b/IndexKits/bin/idk @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 + +import argparse + +from index_kits.dataset.make_dataset_core import startup, make_multireso +from index_kits.common import show_index_info +from index_kits import __version__ + + +def common_args(parser): + parser.add_argument('-t', '--target', type=str, required=True, help='Save path') + + +def get_args(): + parser = argparse.ArgumentParser(description=""" + IndexKits is a tool to build and manage index files for large-scale datasets. + It supports both base index and multi-resolution index. + + Introduction + ------------ + This command line tool provides the following functionalities: + 1. Show index v2 information + 2. Build base index v2 + 3. Build multi-resolution index v2 + + Examples + -------- + 1. Show index v2 information + index_kits show /path/to/index.json + + 2. Build base index v2 + Default usage: + index_kits base -c /path/to/config.yaml -t /path/to/index.json + + Use multiple processes: + index_kits base -c /path/to/config.yaml -t /path/to/index.json -w 40 + + 3. Build multi-resolution index v2 + + Build with a configuration file: + index_kits multireso -c /path/to/config.yaml -t /path/to/index_mb_gt512.json + + Build by specifying arguments without a configuration file: + index_kits multireso --src /path/to/index.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json + + Build by specifying target-ratios: + index_kits multireso --src /path/to/index.json --base-size 512 --target-ratios 1:1 4:3 3:4 16:9 9:16 --min-size 512 -t /path/to/index_mb_gt512.json + + Build with multiple source index files. + index_kits multireso --src /path/to/index1.json /path/to/index2.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json + """, formatter_class=argparse.RawTextHelpFormatter) + sub_parsers = parser.add_subparsers(dest='task', required=True) + + # Show index message + show_parser = sub_parsers.add_parser('show', description=""" + Show base/multireso index v2 information. + + Example + ------- + index_kits show /path/to/index.json + """, formatter_class=argparse.RawTextHelpFormatter) + show_parser.add_argument('src', type=str, help='Path to a base/multireso index file.') + show_parser.add_argument('--arrow-files', action='store_true', help='Show arrow files only.') + show_parser.add_argument('--depth', type=int, default=1, + help='Arrow file depth. Default is 1, the level of last folder in the arrow file path. ' + 'Set it to 0 to show the full path including `xxx/last_folder/*.arrow`.') + + # Single resolution bucket + base_parser = sub_parsers.add_parser('base', description=""" + Build base index v2. + + Example + ------- + index_kits base -c /path/to/config.yaml -t /path/to/index.json + """, formatter_class=argparse.RawTextHelpFormatter) + base_parser.add_argument('-c', '--config', type=str, required=True, help='Configuration file path') + common_args(base_parser) + base_parser.add_argument('-w', '--world-size', type=int, default=1) + base_parser.add_argument('--work-dir', type=str, default='.', help='Work directory') + base_parser.add_argument('--use-cache', action='store_true', help='Use cache to avoid reprocessing. ' + 'Perform merge pkl results directly.') + + # Multi-resolution bucket + mo_parser = sub_parsers.add_parser('multireso', description=""" + Build multi-resolution index v2 + + Example + ------- + Build with a configuration file: + index_kits multireso -c /path/to/config.yaml -t /path/to/index_mb_gt512.json + + Build by specifying arguments without a configuration file: + index_kits multireso --src /path/to/index.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json + + Build by specifying target-ratios: + index_kits multireso --src /path/to/index.json --base-size 512 --target-ratios 1:1 4:3 3:4 16:9 9:16 --min-size 512 -t /path/to/index_mb_gt512.json + + Build with multiple source index files. + index_kits multireso --src /path/to/index1.json /path/to/index2.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json + """, formatter_class=argparse.RawTextHelpFormatter) + mo_parser.add_argument('-c', '--config', type=str, default=None, + help='Configuration file path in a yaml format. Either --config or --src must be provided.') + mo_parser.add_argument('-s', '--src', type=str, nargs='+', default=None, + help='Source index files. Either --config or --src must be provided.') + common_args(mo_parser) + mo_parser.add_argument('--base-size', type=int, default=None, help="Base size. Typically set as 256/512/1024 according to image size you train model.") + mo_parser.add_argument('--reso-step', type=int, default=None, + help="Resolution step. Either reso_step or target_ratios must be provided.") + mo_parser.add_argument('--target-ratios', type=str, nargs='+', default=None, + help="Target ratios. Either reso_step or target_ratios must be provided.") + mo_parser.add_argument('--md5-file', type=str, default=None, + help='You can provide an md5 to height and width file to accelerate the process. ' + 'It is a pickle file that contains a dict, which maps md5 to (height, width) tuple.') + mo_parser.add_argument('--align', type=int, default=16, help="Used when --target-ratios is provided. Align size of source image height and width.") + mo_parser.add_argument('--min-size', type=int, default=0, + help="Minimum size. Images smaller than this size will be ignored.") + + # Common + parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}') + + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = get_args() + if args.task == 'show': + show_index_info(args.src, + args.arrow_files, + args.depth, + ) + elif args.task == 'base': + startup(args.config, + args.target, + args.world_size, + args.work_dir, + use_cache=args.use_cache, + ) + elif args.task == 'multireso': + make_multireso(args.target, + args.config, + args.src, + args.base_size, + args.reso_step, + args.target_ratios, + args.align, + args.min_size, + args.md5_file, + ) diff --git a/IndexKits/docs/MakeDataset.md b/IndexKits/docs/MakeDataset.md new file mode 100644 index 0000000..ae23e76 --- /dev/null +++ b/IndexKits/docs/MakeDataset.md @@ -0,0 +1,357 @@ +# Use the IndexKits library to create datasets. + +[TOC] + +## Introdution + +The IndexKits library offers a command-line tool `idk` for creating datasets and viewing their statistical information. You can view the instructions by using `idk -h`. Here, it refers to creating an Index V2 format dataset from a series of arrow files. + +## 1. Creating a Base Index V2 Dataset + +### 1.1 Create a Base Index V2 dataset using `idk` + +When creating a Base Index V2 dataset, you need to specify the path to a configuration file using the `-c` parameter and a save path using the `-t` parameter. + +```shell +idk base -c base_config.yaml -t base_dataset.json +``` + +### 1.2 Basic Configuration + +Next, let’s discuss how to write the configuration file. The configuration file is in `yaml` format, and below is a basic example: + +Filename: `base_config.yaml` + +```yaml +source: + - /HunYuanDiT/dataset/porcelain/arrows/00000.arrow +``` + +| Field Name | Type | Description | +|:---------:|:---:|:----------:| +| source | Optional | Arrow List | + + +We provide an example that includes all features and fields in [full_config.yaml](./docs/full_config.yaml). + +### 1.3 Filtering Criteria + +`idk` offers two types of filtering capabilities during the dataset creation process: (1) based on columns in Arrow, and (2) based on MD5 files. + +To enable filtering criteria, add the `filter` field in the configuration file. + +#### 1.3.1 Column Filtering + +To enable column filtering, add the `column` field under the `filter` section. +Multiple column filtering criteria can be applied simultaneously, with the intersection of multiple conditions being taken. + +For example, to select data where both the length and width are greater than or equal to 512, +with the default being 1024 if the length and width are invalid: + +```yaml +filter: + column: + - name: height + type: int + action: ge + target: 512 + default: 1024 + - name: width + type: int + action: ge + target: 512 + default: 1024 +``` + +This filtering condition is equivalent to `table['height'].to_int(default=1024) >= 512 && table['width'].to_int(default=1024) >= 512`. + +Each filtering criterion includes the following fields: + +| Field Name | Type | Description | Value Range | +|:------------------:|:---:|:---------------:|:----------------------:| +| name | Required | Column name in Arrow | Column in Arrow | +| type | Required | Type of Elements in the Column | `int`, `float` or `str` | +| action | Required | Filtering Criteria | See the table below for possible values | +| target | Required | Filtering Criteria | Numeric or String | +| default | Required | Default value when the element is invalid | Numeric or String | +| arrow_file_keyword | Optional | Keywords in the Arrow file path | - | + +Below are the specific meanings of “action” and the optional values under different circumstances:: + +| Action | Description | Type | Action | Description | Type | +|:-------------:|:----------------------------------------------------:|:---------------------:|:------------:|:--------------------------------------------:|:---------------------:| +| eq | equal, `==` | `int`, `float`, `str` | ne | not equal, `!=` | `int`, `float`, `str` | +| gt | great than, `>` | `int`, `float` | lt | less than, `<` | `int`, `float` | +| ge | great or equal, `>=` | `int`, `float` | le | less or equal, `<=` | `int`, `float` | +| len_eq | str length equal, `str.len()==` | `str` | len_ne | str length not equal, `str.len()!=` | `str` | +| len_gt | str length great than, `str.len()>` | `str` | len_lt | str length less than, `str.len()<` | `str` | +| len_ge | str length great or equal, `str.len()>=` | `str` | len_le | str length less or equal, `str.len()<=` | `str` | +| contains | str contains, `str.contains(target)` | `str` | not_contains | str not contains, `str.not_contains(target)` | `str` | +| in | str in, `str.in(target)` | `str` | not_in | str not in, `str.not_in(target)` | `str` | +| lower_last_in | lower str last char in, `str.lower()[-1].in(target)` | `str` | | | | + +#### 1.3.2 MD5 Filtering + +Add an `md5` field under the `filter` section to initiate MD5 filtering. Multiple MD5 filtering criteria can be applied simultaneously, with the intersection of multiple conditions being taken. + +例如: +* `badcase.txt` is a list of MD5s, aiming to filter out the entries listed in these lists. +* `badcase.json` is a dictionary, where the key is the MD5 and the value is a text-related `tag`. The goal is to filter out specific `tags`. + +```shell +filter: + md5: + - name: badcase1 + path: + - badcase1.txt + type: list + action: in + is_valid: false + - name: badcase2 + path: badcase2.json + type: dict + action: eq + target: 'Specified tag' + is_valid: false +``` + +Each filtering criterion includes five fields: + +| Field Name | Type | Description | Value Range | +|:------------------:|:---:|:-----------------------------------------------------------------:|:------------------------------------------------------------------------------------:| +| name | Required | The name of the filtering criterion, which can be customized for ease of statistics. | - | +| path | Required | The path to the filtering file, which can be a single path or multiple paths provided in a list format. Supports `.txt`, `.json`, `.pkl` formats. | - | +| type | Required | The type of records in the filtering file. | `list` or `dict` | +| action | Required | The filtering action. | For `list`: `in`, `not_in`; For `dict`: `eq`, `ne`, `gt`, `lt`, `ge`, `le`| +| target | Optional | The filtering criterion. | Required when type is `dict` | +| is_valid | Required | Whether a hit on action+target is considered valid or invalid. | `true` or `false` | +| arrow_file_keyword | Optional | Keywords in the Arrow file path. | - | + +### 1.4 Advanced Filtering + +`idk` also supports some more advanced filtering functions. + +#### 1.4.1 Filtering Criteria Applied to Part of Arrow Files + +Using the `arrow_file_keyword` parameter, filtering criteria can be applied only to part of the Arrow files. + +For example: + +* The filtering criterion `height>=0` applies only to arrows whose path includes `human`. +* The filtering criterion “keep samples in goodcase.txt” applies only to arrows whose path includes `human`. + +```yaml +filter: + column: + - name: height + type: int + action: ge + target: 512 + default: 1024 + arrow_file_keyword: + - human + + md5: + - name: goodcase + path: goodcase.txt + type: list + action: in + is_valid: true + arrow_file_keyword: + - human +``` + +#### 1.4.2 The “or” Logic in Filtering Criteria + +By default, filtering criteria follow an “and” logic. If you want two or more filtering criteria to follow an “or” logic, you should use the `logical_or` field. The column filtering criteria listed under this field will be combined using an “or” logic. + +```yaml +filter: + column: + - logical_or: + - name: md5 + type: str + action: lower_last_in + target: '02468ace' + default: '' + - name: text_zh + type: str + action: contains + target: 'turtle|rabbit|duck|swan|peacock|hedgehog|wolf|fox|seal|crow|deer' + default: '' +``` + +Special Note: The `logical_or` field is applicable only to the `column` filtering criteria within `filter`. + +1.4.3 Excluding Certain Arrows from the Source + +While wildcards can be used to fetch multiple arrows at once, there might be instances where we want to exclude some of them. This can be achieved through the `exclude` field. Keywords listed under `exclude`, if found in the path of the current group of arrows, will result in those arrows being excluded. + +```yaml +source: + - /HunYuanDiT/dataset/porcelain/arrows/*.arrow: + exclude: + - arrow1 + - arrow2 +``` + +### 1.5 Repeating Samples + +`idk` offers the capability to repeat either all samples or specific samples during dataset creation. There are three types of repeaters: +* Directly repeating the source +* Based on keywords in the Arrow file name (enable repeat conditions by adding the `repeater` field in the configuration file) +* Based on an MD5 file (enable repeat conditions by adding the `repeater` field in the configuration file) + +**Special Note:** The above three conditions can be used simultaneously. If a sample meets multiple repeat conditions, the highest number of repeats will be taken. + +#### 1.5.1 Repeating the Source + +In the source, for the Arrow(s) you want to repeat (which can include wildcards like `*`, `?`), add `repeat: n` to mark the number of times to repeat. + +```yaml +source: + - /HunYuanDiT/dataset/porcelain/arrows/*.arrow: + repeat: 10 +``` + +**Special Note: Add a colon at the end of the Arrow path** + +#### 1.5.2 Arrow Keywords + +Add an `arrow_file_keyword` field under the `repeater` section. + +```yaml +repeater: + arrow_file_keyword: + - repeat: 8 + keyword: + - Lolita anime style + - Minimalist style + - repeat: 5 + keyword: + - Magical Barbie style + - Disney style +``` + +Each repeat condition includes two fields: + + +| Field Name | Type | Description | Value Range | +|:-------:|:---:|:---------------:|:----:| +| repeat | Required | The number of times to repeat | Number | +| keyword | Required | Keywords in the Arrow file path | - | + + +#### 1.5.3 MD5 File + +Add an `md5` field under the `repeater` section. + +```yaml +repeater: + md5: + - name: goodcase1 + path: /HunYuanDiT/dataset/porcelain/md5_repeat_1.json + type: dict + plus: 3 + - name: goodcase2 + path: /HunYuanDiT/dataset/porcelain/md5_repeat_2.json + type: list + repeat: 6 +``` + +Each repeat condition includes the following fields: + + +| Field Name | Type | Description | Value Range | +|:------:|:---:|:----------------------------------------------:|:---------------------------:| +| name | Required | Custom name for the repeat condition | - | +| path | Required | Path to the file containing MD5s | `.txt`, `.json`, or `.pkl` formats | +| type | Required | Type of the MD5 file | `list` or `dict` | +| repeat | Optional | Number of times to repeat, this will override plus | Integer | +| plus | Optional | Adds this value to the value obtained from the MD5 file as the number of repeats | Integer | + +### 1.6 Deduplication + +To deduplicate, add `remove_md5_dup` and set it to `true`. For example: + +```yaml +remove_md5_dup: true +``` + +**Special Note:** Deduplication is performed after repeat conditions, so using them together will make repeat ineffective. + +### 1.7 Creating a Base Index V2 Dataset with Python Code + +```python +from index_kits import IndexV2Builder + +builder = IndexV2Builder(arrow_files) +builder.save('data_v2.json') +``` + + + +## 2. Creating a Multireso Index V2 Dataset + +### 2.1 Creating a Multireso Index V2 Dataset with `idk` +To create a multi-resolution dataset, you can do so through a configuration file: + +```shell +src: + - /HunYuanDiT/dataset/porcelain/jsons/a.json + - /HunYuanDiT/dataset/porcelain/jsons/b.json + - /HunYuanDiT/dataset/porcelain/jsons/c.json +base_size: 512 +reso_step: 32 +min_size: 512 +``` + +The fields are as follows: + +| Field Name | Type | Description | Value Range | +|:-------------:|:---:|:------------------------------------------------------:|:-------------------------------:| +| src | Required | Path(s) to the Base Index V2 file, can be single or multiple | - | +| base_size | Required | The base resolution (n, n) from which to create multiple resolutions | Recommended values: 256/512/1024 | +| reso_step | Optional | The step size for traversing multiple resolutions. Choose either this or target_ratios | Recommended values: 16/32/64 | +| target_ratios | Optional | Target aspect ratios, a list. Choose either this or reso_step | Recommended values: 1:1, 4:3, 3:4, 16:9, 9:16 | +| align | Optional | When using target_ratios, the multiple to align the target resolution to | Recommended value: 16 (2x patchify, 8x VAE) | +| min_size | Optional | The minimum resolution filter for samples when creating from Base to Multireso Index V2 | Recommended values: 256/512/1024 | +| md5_file | Optional | A pre-calculated dictionary of image sizes in pkl format, key is MD5, value is (h, w) | - | + + +### 2.2 Creating a Multireso Index V2 Dataset with Python Code + +First, import the necessary function + +```python +from index_kits import build_multi_resolution_bucket + +md5_hw = None +# If many arrows in your Index V2 lack height and width, you can pre-calculate them and pass through the md5_hw parameter. +# md5_hw = {'c67be1d8f30fd0edcff6ac99b703879f': (720, 1280), ...} +# with open('md5_hw.pkl', 'rb') as f: +# md5_hw = pickle.load(f) + +index_v2_file_path = 'data_v2.json' +index_v2_save_path = 'data_multireso.json' + +# Method 1: Given base_size and reso_step, automatically calculate all resolution buckets. +build_multi_resolution_bucket(base_size=1024, + reso_step=64, + min_size=1024, + src_index_files=index_v2_file_path, + save_file=index_v2_save_path, + md5_hw=md5_hw) + +# Method 2: Given a series of target aspect ratios, automatically calculate all resolution buckets. +build_multi_resolution_bucket(base_size=1024, + target_ratios=["1:1", "3:4", "4:3", "16:9", "9:16"], + align=16, + min_size=1024, + src_index_files=index_v2_file_path, + save_file=index_v2_save_path, + md5_hw=md5_hw) +``` + +**Note:** If both reso_step and target_ratios are provided, target_ratios will be prioritized. + diff --git a/IndexKits/docs/full_config.yaml b/IndexKits/docs/full_config.yaml new file mode 100644 index 0000000..e5a3a6c --- /dev/null +++ b/IndexKits/docs/full_config.yaml @@ -0,0 +1,36 @@ +source: + - /HunyuanDiT/dataset/task1/*.arrow + - /HunyuanDiT/dataset/task2/*.arrow: + repeat: 2 + +remove_md5_dup: false + +filter: + column: + - name: height + type: int + action: ge + target: 512 + default: 1024 + - name: width + type: int + action: ge + target: 512 + default: 1024 + md5: + - name: Badcase + path: /path/to/bad_case_md5.txt + type: list + action: in + is_valid: false + +repeater: + md5: + - name: Hardcase Repeat + path: /path/to/hard_case_md5.json + type: dict + plus: 3 + - name: Goodcase Repeat + path: /path/to/good_case_md5.txt + type: list + repeat: 6 diff --git a/IndexKits/index_kits/__init__.py b/IndexKits/index_kits/__init__.py new file mode 100644 index 0000000..987486d --- /dev/null +++ b/IndexKits/index_kits/__init__.py @@ -0,0 +1,11 @@ +from .bucket import ( + MultiIndexV2, + MultiResolutionBucketIndexV2, MultiMultiResolutionBucketIndexV2, + build_multi_resolution_bucket +) +from .bucket import Resolution, ResolutionGroup +from .indexer import IndexV2Builder, ArrowIndexV2 +from .common import load_index, show_index_info + +__version__ = "0.4.0" + diff --git a/IndexKits/index_kits/bucket.py b/IndexKits/index_kits/bucket.py new file mode 100644 index 0000000..72b63e6 --- /dev/null +++ b/IndexKits/index_kits/bucket.py @@ -0,0 +1,888 @@ +import bisect +import json +import random +from pathlib import Path +from typing import Optional, Dict, List, Callable, Union + +import numpy as np +from tqdm import tqdm +from PIL import Image + +from .indexer import ArrowIndexV2, IndexV2Builder + + +class Resolution(object): + def __init__(self, size, *args): + if isinstance(size, str): + if 'x' in size: + size = size.split('x') + size = (int(size[0]), int(size[1])) + else: + size = int(size) + if len(args) > 0: + size = (size, args[0]) + if isinstance(size, int): + size = (size, size) + + self.h = self.height = size[0] + self.w = self.width = size[1] + self.r = self.ratio = self.height / self.width + + def __getitem__(self, idx): + if idx == 0: + return self.h + elif idx == 1: + return self.w + else: + raise IndexError(f'Index {idx} out of range') + + def __str__(self): + return f'{self.h}x{self.w}' + + +class ResolutionGroup(object): + def __init__(self, base_size=None, step=None, align=1, target_ratios=None, enlarge=1, data=None): + self.enlarge = enlarge + + if data is not None: + self.data = data + mid = len(self.data) // 2 + self.base_size = self.data[mid].h + self.step = self.data[mid].h - self.data[mid - 1].h + else: + self.align = align + self.base_size = base_size + assert base_size % align == 0, f'base_size {base_size} is not divisible by align {align}' + if base_size is not None and not isinstance(base_size, int): + raise ValueError(f'base_size must be None or int, but got {type(base_size)}') + if step is None and target_ratios is None: + raise ValueError(f'Either step or target_ratios must be provided') + if step is not None and step > base_size // 2: + raise ValueError(f'step must be smaller than base_size // 2, but got {step} > {base_size // 2}') + + self.step = step + self.data = self.calc(target_ratios) + + self.ratio = np.array([x.ratio for x in self.data]) + self.attr = ['' for _ in range(len(self.data))] + self.prefix_space = 0 + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + def __repr__(self): + prefix = self.prefix_space * ' ' + prefix_close = (self.prefix_space - 4) * ' ' + res_str = f'ResolutionGroup(base_size={self.base_size}, step={self.step}, data=' + attr_maxlen = max([len(x) for x in self.attr] + [5]) + res_str += f'\n{prefix}ID: height width ratio {" " * max(0, attr_maxlen - 4)}count h/16 w/16 tokens\n{prefix}' + res_str += ('\n' + prefix).join([f'{i:2d}: ({x.h:4d}, {x.w:4d}) {self.ratio[i]:.4f} {self.attr[i]:>{attr_maxlen}s} ' + f'({x.h // 16:3d}, {x.w // 16:3d}) {x.h // 16 * x.w // 16:6d}' + for i, x in enumerate(self.data)]) + res_str += f'\n{prefix_close})' + return res_str + + @staticmethod + def from_list_of_hxw(hxw_list): + data = [Resolution(x) for x in hxw_list] + data = sorted(data, key=lambda x: x.ratio) + return ResolutionGroup(None, data=data) + + def calc(self, target_ratios=None): + if target_ratios is None: + return self._calc_by_step() + else: + return self._calc_by_ratio(target_ratios) + + def _calc_by_ratio(self, target_ratios): + resolutions = [] + for ratio in target_ratios: + if ratio == '1:1': + reso = Resolution(self.base_size, self.base_size) + else: + hr, wr = map(int, ratio.split(':')) + x = int((self.base_size ** 2 * self.enlarge // self.align // self.align / (hr * wr)) ** 0.5) + height = x * hr * self.align + width = x * wr * self.align + reso = Resolution(height, width) + resolutions.append(reso) + + resolutions = sorted(resolutions, key=lambda x_: x_.ratio) + + return resolutions + + def _calc_by_step(self): + min_height = self.base_size // 2 + min_width = self.base_size // 2 + max_height = self.base_size * 2 + max_width = self.base_size * 2 + + resolutions = [Resolution(self.base_size, self.base_size)] + + cur_height, cur_width = self.base_size, self.base_size + while True: + if cur_height >= max_height and cur_width <= min_width: + break + + cur_height = min(cur_height + self.step, max_height) + cur_width = max(cur_width - self.step, min_width) + resolutions.append(Resolution(cur_height, cur_width)) + + cur_height, cur_width = self.base_size, self.base_size + while True: + if cur_height <= min_height and cur_width >= max_width: + break + + cur_height = max(cur_height - self.step, min_height) + cur_width = min(cur_width + self.step, max_width) + resolutions.append(Resolution(cur_height, cur_width)) + + resolutions = sorted(resolutions, key=lambda x: x.ratio) + + return resolutions + + +class Bucket(ArrowIndexV2): + def __init__(self, height, width, *args, **kwargs): + super().__init__(*args, **kwargs) + self.height = height + self.width = width + self.ratio = height / width + self.scale_dist = [] + + def get_scale_by_index(self, index, size_col='hw', shadow=None): + """ + Calculate the scale to resize the image to fit the bucket. + + Parameters + ---------- + index: int + An in-json index. + size_col: str + How to get the size of the image. 'hw' for height and width column, + while 'image' for decoding image binary and get the PIL Image size. + shadow: str + The shadow name. If None, return the main arrow file. If not None, return the shadow arrow file. + + Returns + ------- + scale: float + """ + if size_col == 'hw': + w = int(self.get_attribute_by_index(index, 'width', shadow=shadow)) + h = int(self.get_attribute_by_index(index, 'height', shadow=shadow)) + else: + w, h = self.get_image_by_index(index, shadow=shadow).size + tw, th = self.width, self.height + + tr = th / tw + r = h / w + + scale = th / h if r < tr else tw / w + return scale + + @staticmethod + def from_bucket_index(index_file, align=1, shadow_file_fn=None): + with open(index_file, 'r') as f: + res_dict = json.load(f) + + if not isinstance(res_dict['group_length'], dict): + error_msg = f'`group_length` must be a dict, but got {type(res_dict["group_length"])}' + if isinstance(res_dict['group_length'], list): + raise ValueError(f'{error_msg}\nYou may using a vanilla Index V2 file. Try `ArrowIndexV2` instead.') + else: + raise ValueError(error_msg) + + assert 'indices_file' in res_dict, f'indices_file not found in {index_file}' + assert res_dict['indices_file'] != '', f'indices_file is empty in {index_file}' + + indices_file = Path(index_file).parent / res_dict['indices_file'] + assert Path(indices_file).exists(), f'indices_file {indices_file} not found' + + # Loading indices data + indices_data = np.load(indices_file) + + # Build buckets + buckets = [] + keys = [] + for k, v in indices_data.items(): + data = { + 'data_type': res_dict['data_type'], + 'arrow_files': res_dict['arrow_files'], + 'cum_length': res_dict['cum_length'], + } + + data['indices_file'] = '' + data['indices'] = v + data['group_length'] = res_dict['group_length'][k] + + height, width = map(int, k.split('x')) + bucket = Bucket(height, width, res_dict=data, align=align, shadow_file_fn=shadow_file_fn) + + if len(bucket) > 0: + buckets.append(bucket) + keys.append(k) + + resolutions = ResolutionGroup.from_list_of_hxw(keys) + resolutions.attr = [f'{len(bucket):,d}' for bucket in buckets] + + return buckets, resolutions + + +class MultiIndexV2(object): + """ + Multi-bucket index. Support multi-GPU (either single node or multi-node distributed) training. + + Parameters + ---------- + index_files: list + The index files. + batch_size: int + The batch size of each GPU. Required when using MultiResolutionBucketIndexV2 as base index class. + world_size: int + The number of GPUs. Required when using MultiResolutionBucketIndexV2 as base index class. + sample_strategy: str + The sample strategy. Can be 'uniform' or 'probability'. Default to 'uniform'. + If set to probability, a list of probability must be provided. The length of the list must be the same + as the number of buckets. Each probability value means the sample rate of the corresponding bucket. + probability: list + A list of probability. Only used when sample_strategy=='probability'. + shadow_file_fn: callable or dict + A callable function to map shadow file path to a new path. If None, the shadow file path will not be + changed. If a dict is provided, the keys are the shadow names to call the function, and the values are the + callable functions to map the shadow file path to a new path. If a callable function is provided, the key + is 'default'. + seed: int + Only used when sample_strategy=='probability'. The seed to sample the indices. + """ + buckets: List[ArrowIndexV2] + + def __init__(self, + index_files: List[str], + batch_size: Optional[int] = None, + world_size: Optional[int] = None, + sample_strategy: str = 'uniform', + probability: Optional[List[float]] = None, + shadow_file_fn: Optional[Union[Callable, Dict[str, Callable]]] = None, + seed: Optional[int] = None, + ): + self.buckets = self.load_buckets(index_files, + batch_size=batch_size, world_size=world_size, shadow_file_fn=shadow_file_fn) + + self.sample_strategy = sample_strategy + self.probability = probability + self.check_sample_strategy(sample_strategy, probability) + + self.cum_length = self.calc_cum_length() + + self.sampler = np.random.RandomState(seed) + if sample_strategy == 'uniform': + self.total_length = sum([len(bucket) for bucket in self.buckets]) + self.ind_mapper = np.arange(self.total_length) + elif sample_strategy == 'probability': + self.ind_mapper = self.sample_indices_with_probability() + self.total_length = len(self.ind_mapper) + else: + raise ValueError(f"Not supported sample_strategy {sample_strategy}.") + + def load_buckets(self, index_files, **kwargs): + buckets = [ArrowIndexV2(index_file, **kwargs) for index_file in index_files] + return buckets + + def __len__(self): + return self.total_length + + def check_sample_strategy(self, sample_strategy, probability): + if sample_strategy == 'uniform': + pass + elif sample_strategy == 'probability': + if probability is None: + raise ValueError(f"probability must be provided when sample_strategy is 'probability'.") + assert isinstance(probability, (list, tuple)), \ + f"probability must be a list, but got {type(probability)}" + assert len(self.buckets) == len(probability), \ + f"Length of index_files {len(self.buckets)} != Length of probability {len(probability)}" + else: + raise ValueError(f"Not supported sample_strategy {sample_strategy}.") + + def sample_indices_with_probability(self): + ind_mapper_list = [] + accu = 0 + for bucket, p in zip(self.buckets, self.probability): + if p == 1: + # Just use all indices + indices = np.arange(len(bucket)) + accu + else: + # Use all indices multiple times, and then sample some indices without replacement + repeat_times = int(p) + indices_part1 = np.arange(len(bucket)).repeat(repeat_times) + indices_part2 = self.sampler.choice(len(bucket), int(len(bucket) * (p - repeat_times)), replace=False) + indices = np.sort(np.concatenate([indices_part1, indices_part2])) + accu + ind_mapper_list.append(indices) + accu += len(bucket) + ind_mapper = np.concatenate(ind_mapper_list) + return ind_mapper + + def calc_cum_length(self): + cum_length = [] + length = 0 + for bucket in self.buckets: + length += len(bucket) + cum_length.append(length) + return cum_length + + def shuffle(self, seed=None, fast=False): + if self.sample_strategy == 'probability': + # Notice: In order to resample indices when shuffling, shuffle will not preserve the + # initial sampled indices when loading the index. + pass + + # Shuffle indexes + if seed is not None: + state = random.getstate() + random.seed(seed) + random.shuffle(self.buckets) + random.setstate(state) + else: + random.shuffle(self.buckets) + + self.cum_length = self.calc_cum_length() + + # Shuffle indices in each index + for i, bucket in enumerate(self.buckets): + bucket.shuffle(seed + i, fast=fast) + + # Shuffle ind_mapper + if self.sample_strategy == 'uniform': + self.ind_mapper = np.arange(self.total_length) + elif self.sample_strategy == 'probability': + self.ind_mapper = self.sample_indices_with_probability() + else: + raise ValueError(f"Not supported sample_strategy {self.sample_strategy}.") + if seed is not None: + sampler = np.random.RandomState(seed) + sampler.shuffle(self.ind_mapper) + else: + np.random.shuffle(self.ind_mapper) + + def get_arrow_file(self, ind, **kwargs): + """ + Get arrow file by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + kwargs: dict + shadow: str + The shadow name. If None, return the main arrow file. If not None, return the shadow arrow file. + + Returns + ------- + arrow_file: str + """ + ind = self.ind_mapper[ind] + i = bisect.bisect_right(self.cum_length, ind) + bias = self.cum_length[i - 1] if i > 0 else 0 + return self.buckets[i].get_arrow_file(ind - bias, **kwargs) + + def get_data(self, ind, columns=None, allow_missing=False, return_meta=True, **kwargs): + """ + Get data by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + columns: str or list + The columns to be returned. If None, return all columns. + allow_missing: bool + If True, omit missing columns. If False, raise an error if the column is missing. + return_meta: bool + If True, the resulting dict will contain some meta information: + in-json index, in-arrow index, and arrow_name. + kwargs: dict + shadow: str + The shadow name. If None, return the main arrow file. If not None, return the shadow arrow file. + + Returns + ------- + data: dict + A dict containing the data. + """ + ind = self.ind_mapper[ind] + i = bisect.bisect_right(self.cum_length, ind) + bias = self.cum_length[i - 1] if i > 0 else 0 + return self.buckets[i].get_data(ind - bias, columns=columns, allow_missing=allow_missing, + return_meta=return_meta, **kwargs) + + def get_attribute(self, ind, column, **kwargs): + """ + Get single attribute by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + column: str + The column name. + kwargs: dict + shadow: str + The shadow name. If None, return the main arrow file. If not None, return the shadow arrow file. + + Returns + ------- + attribute: Any + """ + ind = self.ind_mapper[ind] + i = bisect.bisect_right(self.cum_length, ind) + bias = self.cum_length[i - 1] if i > 0 else 0 + return self.buckets[i].get_attribute(ind - bias, column, **kwargs) + + def get_image(self, ind, column='image', ret_type='pil', max_size=-1, **kwargs): + """ + Get image by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + column: str + [Deprecated] The column name of the image. Default to 'image'. + ret_type: str + The return type. Can be 'pil' or 'numpy'. Default to 'pil'. + max_size: int + If not -1, resize the image to max_size. max_size is the size of long edge. + kwargs: dict + shadow: str + The shadow name. If None, return the main arrow file. If not None, return the shadow arrow file. + + Returns + ------- + image: PIL.Image.Image or np.ndarray + """ + ind = self.ind_mapper[ind] + i = bisect.bisect_right(self.cum_length, ind) + bias = self.cum_length[i - 1] if i > 0 else 0 + return self.buckets[i].get_image(ind - bias, column, ret_type, max_size, **kwargs) + + def get_md5(self, ind, **kwargs): + """ Get md5 by in-dataset index. """ + ind = self.ind_mapper[ind] + i = bisect.bisect_right(self.cum_length, ind) + bias = self.cum_length[i - 1] if i > 0 else 0 + return self.buckets[i].get_md5(ind - bias, **kwargs) + + def get_columns(self, ind, **kwargs): + """ Get columns by in-dataset index. """ + ind = self.ind_mapper[ind] + i = bisect.bisect_right(self.cum_length, ind) + bias = self.cum_length[i - 1] if i > 0 else 0 + return self.buckets[i].get_columns(ind - bias, **kwargs) + + @staticmethod + def resize_and_crop(image, target_size, resample=Image.LANCZOS, crop_type='random'): + """ + Resize image without changing aspect ratio, then crop the center/random part. + + Parameters + ---------- + image: PIL.Image.Image + The input image to be resized and cropped. + target_size: tuple + The target size of the image. + resample: + The resample method. See PIL.Image.Image.resize for details. Default to Image.LANCZOS. + crop_type: str + 'center' or 'random'. If 'center', crop the center part of the image. If 'random', + crop a random part of the image. Default to 'random'. + + Returns + ------- + image: PIL.Image.Image + The resized and cropped image. + crop_pos: tuple + The position of the cropped part. (crop_left, crop_top) + """ + return ArrowIndexV2.resize_and_crop(image, target_size, resample, crop_type) + + +class MultiResolutionBucketIndexV2(MultiIndexV2): + """ + Multi-resolution bucket index. Support multi-GPU (either single node or multi-node distributed) training. + + Parameters + ---------- + index_file: str + The index file of the bucket index. + batch_size: int + The batch size of each GPU. + world_size: int + The number of GPUs. + shadow_file_fn: callable or dict + A callable function to map shadow file path to a new path. If None, the shadow file path will not be + changed. If a dict is provided, the keys are the shadow names to call the function, and the values are the + callable functions to map the shadow file path to a new path. If a callable function is provided, the key + is 'default'. + """ + buckets: List[Bucket] + + def __init__(self, + index_file: str, + batch_size: int, + world_size: int, + shadow_file_fn: Optional[Union[Callable, Dict[str, Callable]]] = None, + ): + align = batch_size * world_size + if align <= 0: + raise ValueError(f'Align size must be positive, but got {align} = {batch_size} x {world_size}') + + self.buckets, self._resolutions = Bucket.from_bucket_index(index_file, + align=align, + shadow_file_fn=shadow_file_fn, + ) + self.arrow_files = self.buckets[0].arrow_files + self._base_size = self._resolutions.base_size + self._step = self._resolutions.step + + self.buckets = sorted(self.buckets, key=lambda x: x.ratio) + self.cum_length = self.calc_cum_length() + + self.total_length = sum([len(bucket) for bucket in self.buckets]) + assert self.total_length % align == 0, f'Total length {self.total_length} is not divisible by align size {align}' + + self.align_size = align + self.batch_size = batch_size + self.world_size = world_size + self.ind_mapper = np.arange(self.total_length) + + @property + def step(self): + return self._step + + @property + def base_size(self): + return self._base_size + + @property + def resolutions(self): + return self._resolutions + + def shuffle(self, seed=None, fast=False): + # Shuffle indexes + if seed is not None: + state = random.getstate() + random.seed(seed) + random.shuffle(self.buckets) + random.setstate(state) + else: + random.shuffle(self.buckets) + + self.cum_length = self.calc_cum_length() + + # Shuffle indices in each index + for i, bucket in enumerate(self.buckets): + bucket.shuffle(seed + i, fast=fast) + + # Shuffle ind_mapper + batch_ind_mapper = np.arange(self.total_length // self.batch_size) * self.batch_size + if seed is not None: + sampler = np.random.RandomState(seed) + sampler.shuffle(batch_ind_mapper) + else: + np.random.shuffle(batch_ind_mapper) + ind_mapper = np.stack([batch_ind_mapper + i for i in range(self.batch_size)], axis=1).reshape(-1) + self.ind_mapper = ind_mapper + + def get_ratio(self, ind, **kwargs): + """ + Get the ratio of the image by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + + Returns + ------- + width, height, ratio + """ + ind = self.ind_mapper[ind] + width, height = self.get_image(ind, **kwargs).size + return width, height, height / width + + def get_target_size(self, ind): + """ + Get the target size of the image by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + + Returns + ------- + target_width, target_height + """ + ind = self.ind_mapper[ind] + i = bisect.bisect_right(self.cum_length, ind) + return self.buckets[i].width, self.buckets[i].height + + def scale_distribution(self, save_file=None): + if save_file is not None: + scale_dict = np.load(save_file) + for bucket in self.buckets: + bucket.scale_dist = scale_dict[f'{bucket.height}x{bucket.width}'] + else: + for bucket in tqdm(self.buckets): + for index in tqdm(bucket.indices, leave=False): + scale = bucket.get_scale_by_index(index) + bucket.scale_dist.append(scale) + scale_dict = {f'{bucket.height}x{bucket.width}': bucket.scale_dist for bucket in self.buckets} + + if save_file is not None: + save_file = Path(save_file) + save_file.parent.mkdir(exist_ok=True, parents=True) + np.savez_compressed(save_file, **scale_dict) + + return self + + +class MultiMultiResolutionBucketIndexV2(MultiIndexV2): + buckets: List[MultiResolutionBucketIndexV2] + + @property + def step(self): + return [b.step for b in self.buckets] + + @property + def base_size(self): + return [b.base_size for b in self.buckets] + + @property + def resolutions(self): + return [b.resolutions for b in self.buckets] + + def load_buckets(self, index_files, **kwargs): + self.batch_size = kwargs.get('batch_size', None) + self.world_size = kwargs.get('world_size', None) + if self.batch_size is None or self.world_size is None: + raise ValueError("`batch_size` and `world_size` must be provided when using " + "`MultiMultiResolutionBucketIndexV2`.") + buckets = [ + MultiResolutionBucketIndexV2(index_file, + self.batch_size, + self.world_size, + shadow_file_fn=kwargs.get('shadow_file_fn', None), + ) + for index_file in index_files + ] + return buckets + + def sample_indices_with_probability(self, return_batch_indices=False): + bs = self.batch_size + ind_mapper_list = [] + accu = 0 + for bucket, p in zip(self.buckets, self.probability): + if p == 1: + # Just use all indices + batch_indices = np.arange(len(bucket) // bs) * bs + accu + else: + # Use all indices multiple times, and then sample some indices without replacement + repeat_times = int(p) + indices_part1 = np.arange(len(bucket) // bs).repeat(repeat_times) * bs + indices_part2 = self.sampler.choice(len(bucket) // bs, int(len(bucket) * (p / bs - repeat_times)), + replace=False) * bs + batch_indices = np.sort(np.concatenate([indices_part1, indices_part2])) + accu + + if return_batch_indices: + indices = batch_indices + else: + indices = np.stack([batch_indices + i for i in range(bs)], axis=1).reshape(-1) + ind_mapper_list.append(indices) + accu += len(bucket) + ind_mapper = np.concatenate(ind_mapper_list) + return ind_mapper + + def shuffle(self, seed=None, fast=False): + if self.sample_strategy == 'probability': + # Notice: In order to resample indices when shuffling, shuffle will not preserve the + # initial sampled indices when loading the index. + pass + + # Shuffle indexes + if seed is not None: + state = random.getstate() + random.seed(seed) + random.shuffle(self.buckets) + random.setstate(state) + else: + random.shuffle(self.buckets) + + self.cum_length = self.calc_cum_length() + + # Shuffle indices in each index + for i, bucket in enumerate(self.buckets): + bucket.shuffle(seed + i, fast=fast) + + # Shuffle ind_mapper in batch level + if self.sample_strategy == 'uniform': + batch_ind_mapper = np.arange(self.total_length // self.batch_size) * self.batch_size + elif self.sample_strategy == 'probability': + batch_ind_mapper = self.sample_indices_with_probability(return_batch_indices=True) + else: + raise ValueError(f"Not supported sample_strategy {self.sample_strategy}.") + if seed is not None: + sampler = np.random.RandomState(seed) + sampler.shuffle(batch_ind_mapper) + else: + np.random.shuffle(batch_ind_mapper) + self.ind_mapper = np.stack([batch_ind_mapper + i for i in range(self.batch_size)], axis=1).reshape(-1) + + def get_ratio(self, ind, **kwargs): + """ + Get the ratio of the image by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + + Returns + ------- + width, height, ratio + """ + ind = self.ind_mapper[ind] + i = bisect.bisect_right(self.cum_length, ind) + bias = self.cum_length[i - 1] if i > 0 else 0 + return self.buckets[i].get_ratio(ind - bias, **kwargs) + + def get_target_size(self, ind): + """ + Get the target size of the image by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + + Returns + ------- + target_width, target_height + """ + ind = self.ind_mapper[ind] + i = bisect.bisect_right(self.cum_length, ind) + bias = self.cum_length[i - 1] if i > 0 else 0 + return self.buckets[i].get_target_size(ind - bias) + + +def build_multi_resolution_bucket(config_file, + base_size, + src_index_files, + save_file, + reso_step=64, + target_ratios=None, + align=1, + min_size=0, + md5_hw=None, + ): + # Compute base size + resolutions = ResolutionGroup(base_size, step=reso_step, target_ratios=target_ratios, align=align) + print(resolutions) + + save_file = Path(save_file) + save_file.parent.mkdir(exist_ok=True, parents=True) + + if isinstance(src_index_files, str): + src_index_files = [src_index_files] + src_indexes = [] + print(f'Loading indexes:') + for src_index_file in src_index_files: + src_indexes.append(ArrowIndexV2(src_index_file)) + print(f' {src_index_file} | cum_length: {src_indexes[-1].cum_length[-1]} | indices: {len(src_indexes[-1])}') + + if md5_hw is None: + md5_hw = {} + + arrow_files = src_indexes[0].arrow_files[:] # !!!important!!!, copy the list + for src_index in src_indexes[1:]: + arrow_files.extend(src_index.arrow_files[:]) + + cum_length = src_indexes[0].cum_length[:] + for src_index in src_indexes[1:]: + cum_length.extend([x + cum_length[-1] for x in src_index.cum_length]) + print(f'cum_length: {cum_length[-1]}') + + group_length_list = src_indexes[0].group_length[:] + for src_index in src_indexes[1:]: + group_length_list.extend(src_index.group_length[:]) + + total_indices = sum([len(src_index) for src_index in src_indexes]) + total_group_length = sum(group_length_list) + assert total_indices == total_group_length, f'Total indices {total_indices} != Total group length {total_group_length}' + + buckets = [[] for _ in range(len(resolutions))] + cum_length_tmp = 0 + total_index_count = 0 + for src_index, src_index_file in zip(src_indexes, src_index_files): + index_count = 0 + pbar = tqdm(src_index.indices.tolist()) + for i in pbar: + try: + height = int(src_index.get_attribute_by_index(i, 'height')) + width = int(src_index.get_attribute_by_index(i, 'width')) + except Exception as e1: + try: + md5 = src_index.get_attribute_by_index(i, 'md5') + height, width = md5_hw[md5] + except Exception as e2: + try: + width, height = src_index.get_image_by_index(i).size + except Exception as e3: + print(f'Error: {e1} --> {e2} --> {e3}. We will skip this image.') + continue + + if height < min_size or width < min_size: + continue + + ratio = height / width + idx = np.argmin(np.abs(resolutions.ratio - ratio)) + buckets[idx].append(i + cum_length_tmp) + index_count += 1 + print(f"Valid indices {index_count} in {src_index_file}.") + cum_length_tmp += src_index.cum_length[-1] + total_index_count += index_count + print(f'Total indices: {total_index_count}') + + print(f'Making bucket index.') + indices = {} + for i, bucket in tqdm(enumerate(buckets)): + if len(bucket) == 0: + continue + reso = f'{resolutions[i]}' + resolutions.attr[i] = f'{len(bucket):>6d}' + indices[reso] = bucket + + builder = IndexV2Builder(data_type=['multi-resolution-bucket-v2', + f'base_size={base_size}', + f'reso_step={reso_step}', + f'target_ratios={target_ratios}', + f'align={align}', + f'min_size={min_size}', + f'src_files='] + + [f'{src_index_file}' for src_index_file in src_index_files], + arrow_files=arrow_files, + cum_length=cum_length, + indices=indices, + config_file=config_file, + ) + builder.build(save_file) + print(resolutions) + print(f'Build index finished!\n\n' + f' Save path: {Path(save_file).absolute()}\n' + f' Number of indices: {sum([len(v) for k, v in indices.items()])}\n' + f'Number of arrow files: {len(arrow_files)}\n' + ) diff --git a/IndexKits/index_kits/common.py b/IndexKits/index_kits/common.py new file mode 100644 index 0000000..998184f --- /dev/null +++ b/IndexKits/index_kits/common.py @@ -0,0 +1,262 @@ +import json +import numpy as np +from pathlib import Path +from functools import partial + +from .indexer import ArrowIndexV2 +from .bucket import ( + ResolutionGroup, + MultiIndexV2, MultiResolutionBucketIndexV2, MultiMultiResolutionBucketIndexV2, IndexV2Builder +) + + +def load_index(src, + multireso=False, + batch_size=1, + world_size=1, + sample_strategy='uniform', + probability=None, + shadow_file_fn=None, + seed=None, + ): + if isinstance(src, str): + src = [src] + if src[0].endswith('.arrow'): + if multireso: + raise ValueError('Arrow file does not support multiresolution. Please make base index V2 first and then' + 'build multiresolution index.') + idx = IndexV2Builder(src).to_index_v2() + elif src[0].endswith('.json'): + if multireso: + if len(src) == 1: + idx = MultiResolutionBucketIndexV2(src[0], batch_size=batch_size, + world_size=world_size, + shadow_file_fn=shadow_file_fn, + ) + else: + idx = MultiMultiResolutionBucketIndexV2(src, batch_size=batch_size, + world_size=world_size, + sample_strategy=sample_strategy, probability=probability, + shadow_file_fn=shadow_file_fn, seed=seed, + ) + else: + if len(src) == 1: + idx = ArrowIndexV2(src[0], + shadow_file_fn=shadow_file_fn, + ) + else: + idx = MultiIndexV2(src, + sample_strategy=sample_strategy, probability=probability, + shadow_file_fn=shadow_file_fn, seed=seed, + ) + else: + raise ValueError(f'Unknown file type: {src[0]}') + return idx + + +def get_attribute(data, attr_list): + ret_data = {} + for attr in attr_list: + ret_data[attr] = data.get(attr, None) + if ret_data[attr] is None: + raise ValueError(f'Missing {attr} in {data}') + return ret_data + + +def get_optional_attribute(data, attr_list): + ret_data = {} + for attr in attr_list: + ret_data[attr] = data.get(attr, None) + return ret_data + + +def detect_index_type(data): + if isinstance(data['group_length'], dict): + return 'multireso' + else: + return 'base' + +def show_index_info(src, only_arrow_files=False, depth=1): + """ + Show base/multireso index information. + """ + if not Path(src).exists(): + raise ValueError(f'{src} does not exist.') + print(f"Loading index file {src} ...") + with open(src, 'r') as f: + src_data = json.load(f) + print(f"Loaded.") + data = get_attribute(src_data, ['data_type', 'indices_file', 'arrow_files', 'cum_length', + 'group_length', 'indices', 'example_indices']) + opt_data = get_optional_attribute(src_data, ['config_file']) + + # Format arrow_files examples + arrow_files = data['arrow_files'] + if only_arrow_files: + existed = set() + arrow_files_output_list = [] + for arrow_file in arrow_files: + if depth == 0: + if arrow_file not in existed: + arrow_files_output_list.append(arrow_file) + existed.add(arrow_file) + elif depth > 0: + parts = Path(arrow_file).parts + if depth >= len(parts): + continue + else: + arrow_file_part = '/'.join(parts[:-depth]) + if arrow_file_part not in existed: + arrow_files_output_list.append(arrow_file_part) + existed.add(arrow_file_part) + else: + raise ValueError(f'Depth {depth} has exceeded the limit of arrow file {arrow_file}.') + arrow_files_repr = '\n'.join(arrow_files_output_list) + print(arrow_files_repr) + return None + + return_space = '\n' + ' ' * 21 + + if len(arrow_files) <= 4: + arrow_files_repr = return_space.join([arrow_file for arrow_file in arrow_files]) + else: + arrow_files_repr = return_space.join([_ for _ in arrow_files[:2]] + ['...'] + + [_ for _ in arrow_files[-2:]]) + arrow_files_length = len(arrow_files) + + # Format data_type + data_type = data['data_type'] + if isinstance(data_type, str): + data_type = [data_type] + data_type_common = [] + src_files = [] + found_src_files = False + for data_type_item in data_type: + if not found_src_files and data_type_item.strip() != 'src_files=': + data_type_common.append(data_type_item.strip()) + continue + found_src_files = True + if data_type_item.endswith('.json'): + src_files.append(data_type_item.strip()) + else: + data_type_common.append(data_type_item.strip()) + data_type_part2_with_ids = [] + max_id_len = len(str(len(src_files))) + for sid, data_type_item in enumerate(src_files, start=1): + data_type_part2_with_ids.append(f'{str(sid).rjust(max_id_len)}. {data_type_item}') + data_type = data_type_common + data_type_part2_with_ids + data_repr = return_space.join(data_type) + + # Format cum_length examples + cum_length = data['cum_length'] + if len(cum_length) <= 8: + cum_length_repr = ', '.join([str(i) for i in cum_length]) + else: + cum_length_repr = ', '.join([str(i) for i in cum_length[:4]] + ['...'] + [str(i) for i in cum_length[-4:]]) + cum_length_length = len(cum_length) + + if detect_index_type(data) == 'base': + # Format group_length examples + group_length = data['group_length'] + if len(group_length) <= 8: + group_length_repr = ', '.join([str(i) for i in group_length]) + else: + group_length_repr = ', '.join([str(i) for i in group_length[:4]] + ['...'] + [str(i) for i in group_length[-4:]]) + group_length_length = len(group_length) + + # Format indices examples + indices = data['indices'] + if len(indices) == 0 and data['indices_file'] != '': + indices_file = Path(src).parent / data['indices_file'] + if Path(indices_file).exists(): + print(f"Loading indices from {indices_file} ...") + indices = np.load(indices_file)['x'] + print(f"Loaded.") + else: + raise ValueError(f'This Index file contains an extra file {indices_file} which is missed.') + if len(indices) <= 8: + indices_repr = ', '.join([str(i) for i in indices]) + else: + indices_repr = ', '.join([str(i) for i in indices[:4]] + ['...'] + [str(i) for i in indices[-4:]]) + + # Calculate indices total length + indices_length = len(indices) + + print_str = f"""File: {Path(src).absolute()} + +ArrowIndexV2( + \033[4mdata_type:\033[0m {data_repr}""" + + # Process optional data + if opt_data['config_file'] is not None: + print_str += f""" + \033[4mconfig_file:\033[0m {opt_data['config_file']}""" + + # Add common data + print_str += f""" + \033[4mindices_file:\033[0m {data['indices_file']} + \033[4marrow_files: Count = {arrow_files_length:,}\033[0m + Examples: {arrow_files_repr} + \033[4mcum_length: Count = {cum_length_length:,}\033[0m + Examples: {cum_length_repr} + \033[4mgroup_length: Count = {group_length_length:,}\033[0m + Examples: {group_length_repr} + \033[4mindices: Count = {indices_length:,}\033[0m + Examples: {indices_repr}""" + + else: + group_length = data['group_length'] + + indices_file = Path(src).parent / data['indices_file'] + assert Path(indices_file).exists(), f'indices_file {indices_file} not found' + print(f"Loading indices from {indices_file} ...") + indices_data = np.load(indices_file) + print(f"Loaded.") + indices_length = sum([len(indices) for key, indices in indices_data.items()]) + keys = [k for k in group_length.keys() if len(indices_data[k]) > 0] + + resolutions = ResolutionGroup.from_list_of_hxw(keys) + resolutions.attr = [f'{len(indices):>,d}' for k, indices in indices_data.items()] + resolutions.prefix_space = 25 + + print_str = f"""File: {Path(src).absolute()} + +MultiResolutionBucketIndexV2( + \033[4mdata_type:\033[0m {data_repr}""" + + # Process optional data + if opt_data['config_file'] is not None: + print_str += f""" + \033[4mconfig_file:\033[0m {opt_data['config_file']}""" + + # Process config files of base index files + config_files = [] + for src_file in src_files: + src_file = Path(src_file) + if src_file.exists(): + with src_file.open() as f: + base_data = json.load(f) + if 'config_file' in base_data: + config_files.append(base_data['config_file']) + else: + config_files.append('Unknown') + else: + config_files.append('Missing the src file') + if config_files: + config_file_str = return_space.join([f'{str(sid).rjust(max_id_len)}. {config_file}' + for sid, config_file in enumerate(config_files, start=1)]) + print_str += f""" + \033[4mbase config files:\033[0m {config_file_str}""" + + # Add common data + print_str += f""" + \033[4mindices_file:\033[0m {data['indices_file']} + \033[4marrow_files: Count = {arrow_files_length:,}\033[0m + Examples: {arrow_files_repr} + \033[4mcum_length: Count = {cum_length_length:,}\033[0m + Examples: {cum_length_repr} + \033[4mindices: Count = {indices_length:,}\033[0m + \033[4mbuckets: Count = {len(keys)}\033[0m + {resolutions}""" + + print(print_str + '\n)\n') diff --git a/IndexKits/index_kits/dataset/config_parse.py b/IndexKits/index_kits/dataset/config_parse.py new file mode 100644 index 0000000..11e345c --- /dev/null +++ b/IndexKits/index_kits/dataset/config_parse.py @@ -0,0 +1,777 @@ +import yaml +from glob import glob +import pathlib +from pathlib import Path +import pickle +import json +from typing import List +from collections import defaultdict + +def source_names_to_names(source_names_): + """ + Get all arrow file names from a list of source names. + + Examples: + `test.arrow@@x2` means repeat the arrow file `test.arrow` 2 times. + """ + names = [] + exclude_count = 0 + for data_name in source_names_: + if isinstance(data_name, dict): + data_name, value = list(data_name.items())[0] + exclude = value.get('exclude', []) + repeat_times = value.get('repeat', 1) + else: + exclude = [] + repeat_times = 1 + + cur_count = 0 + for name in glob(data_name): + for e in exclude: + if e in name: + print(f"Excluding {name}") + exclude_count += 1 + break + else: + names.append((name, repeat_times)) + cur_count += 1 + print(f"Found {cur_count} arrow files in {data_name}.") + names = list(sorted(names, key=lambda x: x[0])) + return names, exclude_count + + +def get_or_create_arrow_name_list(source_names=None): + + if not isinstance(source_names, list): + raise ValueError(f"`source_names` should be of type list, got {type(source_names)}") + if len(source_names) == 0: + raise ValueError(f"`source_names` is an empty list.") + names, exclude_count = source_names_to_names(source_names) + + print(f"Found {len(names):,} arrow files.") + print(f"Excluded {exclude_count} arrow files.") + return names + + +def to_numeric(x, default_val=0.0): + if isinstance(x, (float, int)): + pass + elif isinstance(x, str): + try: + x = float(x) + except ValueError: + x = default_val + else: + x = default_val + return x + + +TYPE_MAPPER = { + "int": int, + "float": float, + "str": str, + "dict": dict, + "list": list, +} + + +class Operator(object): + def __init__(self, config): + self.config = config + self.arrow_file_keyword = None + + def applicable(self, arrow_file): + # If no keyword is provided, then it is applicable to all arrow files + if self.arrow_file_keyword is None: + return True + # If keyword is provided, then it is applicable to the arrow file if the keyword is in the arrow file name + for keyword in self.arrow_file_keyword: + if keyword in arrow_file: + return True + return False + + @staticmethod + def parse_path(paths): + if isinstance(paths, str): + paths = [paths] + if not isinstance(paths, list): + raise ValueError(f"Invalid type: {type(paths)}") + paths = [Path(p) for p in paths] + return paths + + @staticmethod + def load_md5s(paths: List[pathlib.Path], merge=True): + md5s_list = [] + for path in paths: + if not path.exists(): + raise ValueError(f"Path not found: {path}") + if path.suffix == '.pkl': + with path.open('rb') as f: + md5s = pickle.load(f) + assert isinstance(md5s, (set, dict)), f"Invalid type: {type(md5s)}" + elif path.suffix == '.json': + with path.open() as f: + md5s = json.load(f) + assert isinstance(md5s, dict), f"Invalid type: {type(md5s)}" + elif path.suffix == '.txt': + md5s = set(path.read_text().splitlines()) + else: + raise ValueError(f"Invalid file type: {path}") + md5s_list.append(md5s) + print(f" {path}: {len(md5s):,} md5s") + + if merge: + md5s = md5s_list[0] + for m in md5s_list[1:]: + md5s.update(m) + else: + md5s = md5s_list + return md5s + + +class ColumnFilter(Operator): + numeric_actions = {"eq", "ne", "gt", "lt", "ge", "le"} + str_actions = {"eq", "ne", "len_eq", "len_ne", "len_gt", "len_lt", "len_ge", "len_le", "contains", "not_contains", + "in", "not_in", "lower_last_in"} + int_target_str_actions = {"len_eq", "len_gt", "len_lt", "len_ge", "len_le"} + action_mapper_left = { + "eq": "==", "ne": "!=", "len_eq": ".len()==", + "gt": ">", "len_gt": ".len()>", + "lt": "<", "len_lt": ".len()<", + "ge": ">=", "len_ge": ".len()>=", + "le": "<=", "len_le": ".len()<=", + "contains": ".contains(", + "not_contains": ".not_contains(", + "in": ".isin(", "not_in": ".notin(", + "lower_last_in": "[-1].lower().isin(", + } + action_mapper_right = { + "eq": "", "ne": "", "len_eq": "", + "gt": "", "len_gt": "", + "lt": "", "len_lt": "", + "ge": "", "len_ge": "", + "le": "", "len_le": "", + "contains": ")", + "not_contains": ")", + "in": ")", "not_in": ")", + "lower_last_in": ")", + } + + def __init__(self, config): + super().__init__(config) + self.name = self.column_name = config['name'] + self.dtype = config['type'] + self.action = config['action'] + self.target = config['target'] + self.default = config['default'] + self.arrow_file_cond = config.get('arrow_file', None) + self.arrow_file_keyword = config.get('arrow_file_keyword', None) + if self.arrow_file_keyword is not None: + if isinstance(self.arrow_file_keyword, str): + self.arrow_file_keyword = [self.arrow_file_keyword] + + def check_exists(self, arrow_file, table): + status = True + if self.column_name not in table.column_names: + print(f"Warning: Column `{self.column_name}` not found in {arrow_file}.") + status = status and False + return status + + @property + def data_type(self): + if self.arrow_file_keyword is None or len(self.arrow_file_keyword) == 0: + arrow_file_keyword_repr = "" + else: + arrow_file_keyword_repr = f', arrow_file_keys={self.arrow_file_keyword}' + fmt_str = f"{self.column_name}" \ + f"{self.action_mapper_left[self.action]}{self.target}{self.action_mapper_right[self.action]}" \ + f" (default={self.default}{arrow_file_keyword_repr})" + return fmt_str + + +class NumericFilter(ColumnFilter): + def run_on_column(self, column): + if self.action == 'eq': + return column == self.target + elif self.action == 'ne': + return column != self.target + elif self.action == 'gt': + return column > self.target + elif self.action == 'lt': + return column < self.target + elif self.action == 'ge': + return column >= self.target + elif self.action == 'le': + return column <= self.target + else: + raise ValueError(f"Invalid action: {self.action}") + + +class IntFilter(NumericFilter): + def __call__(self, arrow_file, table): + success = self.check_exists(arrow_file, table) + if not success: + return None + column = table[self.column_name].to_pandas().apply(lambda x: to_numeric(x, self.default)) + return self.run_on_column(column) + + +class FloatFilter(NumericFilter): + def __call__(self, arrow_file, table): + success = self.check_exists(arrow_file, table) + if not success: + return None + column = table[self.column_name].to_pandas().apply(lambda x: to_numeric(x, self.default)) + return self.run_on_column(column) + + +class StrFilter(ColumnFilter): + def __call__(self, arrow_file, table): + success = self.check_exists(arrow_file, table) + if not success: + return None + column = table[self.column_name].to_pandas() + if self.action == 'eq': + return column.apply(lambda x: x == self.target) + elif self.action == 'ne': + return column.apply(lambda x: x != self.target) + elif self.action == 'len_eq': + return column.str.len() == self.target + elif self.action == 'len_ne': + return column.str.len() != self.target + elif self.action == 'len_gt': + return column.str.len() > self.target + elif self.action == 'len_lt': + return column.str.len() < self.target + elif self.action == 'len_ge': + return column.str.len() >= self.target + elif self.action == 'len_le': + return column.str.len() <= self.target + elif self.action == 'contains': + return column.str.contains(self.target) + elif self.action == 'not_contains': + return ~column.str.contains(self.target) + elif self.action == 'in': + return column.apply(lambda x: x in self.target) + elif self.action == 'not_in': + return column.apply(lambda x: x not in self.target) + elif self.action == 'lower_last_in': + return column.apply(lambda x: x[-1].lower() in self.target) + else: + raise ValueError(f"Invalid action: {self.action}") + + +def check_type(parent, config): + if 'type' not in config: + return False, f"Missing required argument: {parent}.type" + if not isinstance(config['type'], str): + return False, f"Argument {parent}.type should be of type str" + if config['type'] not in TYPE_MAPPER: + return False, f"Invalid type: {parent}.type: {config['type']}" + return True, "" + + +def check_keys(parent, config, required_args, types): + for arg, ts in zip(required_args, types): + # Check key exists + if arg not in config: + return False, f"Missing required argument: {parent}.{arg}" + # Check values' type + if ts is not None: + if not isinstance(ts, list): + ts = [ts] + for t in ts: + if not isinstance(config[arg], t): + return False, f"Argument {parent}.{arg} should be of type {t}, got {type(config[arg])}" + return True, "" + + +def get_column_filter(parent, config): + success, msg = check_type(parent, config) + if not success: + return False, msg, None + + if config['type'] == 'str' and config['action'] in ColumnFilter.int_target_str_actions: + dtype = TYPE_MAPPER['int'] + else: + dtype = TYPE_MAPPER[config['type']] + + # Check other keys + required_args = ['name', 'action', 'target', 'default'] + types = [str, str, dtype, dtype] + success, msg = check_keys(parent, config, required_args, types) + if not success: + return False, msg, None + + # Check action values + if config['type'] == 'str': + valid_actions = ColumnFilter.str_actions + else: + valid_actions = ColumnFilter.numeric_actions + if config['action'] not in valid_actions: + return False, f"Invalid action: {parent}.action: {config['action']}", None + + if config['type'] == 'int': + return True, "", IntFilter(config) + elif config['type'] == 'float': + return True, "", FloatFilter(config) + elif config['type'] == 'str': + return True, "", StrFilter(config) + else: + raise ValueError(f"Invalid type: {config['type']}") + + +class MD5Filter(Operator): + valid_actions = { + "list": {"in", "not_in"}, + "dict": {"eq", "ne", "gt", "lt", "ge", "le"}, + } + action_mapper_left = { + "eq": "==", "ne": "!=", + "gt": ">", "lt": "<", + "ge": ">=", "le": "<=", + "in": ".isin(", "not_in": ".isin(", + } + action_mapper_right = { + "eq": "", "ne": "", + "gt": "", "lt": "", + "ge": "", "le": "", + "in": ")", "not_in": ")", + } + + def __init__(self, config): + super().__init__(config) + self.name = config['name'] + self.paths = self.parse_path(config['path']) + self.dtype = config['type'] + self.action = config['action'] + self.target = config.get('target', None) # only for type=='dict' + self.is_valid = config.get('is_valid', None) # only for type=='dict' + self.arrow_file_keyword = config.get('arrow_file_keyword', None) + if self.arrow_file_keyword is not None: + if isinstance(self.arrow_file_keyword, str): + self.arrow_file_keyword = [self.arrow_file_keyword] + + self.md5_data = self.load_md5s(self.paths) + + @property + def data_type(self): + if self.arrow_file_keyword is None or len(self.arrow_file_keyword) == 0: + arrow_file_keyword_repr = "" + else: + arrow_file_keyword_repr = f'(arrow_file_keys={self.arrow_file_keyword})' + return_space = '\n' + ' ' * 22 + if self.dtype == 'list': + fmt_str = ("Good Cases" if self.is_valid else " Bad Cases") \ + + f" (md5): {return_space.join([str(p) for p in self.paths])} " \ + f"{arrow_file_keyword_repr} " \ + f"| {self.name}" + elif self.dtype == 'dict': + fmt_str = ("Good Cases" if self.is_valid else " Bad Cases") \ + + f" (md5): {return_space.join([str(p) for p in self.paths])}\n" \ + f" --> value" \ + f"{self.action_mapper_left[self.action]}{self.target}{self.action_mapper_right[self.action]} " \ + f"{arrow_file_keyword_repr} " \ + f"| {self.name}" + else: + raise ValueError(f"Invalid type: {self.dtype}") + return fmt_str + + +class ListFilter(MD5Filter): + def __call__(self, md5s): + if self.action == "in": + find_valid = self.is_valid + elif self.action == "not_in": + find_valid = not self.is_valid + else: + raise ValueError(f"Invalid action: {self.action}") + + if find_valid: + return md5s.apply(lambda x: x in self.md5_data) + else: + return md5s.apply(lambda x: x not in self.md5_data) + + +class DictFilter(MD5Filter): + def __call__(self, md5s): + if self.is_valid: + return md5s.apply(lambda x: x in self.md5_data and self.cond(self.md5_data[x])) + else: + return md5s.apply(lambda x: (x not in self.md5_data) or not self.cond(self.md5_data[x])) + + def cond(self, value): + if self.action == "eq": + return value == self.target + elif self.action == "ne": + return value != self.target + elif self.action == "gt": + return value > self.target + elif self.action == "lt": + return value < self.target + elif self.action == "ge": + return value >= self.target + elif self.action == "le": + return value <= self.target + else: + raise ValueError(f"Invalid action: {self.action}") + + +def get_md5_filter(parent, config): + success, msg = check_type(parent, config) + if not success: + return False, msg, None + + required_args = ['name', 'path', 'action'] + types = [str, (str, list), str] + if config['type'] == 'dict': + required_args.extend(['target', 'is_valid']) + types.extend([None, bool]) + success, msg = check_keys(parent, config, required_args, types) + if not success: + return False, msg, None + + if config['action'] not in MD5Filter.valid_actions[config['type']]: + return False, f"Invalid action: {parent}.action: {config['action']}", None + + if config['type'] == 'list': + return True, "", ListFilter(config) + elif config['type'] == 'dict': + return True, "", DictFilter(config) + else: + raise ValueError(f"Invalid type: {config['type']}") + + +class FilterCompose(object): + def __init__(self, column_filter_list, md5_filter_list): + self.column_filter_list = column_filter_list + self.md5_filter_list = md5_filter_list + + def __call__(self, arrow_file, table): + stats = {} + length = len(table) + assert length > 0, "Empty table" + + mask = None + for filter_ in self.column_filter_list: + if isinstance(filter_, tuple): + op, filter_list = filter_ + if op == 'logical_or': + sub_mask = None + for sub_filter in filter_list: + if not sub_filter.applicable(arrow_file): + continue + sub_current_mask = sub_filter(arrow_file, table) + if sub_current_mask is not None: + if sub_mask is None: + sub_mask = sub_current_mask + else: + sub_mask = sub_mask | sub_current_mask + if sub_mask is not None: + name = '|'.join([f.name for f in filter_list]) + stats.update({ + name: length - sum(sub_mask) + }) + if mask is None: + mask = sub_mask + else: + mask = mask & sub_mask + else: + raise ValueError(f"Invalid operation: {op}") + else: + if not filter_.applicable(arrow_file): + continue + current_mask = filter_(arrow_file, table) + if current_mask is not None: + stats.update({ + filter_.name: length - sum(current_mask) + }) + if mask is None: + mask = current_mask + else: + mask = mask & current_mask + + md5s = None + for filter_ in self.md5_filter_list: + if not filter_.applicable(arrow_file): + continue + if 'md5' not in table.column_names: + print(f"Warning: Column 'md5' not found in {arrow_file}.") + + md5s = table['md5'].to_pandas() + current_mask = filter_(md5s) + if current_mask is not None: + stats.update({ + filter_.name: length - sum(current_mask) + }) + if mask is None: + mask = current_mask + else: + mask = mask & current_mask + + return mask, stats, md5s + + @property + def data_type(self): + data_type_list = [] + for filter_ in self.column_filter_list + self.md5_filter_list: + if isinstance(filter_, tuple): + data_type_list.append(' || '.join([f.data_type for f in filter_[1]])) + else: + data_type_list.append(filter_.data_type) + return data_type_list + + +class MD5Repeater(Operator): + def __init__(self, config): + super().__init__(config) + self.name = config['name'] + self.paths = self.parse_path(config['path']) + self.dtype = config['type'] + self.plus = config.get('plus', 0) + self.repeat = config.get('repeat', None) + self.arrow_file_keyword = config.get('arrow_file_keyword', None) + if self.arrow_file_keyword is not None: + if isinstance(self.arrow_file_keyword, str): + self.arrow_file_keyword = [self.arrow_file_keyword] + + self.md5_data = self.load_md5s(self.paths) + if self.repeat is None: + # Check if md5_data.values() are integers + for v in self.md5_data.values(): + if not isinstance(v, int): + raise ValueError(f"Values from {self.paths} are not integers. For example: {v}") + # We only check the first value for performance + # We assume all values are the same type + break + + def __call__(self, arrow_file, index, md5): + if md5 in self.md5_data: + if self.repeat is None: + return self.md5_data[md5] + self.plus + else: + return self.repeat + else: + return 1 + + @property + def data_type(self): + path_repr = ', '.join([str(p) for p in self.paths]) + if self.dtype == 'list': + fmt_str = f"[MD5] Repeat {self.repeat} times: {path_repr}" + elif self.dtype == 'dict': + fmt_str = f"[MD5] Repeat multiple times: {path_repr}" + else: + raise ValueError(f"Invalid type: {self.dtype}") + return fmt_str + + +def get_md5_repeater(parent, config): + success, msg = check_type(parent, config) + if not success: + return False, msg, None + + required_args = ['name', 'path'] + types = [str, str] + if config['type'] == 'list': + required_args.extend(['repeat']) + types.extend([int]) + success, msg = check_keys(parent, config, required_args, types) + if not success: + return False, msg, None + + return True, "", MD5Repeater(config) + + +class KeywordRepeater(Operator): + def __init__(self, config): + super().__init__(config) + self.keywords = config['keyword'] + self.repeat = config['repeat'] + self.name = f"Repeat {self.repeat} times" + + def __call__(self, arrow_file, idx, md5): + for key in self.keywords: + if key in arrow_file: + return self.repeat + return 1 + + @property + def data_type(self): + fmt_str = f"[Keyword] Repeat {self.repeat} times: {self.keywords}" + return fmt_str + +def get_keyword_repeater(parent, config): + required_args = ["repeat", "keyword"] + types = [int, list] + success, msg = check_keys(parent, config, required_args, types) + if not success: + return False, msg, None + + return True, "", KeywordRepeater(config) + + +class RepeaterCompose(object): + def __init__(self, repeater_list): + self.repeater_list = repeater_list + + def __call__(self, arrow_file, table, indices, repeat_times=1, md5s=None): + stats = defaultdict(int) + length = len(table) + assert length > 0, "Empty table" + + if md5s is None: + if 'md5' not in table.column_names: + print(f"Warning: Column 'md5' not found in {arrow_file}.") + md5s = None + else: + md5s = table['md5'].to_pandas() + + repeated_indices = [] + for idx in indices: + md5 = md5s[idx] if md5s is not None else None + max_repeat = repeat_times + max_i = -1 + for i, repeater in enumerate(self.repeater_list): + if not repeater.applicable(arrow_file): + continue + repeat = repeater(arrow_file, idx, md5) + if repeat > max_repeat: + max_repeat = repeat + max_i = i + if max_i >= 0: + stats[self.repeater_list[max_i].name] += (max_repeat - 1) + repeated_indices.extend([idx] * max_repeat) + + return repeated_indices, stats + + @property + def data_type(self): + data_type_list = [] + for repeater in self.repeater_list: + data_type_list.append(repeater.data_type) + return data_type_list + + +class DatasetConfig(object): + def __init__(self, work_dir, config_file): + self.work_dir = work_dir + self.config_file = str(Path(config_file).absolute()) + + with open(self.config_file, 'r') as f: + self.data = yaml.safe_load(f) + + self.names = self.parse_names() # arrow names + arrow_max_repeat = max([x[1] for x in self.names]) + self.filter = self.parse_filter() + self.repeater = self.parse_repeater(enable_arrow_repeat=arrow_max_repeat > 1) + + # Extra arguments + self.remove_md5_dup = self.data.get('remove_md5_dup', False) + + def assert_unknown_args(self): + unknown_args = set(self.data.keys()) - { + 'source', 'filter', 'repeater', 'remove_md5_dup' + } + if len(unknown_args) > 0: + raise ValueError(f"Unknown arguments in config file ({self.config_file}): {unknown_args}") + + def parse_filter(self): + column_filter_list = [] + md5_filter_list = [] + if 'filter' in self.data: + filters = self.data['filter'] + if 'column' in filters: + column_filter_configs = filters['column'] + assert isinstance(column_filter_configs, list), "filter.column should be a list." + for i, config in enumerate(column_filter_configs): + if config.get('logical_or', None) is not None: + assert isinstance(config['logical_or'], list), \ + f"filter.column[{i}].logical_or should be a list, got {type(config['logical_or'])}" + sub_column_filter_list = [] + for j, sub_config in enumerate(config['logical_or']): + sub_success, sub_msg, sub_filter_ = get_column_filter(f'filter.column[{i}-logical_or].[{j}]', sub_config) + if not sub_success: + raise ValueError(sub_msg) + sub_column_filter_list.append(sub_filter_) + success = True + msg = '' + filter_ = ('logical_or', sub_column_filter_list) + else: + success, msg, filter_ = get_column_filter(f'filter.column[{i}]', config) + if not success: + raise ValueError(msg) + column_filter_list.append(filter_) + if 'md5' in filters: + md5_filter_configs = filters['md5'] + assert isinstance(md5_filter_configs, list), "filter.md5 should be a list." + for i, config in enumerate(md5_filter_configs): + if config.get('logical_or', None) is not None: + assert isinstance(config['logical_or'], list), \ + f"filter.md5[{i}].logical_or should be a list, got {type(config['logical_or'])}" + sub_md5_filter_list = [] + for j, sub_config in enumerate(config['logical_or']): + sub_success, sub_msg, sub_filter_ = get_md5_filter(f'filter.md5[{i}-logical_or].[{j}]', sub_config) + if not sub_success: + raise ValueError(sub_msg) + sub_md5_filter_list.append(sub_filter_) + success = True + msg = '' + filter_ = ('logical_or', sub_md5_filter_list) + else: + success, msg, filter_ = get_md5_filter(f'filter.md5[{i}]', config) + if not success: + raise ValueError(msg) + md5_filter_list.append(filter_) + + if column_filter_list or md5_filter_list: + composed_filter = FilterCompose(column_filter_list, md5_filter_list) + else: + composed_filter = None + return composed_filter + + def parse_repeater(self, enable_arrow_repeat=False): + repeater_list = [] + if 'repeater' in self.data: + repeaters = self.data['repeater'] + if 'md5' in repeaters: + md5_repeater_configs = repeaters['md5'] + assert isinstance(md5_repeater_configs, list), "repeater.md5 should be a list." + for i, config in enumerate(md5_repeater_configs): + success, msg, repeater = get_md5_repeater(f'repeater.md5[{i}]', config) + if not success: + raise ValueError(msg) + repeater_list.append(repeater) + + if 'arrow_file_keyword' in repeaters: + keyword_repeater_configs = repeaters['arrow_file_keyword'] + assert isinstance(keyword_repeater_configs, list), "repeater.arrow_file_keyword should be a list." + for i, config in enumerate(keyword_repeater_configs): + success, msg, repeater = get_keyword_repeater(f'repeater.arrow_file_keyword[{i}]', config) + if not success: + raise ValueError(msg) + repeater_list.append(repeater) + + if repeater_list or enable_arrow_repeat: + composed_repeater = RepeaterCompose(repeater_list) + else: + composed_repeater = None + return composed_repeater + + def parse_names(self): + if 'source' in self.data: + source = self.data['source'] + assert isinstance(source, list), "source should be a list." + else: + raise ValueError("In the YAML file, the ‘source’ field is filled in incorrectly, please check") + + names = get_or_create_arrow_name_list(source) + return names + + @property + def data_type(self): + data_types = { + 'Filters': [] if self.filter is None else self.filter.data_type, + 'Repeaters': [] if self.repeater is None else self.repeater.data_type, + } + return data_types diff --git a/IndexKits/index_kits/dataset/make_dataset_core.py b/IndexKits/index_kits/dataset/make_dataset_core.py new file mode 100644 index 0000000..6f9f388 --- /dev/null +++ b/IndexKits/index_kits/dataset/make_dataset_core.py @@ -0,0 +1,320 @@ +import json +import pickle +from collections import defaultdict +from glob import glob +from multiprocessing import Pool +from pathlib import Path + +import pandas as pd +import yaml + +import numpy as np +import pyarrow as pa +from tqdm import tqdm + +from index_kits.indexer import IndexV2Builder +from index_kits.bucket import build_multi_resolution_bucket +from index_kits.dataset.config_parse import DatasetConfig + + +def get_table(arrow_file): + return pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, 'r')).read_all() + + +def get_indices(arrow_file, repeat_times, filter_fn, repeat_fn, callback=None): + """ + Get valid indices from a single arrow_file. + + Parameters + ---------- + arrow_file: str + repeat_times: int + Repeat remain indices multiple times. + filter_fn + callback + + Returns + ------- + + """ + try: + table = pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, 'r')).read_all() + except Exception as e: + print(arrow_file, e) + raise e + length = len(table) + + if len(table) == 0: + print(f"Warning: Empty table: {arrow_file}") + indices = [] + stats = {} + + else: + # Apply filter_fn if available + if filter_fn is not None: + mask, stats, md5s = filter_fn(arrow_file, table) + else: + mask = pd.Series([True] * length) + stats = {} + md5s = None + + # Apply callback function if available + if callback is not None: + mask, stats = callback(arrow_file, table, mask, stats, md5s) + + # Get indices + if mask is not None: + indices = np.where(mask)[0].tolist() + else: + indices = list(range(length)) + + # Apply indices repeat + if repeat_fn is not None: + indices, repeat_stats = repeat_fn(arrow_file, table, indices, repeat_times, md5s) + stats.update(repeat_stats) + + return arrow_file, length, indices, stats + + +def load_md5_files(files, name=None): + if isinstance(files, str): + files = [files] + md5s = set() + for file in files: + md5s.update(Path(file).read_text().splitlines()) + print(f" {name} md5s: {len(md5s):,}") + + return md5s + +def load_md52cls_files(files, name=None): + if isinstance(files, str): + files = [files] + md52cls = {} + for file in files: + with Path(file).open() as f: + md52cls.update(json.load(f)) + print(f" {name} md52cls: {len(md52cls):,}") + + return md52cls + + +def merge_and_build_index(data_type, src, dconfig, save_path): + if isinstance(src, str): + files = list(sorted(glob(src))) + else: + files = list(sorted(src)) + print(f"Found {len(files):,} temp pickle files.") + for fname in files: + print(f" {fname}") + arrow_files = [] + table_lengths = [] + indices_list = [] + bad_stats_total = defaultdict(int) + total_indices = 0 + total_processed_length = 0 + for file_name in tqdm(files): + with Path(file_name).open('rb') as f: + data = pickle.load(f) + for arrow_file, table_length, indices, *args in tqdm(data, leave=False): + arrow_files.append(arrow_file) + table_lengths.append(table_length) + total_processed_length += table_length + indices_list.append(indices) + total_indices += len(indices) + if len(args) > 0 and args[0]: + bad_stats = args[0] + for k, v in bad_stats.items(): + bad_stats_total[k] += v + + if len(bad_stats_total): + stats_save_dir = Path(save_path).parent + stats_save_dir.mkdir(parents=True, exist_ok=True) + stats_save_path = stats_save_dir / (Path(save_path).stem + '_stats.txt') + stats_save_path.write_text('\n'.join([f'{k:>50s} {v}' for k, v in bad_stats_total.items()]) + '\n') + print(f"Save stats to {stats_save_path}") + + print(f'Arrow files: {len(arrow_files):,}') + print(f'Processed indices: {total_processed_length:,}') + print(f'Valid indices: {total_indices:,}') + + cum_length = 0 + total_indices = [] + cum_lengths = [] + group_lengths = [] + existed = set() + print(f"Accumulating indices...") + pbar = tqdm(zip(arrow_files, table_lengths, indices_list), total=len(arrow_files), mininterval=1) + _count = 0 + for arrow_file, table_length, indices in pbar: + if len(indices) > 0 and dconfig.remove_md5_dup: + new_indices = [] + table = get_table(arrow_file) + if 'md5' not in table.column_names: + raise ValueError(f"Column 'md5' not found in {arrow_file}. " + f"When `remove_md5_dup: true` is set, md5 column is required.") + md5s = table['md5'].to_pandas() + for i in indices: + md5 = md5s[i] + if md5 in existed: + continue + existed.add(md5) + new_indices.append(i) + indices = new_indices + + total_indices.extend([int(i + cum_length) for i in indices]) + cum_length += table_length + cum_lengths.append(cum_length) + group_lengths.append(len(indices)) + + _count += 1 + + if _count % 100 == 0: + pbar.set_description(f'Indices: {len(total_indices):,}') + + builder = IndexV2Builder(data_type=data_type, + arrow_files=arrow_files, + cum_length=cum_lengths, + group_length=group_lengths, + indices=total_indices, + config_file=dconfig.config_file, + ) + builder.build(save_path) + print(f'Build index finished!\n\n' + f' Save path: {Path(save_path).absolute()}\n' + f' Number of indices: {len(total_indices)}\n' + f'Number of arrow files: {len(arrow_files)}\n' + ) + + +def worker_startup(rank, world_size, dconfig, prefix, work_dir, callback=None): + # Prepare names for this worker + num = (len(dconfig.names) + world_size - 1) // world_size + arrow_names = dconfig.names[rank * num:(rank + 1) * num] + print(f'Rank {rank} has {len(arrow_names):,} names.') + + # Run get indices + print(f"Start getting indices...") + indices = [] + for arrow_name, repeat_times in tqdm(arrow_names, position=rank, desc=f"#{rank}: ", leave=False): + indices.append(get_indices(arrow_name, repeat_times, dconfig.filter, dconfig.repeater, callback)) + + # Save to a temp file + temp_save_path = work_dir / f'data/temp_pickles/{prefix}-{rank + 1}_of_{world_size}.pkl' + temp_save_path.parent.mkdir(parents=True, exist_ok=True) + with temp_save_path.open('wb') as f: + pickle.dump(indices, f) + print(f'Rank {rank} finished. Write temporary data to {temp_save_path}') + + return temp_save_path + + +def startup(config_file, + save, + world_size=1, + work_dir='.', + callback=None, + use_cache=False, + ): + work_dir = Path(work_dir) + save_path = Path(save) + if save_path.suffix != '.json': + save_path = save_path.parent / (save_path.name + '.json') + print(f"Using save_path: {save_path}") + prefix = f"{save_path.stem}" + + # Parse dataset config and build the data_type list + dconfig = DatasetConfig(work_dir, config_file) + data_type = [] + for k, v in dconfig.data_type.items(): + data_type.extend(v) + print(f"{k}:") + for x in v: + print(f' {x}') + if dconfig.remove_md5_dup: + data_type.append('Remove md5 duplicates.') + else: + data_type.append('Keep md5 duplicates.') + + # Start processing + if not use_cache: + temp_pickles = [] + if world_size == 1: + print(f"\nRunning in single process mode...") + temp_pickles.append(worker_startup(rank=0, + world_size=1, + dconfig=dconfig, + prefix=prefix, + work_dir=work_dir, + callback=callback, + )) + else: + print(f"\nRunning in multi-process mode (world_size={world_size})...") + p = Pool(world_size) + temp_pickles_ = [] + for i in range(world_size): + temp_pickles_.append(p.apply_async(worker_startup, args=(i, world_size, dconfig, prefix, work_dir, callback))) + + for res in temp_pickles_: + temp_pickles.append(res.get()) + # close + p.close() + p.join() + else: + temp_pickles = glob(f'{work_dir}/data/temp_pickles/{prefix}-*_of_{world_size}.pkl') + + # Merge temp pickles and build index + merge_and_build_index(data_type, + temp_pickles, + dconfig, + save_path, + ) + + +def make_multireso(target, + config_file=None, + src=None, + base_size=None, + reso_step=None, + target_ratios=None, + align=None, + min_size=None, + md5_file=None, + ): + if config_file is not None: + with Path(config_file).open() as f: + config = yaml.safe_load(f) + else: + config = {} + src = config.get('src', src) + base_size = config.get('base_size', base_size) + reso_step = config.get('reso_step', reso_step) + target_ratios = config.get('target_ratios', target_ratios) + align = config.get('align', align) + min_size = config.get('min_size', min_size) + md5_file = config.get('md5_file', md5_file) + + if src is None: + raise ValueError('src must be provided in either config file or command line.') + if base_size is None: + raise ValueError('base_size must be provided.') + if reso_step is None and target_ratios is None: + raise ValueError('Either reso_step or target_ratios must be provided.') + + if md5_file is not None: + with open(md5_file, 'rb') as f: + md5_hw = pickle.load(f) + print(f'Md5 to height and width: {len(md5_hw):,}') + else: + md5_hw = None + + build_multi_resolution_bucket( + config_file=config_file, + base_size=base_size, + reso_step=reso_step, + target_ratios=target_ratios, + align=align, + min_size=min_size, + src_index_files=src, + save_file=target, + md5_hw=md5_hw, + ) diff --git a/IndexKits/index_kits/indexer.py b/IndexKits/index_kits/indexer.py new file mode 100644 index 0000000..4e27fda --- /dev/null +++ b/IndexKits/index_kits/indexer.py @@ -0,0 +1,832 @@ +import bisect +import io +import json +import random +from pathlib import Path +import ast +from itertools import chain +from collections import defaultdict +from functools import partial +from glob import glob + +import numpy as np +import pyarrow as pa +from PIL import Image +from tqdm import tqdm + + +def get_table(arrow_file): + """ + Read an arrow file and return an arrow table. + """ + return pa.ipc.RecordBatchFileReader(pa.memory_map(f"{arrow_file}", "r")).read_all() + + +def assert_type(data, dtype, msg=''): + if not isinstance(data, dtype): + raise ValueError(f'Expected {msg} type {dtype}, got {type(data)}.') + + +def ndarray_to_list(data): + if isinstance(data, np.ndarray): + data = data.tolist() + elif isinstance(data, dict): + data = {k: ndarray_to_list(v) for k, v in data.items()} + elif isinstance(data, (list, tuple)): + # Assert that all elements in data are python integer, not numpy integer. + # Because numpy integer cannot be serialized to json. + data = [int(x) for x in data] + else: + raise ValueError(f'Expected data type list, tuple, dict or np.ndarray, got {type(data)}.') + return data + + +class ArrowIndexV2(object): + """ + ArrowIndexV2 is a new version of ArrowIndex. + + Parameters + ---------- + index_file: str or pathlib.Path + The path of index file. Either index_file or res_dict should be provided. + res_dict: dict + The index dict. Either index_file or res_dict should be provided. + align: int + Align the length of indices to be a multiple of align. Generally align should be the batch size * world_size. + shadow_file_fn: callable or dict + A callable function to map shadow file path to a new path. If None, the shadow file path will not be + changed. If a dict is provided, the keys are the shadow names to call the function, and the values are the + callable functions to map the shadow file path to a new path. If a callable function is provided, the key + is 'default'. + + Examples + -------- + >>> index_file = 'data.json' + >>> indexObj = ArrowIndexV2(index_file) + >>> pil_image = indexObj.get_image(0) + >>> text = indexObj.get_attribute(0, column='text_zh') + + """ + def __init__(self, index_file=None, res_dict=None, align=1, + shadow_file_fn=None, **kwargs): + if index_file is not None: + with open(index_file, 'r') as f: + res_dict = json.load(f) + elif res_dict is not None: + pass + else: + raise ValueError(f'Either index_file or res_dict should be provided.') + + self.shadow_file_fn = {} + if shadow_file_fn is not None: + if not callable(shadow_file_fn) and not isinstance(shadow_file_fn, dict): + raise ValueError('shadow_file_fn should be a callable function or a dict.') + if callable(shadow_file_fn): + self.shadow_file_fn['default'] = shadow_file_fn + else: + for k, v in shadow_file_fn.items(): + if not callable(v): + raise ValueError(f'{k} should be a callable function.') + self.shadow_file_fn[k] = v + + self._data = res_dict + self.data_type = res_dict['data_type'] + self.arrow_files = res_dict['arrow_files'] + self.cum_length = res_dict['cum_length'] + + self.group_length = res_dict['group_length'] + error_msg = f'Expected group_length type list, got {type(self.group_length)}.' + if isinstance(self.group_length, dict): + raise ValueError(f'{error_msg}\nNote: You may using a multi-resolution index file. ' + 'Try `MultiResolutionBucketIndexV2` instead.') + elif not isinstance(self.group_length, list): + raise ValueError(error_msg) + + self.indices = res_dict['indices'] + if 'indices_file' in res_dict: + self.indices_file = res_dict['indices_file'] + if self.indices_file != '': + indices_file = Path(index_file).parent / self.indices_file + if Path(indices_file).exists(): + self.indices = np.load(indices_file)['x'] + else: + raise ValueError(f'This Index file contains an extra file {indices_file} which is missed.') + else: + self.indices_file = '' + + if not isinstance(self.indices, list) and not isinstance(self.indices, np.ndarray): + raise ValueError(f'Expected indices type list or np.ndarray, got {type(self.indices)}.') + + if align > 1: + if isinstance(self.indices, np.ndarray): + self.indices = self.indices.tolist() + self.align(align) + + self.indices = np.asarray(self.indices, int) + + if len(self.arrow_files) != len(self.cum_length): + raise ValueError(f'Length of arrow_files and cum_length does not match. {len(self.arrow_files)} != {len(self.cum_length)}') + if len(self.arrow_files) != len(self.group_length): + raise ValueError(f'Length of arrow_files and group_length does not match. {len(self.arrow_files)} != {len(self.group_length)}') + if len(self.indices) == 0: + raise ValueError(f'No indices found in index_dict.') + if isinstance(self.indices, list) and self.indices[-1] > self.cum_length[-1] - 1: + raise ValueError(f'Indices exceed cum_length.') + + # Warning: + # Ensure that indices are an increasing array. Currently, + # no checks are performed due to the potential slowness when dealing with hundreds of millions of data points. + + self.bias = self.cum_length + + self._cur_arrow_file = None + self._cur_table_map = None + self._cur_table = None + self._index_bias = 0 + self.last_index = -1 + + self._shadow_cur_arrow_file = {} + self._shadow_cur_table_map = {} + self._shadow_cur_table = {} + self._shadow_index_bias = {} + self.shadow_last_index = {} + for k in self.shadow_file_fn.keys(): + self._shadow_cur_arrow_file[k] = None + self._shadow_cur_table_map[k] = None + self._shadow_cur_table[k] = None + self._shadow_index_bias[k] = 0 + self.shadow_last_index[k] = -1 + + def __len__(self): + return len(self.indices) + + def __repr__(self): + return f""" + ArrowIndexV2( + data_type {self.data_type} + indices_file {self.indices_file} + arrow_files Count={len(self.arrow_files):,} ({self.arrow_files[0]}, ...) + cum_length Count={len(self.cum_length):,} ({self.cum_length[0]}, ...) + group_length Count={len(self.group_length):,} ({self.group_length[0]}, ...) + indices Count={len(self.indices):,} + example_indices Count={len(self._data['example_indices']):,} + ) + """ + + def check_exists(self): + for arrow_file in tqdm(self.arrow_files): + if not Path(arrow_file).exists(): + print(arrow_file) + + def align(self, align): + """ + Repeat the index so that the length is a multiple of batch_size * world_size. + """ + if len(self) % align == 0: + return + + repeat_num = align - len(self) % align + if repeat_num >= len(self): + repeat_n = repeat_num // len(self) + repeat_times = [repeat_n + 1 for _ in self.indices] + group_length_new = [ll * (repeat_n + 1) for ll in self.group_length] + repeat_num -= repeat_n * len(self) + else: + repeat_times = [1 for _ in range(repeat_num)] + group_length_new = [ll for ll in self.group_length] + + for i in range(repeat_num): + repeat_times[-i - 1] += 1 + + repeat_start_idx = len(self) - len(repeat_times) + + group_id = -1 + while group_length_new[group_id] == 0: + group_id -= 1 + + # Allocate the remaining indices that need to be repeated, + # while also counting how many indices have been checked. + # If the count reaches the group_length, switch to the next group + + # The reason for paying attention to group_length is that when repeating indices, + # group_length also needs to be updated synchronously.. + group_acc = 0 + for i in range(repeat_num): + group_length_new[group_id] += 1 + group_acc += 1 + if group_acc == self.group_length[group_id]: + group_id -= 1 + while group_length_new[group_id] == 0: + group_id -= 1 + group_acc = 0 + + temp = [] + for i, value in enumerate(self.indices[repeat_start_idx:]): + temp.extend([value] * repeat_times[i]) + + self.indices = np.concatenate([self.indices[:repeat_start_idx], temp]) + + self.group_length = group_length_new + + def shuffle(self, seed=None, fast=False): + """ + It takes about 30 seconds for an index consisting of 100_000 arrows. + """ + if fast: + return self.shuffle_fast(seed) + + indices = self.indices.tolist() + + if seed is not None: + state = random.getstate() + random.seed(seed) + + indices_group_list = [] + group_cum_len = 0 + for group_len in self.group_length: + indices_group = indices[group_cum_len:group_cum_len + group_len] + random.shuffle(indices_group) + indices_group_list.append((indices_group, group_len)) + group_cum_len += group_len + random.shuffle(indices_group_list) + self.group_length = [x[1] for x in indices_group_list] + self.indices = np.asarray(list(chain.from_iterable([x[0] for x in indices_group_list]))) + + if seed is not None: + random.setstate(state) + + def shuffle_fast(self, seed=None): + if seed is not None: + sampler = np.random.RandomState(seed) + sampler.shuffle(self.indices) + else: + np.random.shuffle(self.indices) + + def get_table(self, arrow_file, shadow=None): + """ + Read an arrow file and return an arrow table. + """ + if shadow is None: + if self._cur_table is not None: + if self._cur_arrow_file == arrow_file: + # This is the same arrow file. Return the cached table. + return self._cur_table + else: + # This is a different arrow file. Clear the cache. + self._cur_table_map.close() + self._cur_table = None + + self._cur_arrow_file = arrow_file + self._cur_table_map = pa.memory_map(f"{arrow_file}", "r") + self._cur_table = pa.ipc.RecordBatchFileReader(self._cur_table_map).read_all() + return self._cur_table + else: + if self._shadow_cur_table[shadow] is not None: + if self._shadow_cur_arrow_file[shadow] == arrow_file: + return self._shadow_cur_table[shadow] + else: + self._shadow_cur_table_map[shadow].close() + self._shadow_cur_table[shadow] = None + + self._shadow_cur_arrow_file[shadow] = arrow_file + self._shadow_cur_table_map[shadow] = pa.memory_map(f"{arrow_file}", "r") + self._shadow_cur_table[shadow] = pa.ipc.RecordBatchFileReader(self._shadow_cur_table_map[shadow]).read_all() + return self._shadow_cur_table[shadow] + + def get_arrow_file_by_index(self, index, return_index_bias=False, shadow=None): + i = bisect.bisect_right(self.cum_length, index) + arrow_file = self.arrow_files[i] + + if return_index_bias: + if i == 0: + index_bias = 0 + else: + index_bias = self.cum_length[i - 1] + + return arrow_file, index_bias + + return arrow_file + + def get_arrow_file(self, ind, shadow=None): + """ + Get arrow file by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + shadow: str + The shadow name. If None, return the main arrow file. If not None, return the shadow arrow file. + + Returns + ------- + arrow_file: str + The arrow file path. + """ + index = self.indices[ind] + return self.get_arrow_file_by_index(index, shadow=shadow) + + def load_table_by_index(self, index, shadow=None): + if shadow is None: + if index == self.last_index: + return self._cur_table + arrow_file, self._index_bias = \ + self.get_arrow_file_by_index(index, return_index_bias=True) + self._cur_table = self.get_table(arrow_file) + self.last_index = index + return self._cur_table + else: + if index == self.shadow_last_index[shadow]: + return self._shadow_cur_table[shadow] + shadow_arrow_file, _shadow_index_bias = \ + self.get_arrow_file_by_index(index, return_index_bias=True, shadow=shadow) + self._shadow_index_bias[shadow] = _shadow_index_bias + self._shadow_cur_table[shadow] = self.get_table(shadow_arrow_file, shadow=shadow) + self.shadow_last_index[shadow] = index + return self._shadow_cur_table[shadow] + + def get_data_by_index(self, index, columns=None, allow_missing=False, return_meta=True, shadow=None): + table = self.load_table_by_index(index, shadow=shadow) + if isinstance(columns, str): + columns = [columns] + if columns is None: + columns = list(table.column_names) + + index_bias = self._index_bias if shadow is None else self._shadow_index_bias[shadow] + in_arrow_index = index - index_bias + if return_meta: + cur_arrow_file = self._cur_arrow_file if shadow is None else self._shadow_cur_arrow_file[shadow] + data = { + 'index': index, + 'in_arrow_index': in_arrow_index, + 'arrow_name': cur_arrow_file, + } + else: + data = {} + + if allow_missing: + for col in columns: + if col in table.column_names: + data[col] = table[col][in_arrow_index].as_py() + else: + for col in columns: + data[col] = table[col][in_arrow_index].as_py() + return data + + def get_data(self, ind, columns=None, allow_missing=False, return_meta=True, shadow=None): + """ + Get data by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + columns: str or list + The columns to be returned. If None, return all columns. + allow_missing: bool + If True, omit missing columns. If False, raise an error if the column is missing. + return_meta: bool + If True, the resulting dict will contain some meta information: + in-json index, in-arrow index, and arrow_name. + shadow: str + The shadow name. If None, return the main data. If not None, return the shadow data. + + Returns + ------- + data: dict + A dict containing the data. + """ + index = self.indices[ind] + return self.get_data_by_index(index, columns, allow_missing=allow_missing, return_meta=return_meta, + shadow=shadow) + + def get_attribute_by_index(self, index, column, shadow=None): + table = self.load_table_by_index(index, shadow=shadow) + index_bias = self._index_bias if shadow is None else self._shadow_index_bias[shadow] + return table[column][index - index_bias].as_py() + + def get_attribute(self, ind, column, shadow=None): + """ + Get single attribute by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + column: str + The column name. + shadow: str + The shadow name. If None, return the main data. If not None, return the shadow data. + + Returns + ------- + data: can be any type + """ + index = self.indices[ind] + return self.get_attribute_by_index(index, column, shadow=shadow) + + def get_image_by_index(self, index, column='image', ret_type='pil', max_size=-1, shadow=None): + table = self.load_table_by_index(index, shadow=shadow) + index_bias = self._index_bias if shadow is None else self._shadow_index_bias[shadow] + + col = 'image' if 'image' in table.column_names else 'binary' + temp = table[col][index - index_bias].as_py() + image_bytes = io.BytesIO(temp) + image_bytes.seek(0) + try: + # convert(RGB) has two purposes: + # 1. Convert the image to RGB mode. Some images are in grayscale/RGBA mode, which will cause channel + # inconsistency in following processing. + # 2. Convert the image to RGB mode. Some images are in P mode, which will be forced to use NEAREST resample + # method in resize (even if you specify LANCZOS), which will cause blurry images. + pil_image = Image.open(image_bytes).convert("RGB") + except Exception as e: + print(f'get_image_by_index | Error: {e} ({self.get_arrow_file_by_index(index), index - index_bias})') + pil_image = Image.new("RGB", (256, 256), (255, 255, 255)) + + if max_size > 0: + # Resize the image to max_size. max_size is the size of long edge + w, h = pil_image.size + if w > h: + new_w = max_size + new_h = int(h * max_size / w) + else: + new_h = max_size + new_w = int(w * max_size / h) + pil_image = pil_image.resize((new_w, new_h)) + + if ret_type == 'numpy': + return np.array(pil_image) + + return pil_image + + def get_image(self, ind, column='image', ret_type='pil', max_size=-1, shadow=None): + """ + Get image by in-dataset index. + + Parameters + ---------- + ind: int + The in-dataset index. + column: str + [Deprecated] The column name of the image. Default to 'image'. + ret_type: str + The return type. Can be 'pil' or 'numpy'. Default to 'pil'. + max_size: int + If not -1, resize the image to max_size. max_size is the size of long edge. + shadow: str + The shadow name. If None, return the main image. If not None, return the shadow image. + + Returns + ------- + image: PIL.Image.Image or np.ndarray + """ + index = self.indices[ind] + return self.get_image_by_index(index, column, ret_type, max_size, shadow=shadow) + + def get_md5_by_index(self, index, shadow=None): + table = self.load_table_by_index(index, shadow=shadow) + index_bias = self._index_bias if shadow is None else self._shadow_index_bias[shadow] + return table['md5'][index - index_bias].as_py() + + def get_md5(self, ind, shadow=None): + index = self.indices[ind] + return self.get_md5_by_index(index, shadow=shadow) + + def get_columns_by_index(self, index, shadow=None): + table = self.load_table_by_index(index, shadow=shadow) + return table.column_names + + def get_columns(self, ind, shadow=None): + index = self.indices[ind] + return self.get_columns_by_index(index, shadow=shadow) + + def source_distribution(self, save_path=None, shadow=None): + sources = defaultdict(int) + for index in tqdm(self.indices): + source = self.get_attribute_by_index(index, 'source', shadow=shadow) + sources[source] += 1 + + sources = sorted(sources.items(), key=lambda x: x[1], reverse=True) + for k, v in sources: + print(f'{k:20s} {v:10d}') + if save_path is not None: + Path(save_path).write_text( + '\n'.join([f'{k:20s} {v:10d}' for k, v in sources])) + + def save(self, save_path): + """ + Save the index to a json file. + + Parameters + ---------- + save_path: str or pathlib.Path + The path to save the index file. + """ + builder = IndexV2Builder(data_type=self.data_type, + arrow_files=self.arrow_files, + cum_length=self.cum_length, + indices=self.indices, + ) + builder.build(save_path) + + def sample_batch_indices(self, n): + return np.random.choice(self.indices, n) + + def sample_batch(self, n, columns, progress=True, shadow=None): + if isinstance(n, int): + indices = self.sample_batch_indices(n) + else: + indices = n + + if progress: + pbar = tqdm(indices) + else: + pbar = indices + + batch_data = [] + for i in pbar: + batch_data.append(self.get_data_by_index(i, columns, shadow=shadow)) + return batch_data + + @staticmethod + def resize_and_crop(image, target_size, resample=Image.LANCZOS, crop_type='random'): + """ + Resize image without changing aspect ratio, then crop the center/random part. + + Parameters + ---------- + image: PIL.Image.Image + The input image to be resized and cropped. + target_size: tuple + The target size of the image. + resample: + The resample method. See PIL.Image.Image.resize for details. Default to Image.LANCZOS. + crop_type: str + 'center' or 'random'. If 'center', crop the center part of the image. If 'random', + crop a random part of the image. Default to 'random'. + + Returns + ------- + image: PIL.Image.Image + The resized and cropped image. + crop_pos: tuple + The position of the cropped part. (crop_left, crop_top) + """ + tw, th = target_size + w, h = image.size + + tr = th / tw + r = h / w + + # resize + if r < tr: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + image = image.resize((resize_width, resize_height), resample=resample) + + if crop_type == 'center': + crop_top = int(round((resize_height - th) / 2.0)) + crop_left = int(round((resize_width - tw) / 2.0)) + elif crop_type == 'random': + crop_top = random.randint(0, resize_height - th) + crop_left = random.randint(0, resize_width - tw) + else: + raise ValueError(f'crop_type must be center or random, but got {crop_type}') + + image = image.crop((crop_left, crop_top, crop_left + tw, crop_top + th)) + return image, (crop_left, crop_top) + + +class IndexV2Builder(object): + def __init__(self, + arrow_files, + indices=None, + cum_length=None, + group_length=None, + data_type=None, + max_indices=5_000_000, + example_num=1000, + config_file=None, + ): + """ + Build index v2 from an index dict. + + Parameters + ---------- + arrow_files: list + A list of arrow files. + indices: list or dict + A list of indices or a dict of indices. + If not provided, it will be specified as range(cum_length[-1]). + cum_length: list + A list of cumulative length of arrow files. + If not provided, it will be calculated from arrow files. + group_length: list + A list of group length or a dict of group length for each arrow file. + If not provided, it will be calculated. + data_type: str or list + Some custom information of this index. + max_indices: int + If the number of indices is larger than max_indices, the indices will be saved in a separate file. + Default to 5_000_000. + example_num: int + The number of examples to be saved in the index file. Default to 1000. + config_file: str + The path of config file. + + Examples + -------- + >>> builder = IndexV2Builder( + >>> data_type='gold', + >>> arrow_files=arrow_files, + >>> cum_length=cum_length, + >>> indices=indices, + >>> ) + >>> builder.build(save_path) + + """ + self.arrow_files = arrow_files + self.indices = indices + self.cum_length = cum_length + self.group_length = group_length + self.data_type = data_type + self.max_indices = max_indices + self.example_num = example_num + self.config_file = config_file + + if isinstance(arrow_files, str): + if '*' in arrow_files or '?' in arrow_files: + self.arrow_files = list(glob(arrow_files)) + else: + self.arrow_files = [arrow_files] + elif isinstance(self.arrow_files, tuple): + self.arrow_files = list(self.arrow_files) + if not isinstance(self.arrow_files, list): + raise ValueError(f'Expected arrow_files to be a list, got {type(self.arrow_files)}.') + + if self.cum_length is None: + continuous = False + if self.indices is None: + self.group_length = [] + continuous = True + + print(f"Calculating cum_length...") + self.cum_length = [] + cur_cum_length = 0 + pbar = tqdm(self.arrow_files) + for arrow_file in pbar: + table_length = len(get_table(arrow_file)) + cur_cum_length += table_length + self.cum_length.append(cur_cum_length) + pbar.set_description(f"{self.cum_length[-1]:>12d}") + + if continuous: + self.group_length.append(table_length) + + if self.indices is None: + self.indices = list(range(self.cum_length[-1])) + + if self.group_length is None: + self.group_length = [] + + if self.data_type is None: + self.data_type = ['Made by IndexV2Builder'] + elif isinstance(self.data_type, str): + self.data_type = [self.data_type] + + assert_type(self.data_type, list, 'data_type') + assert_type(self.cum_length, (list, np.ndarray), 'cum_length') + assert_type(self.group_length, (list, dict, np.ndarray), 'group_length') + assert_type(self.indices, (list, dict, np.ndarray), 'indices') + self.cum_length = ndarray_to_list(self.cum_length) + self.group_length = ndarray_to_list(self.group_length) + self.indices = ndarray_to_list(self.indices) + + if isinstance(self.indices, dict): + for k, v in self.indices.items(): + assert_type(v, list, f'indices[{k}]') + + if len(self.arrow_files) != len(self.cum_length): + raise ValueError(f'Length of arrow_files and cum_length does not match. {len(self.arrow_files)} != {len(self.cum_length)}') + if len(self.indices) == 0: + raise ValueError(f'No indices found in index_dict.') + if isinstance(self.indices, list) and self.indices[-1] > self.cum_length[-1] - 1: + raise ValueError(f'Indices exceed cum_length. {self.indices[-1]} > {self.cum_length[-1] - 1}') + if len(self.group_length) > 0: + if len(self.arrow_files) != len(self.group_length): + raise ValueError(f'Length of arrow_files and group_length does not match. {len(self.arrow_files)} != {len(self.group_length)}') + if sum(self.group_length) != len(self.indices): + raise ValueError(f'Sum of group_length does not match length of indices. {sum(self.group_length)} != {len(self.indices)}') + + def encode(self): + # Encode arrow files + print("Encoding arrow files...") + arrow_files = [] + for arrow_file in tqdm(self.arrow_files): + shortname = arrow_file + arrow_files.append(shortname) + self.arrow_files = arrow_files + + # Calculate group_length + print("Calculating group length...") + if isinstance(self.indices, list): + if len(self.group_length) == 0: + self.group_length = self.calc_group_length(self.indices, self.cum_length) + else: + print("Group length already calculated, skip.") + elif isinstance(self.indices, dict): + if not isinstance(self.group_length, dict): + self.group_length = {} + for k, v in self.indices.items(): + print(f"Calculating group length for {k}...") + if k not in self.group_length or len(self.group_length[k]) == 0: + self.group_length[k] = self.calc_group_length(v, self.cum_length) + else: + print("Group length already calculated, skip.") + else: + raise ValueError(f'Expected indices type list or dict, got {type(self.indices)}.') + + return { + 'data_type': self.data_type, + 'config_file': self.config_file if self.config_file is not None else '', + 'indices_file': '', + 'arrow_files': self.arrow_files, + 'cum_length': self.cum_length, + 'group_length': self.group_length, + 'indices': self.indices, + 'example_indices': [], + } + + def build(self, save_path): + return self.save(save_path) + + def save(self, save_path): + """ + Make index v2 from an index dict. + + Parameters + ---------- + save_path: str or pathlib.Path + The path to save the index file. + """ + index_dict = self.encode() + # Ensure the indices either a list or a dict. + + save_path = Path(save_path) + save_path.parent.mkdir(exist_ok=True, parents=True) + + if isinstance(index_dict['indices'], list) and len(index_dict['indices']) > self.max_indices: + self.example_indices = index_dict['indices'][:self.example_num] + indices_to_save = {'x': index_dict['indices']} + index_dict['indices'] = [] + elif isinstance(index_dict['indices'], dict): + indices_to_save = index_dict['indices'] + index_dict['indices'] = {} + num_keys = len(indices_to_save) + example_num_per_key = max(self.example_num // num_keys, 10) + index_dict['example_indices'] = {k: v[:example_num_per_key] for k, v in index_dict['indices'].items()} + else: + indices_to_save = None + + # save indices + if indices_to_save is not None: + indices_file = save_path.parent / f'{save_path.stem}.index' + indices_dict = {k: np.array(v) for k, v in indices_to_save.items()} + np.savez_compressed(indices_file, **indices_dict) + index_dict['indices_file'] = indices_file.name + '.npz' + + with save_path.open('w') as f: + json.dump(index_dict, f, indent=4, ensure_ascii=False) + + @staticmethod + def calc_group_length(indices, cum_length): + group_lengths = [] + cum_ind = 0 + count = 0 + for index in tqdm(indices): + if index < cum_length[cum_ind]: + # index is still in the current group + count += 1 + else: + # index has exceeded the current group, need to switch to the next group + group_lengths.append(count) + cum_ind += 1 + # if the index exceeds the next group, continue to switch to the next group + while index >= cum_length[cum_ind]: + group_lengths.append(0) + cum_ind += 1 + count = 1 + # The indices array is exhausted, and the last group containing the index should also be added. + group_lengths.append(count) + assert len(group_lengths) <= len(cum_length), (len(group_lengths), len(cum_length)) + # Check if the number of groups is less than the number of cum_length, + # then the last n groups are empty and need to be filled with zeros. + if len(group_lengths) < len(cum_length): + group_lengths.extend([0] * (len(cum_length) - len(group_lengths))) + + return group_lengths \ No newline at end of file diff --git a/IndexKits/index_kits/sampler.py b/IndexKits/index_kits/sampler.py new file mode 100644 index 0000000..ab34ffc --- /dev/null +++ b/IndexKits/index_kits/sampler.py @@ -0,0 +1,138 @@ +import math + +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler + + +class BlockDistributedSampler(DistributedSampler): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False, + batch_size=-1, start_index=0): + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1)) + if batch_size == -1: + raise ValueError("batch_size should be specified") + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + self.shuffle = shuffle + self.seed = seed + self.batch_size = batch_size + self._start_index = start_index + self.recompute_sizes() + + @property + def start_index(self): + return self._start_index + + @start_index.setter + def start_index(self, value): + self._start_index = value + self.recompute_sizes() + + def recompute_sizes(self): + self.num_samples = len(self.dataset) // self.batch_size * self.batch_size // self.num_replicas \ + - self._start_index + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + raw_num_samples = len(indices) // self.batch_size * self.batch_size // self.num_replicas + raw_total_size = raw_num_samples * self.num_replicas + indices = indices[:raw_total_size] + + # We require that the dataset size is divisible by batch_size * num_replicas + # This is naturally satisfied when using index_kits. + # In future, we can remove this assertion. + assert len(indices) == raw_total_size, f"{len(indices)} vs {raw_total_size}" + + # subsample with start_index + indices = indices[self.rank * raw_num_samples + self.start_index:(self.rank + 1) * raw_num_samples] + assert len(indices) + self.start_index == raw_num_samples, \ + f"{len(indices) + self.start_index} vs {raw_num_samples}" + + # This is a sequential sampler. The shuffle operation is done by the dataset itself. + return iter(indices) + + +class DistributedSamplerWithStartIndex(DistributedSampler): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False, + start_index=0): + super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" + " [0, {}]".format(rank, num_replicas - 1)) + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + self._start_index = start_index + self.recompute_sizes() + self.shuffle = shuffle + self.seed = seed + + @property + def start_index(self): + return self._start_index + + @start_index.setter + def start_index(self, value): + self._start_index = value + self.recompute_sizes() + + def recompute_sizes(self): + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + if self.drop_last and (len(self.dataset) - self._start_index) % self.num_replicas != 0: # type: ignore[arg-type] + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + ((len(self.dataset) - self._start_index) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil((len(self.dataset) - self._start_index) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + indices = list(range(self._start_index, len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + assert len(indices) == self.total_size + + # subsample with start_index + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) diff --git a/IndexKits/setup.py b/IndexKits/setup.py new file mode 100644 index 0000000..57c9e3c --- /dev/null +++ b/IndexKits/setup.py @@ -0,0 +1,28 @@ +import re + +try: + from setuptools import setup +except ImportError: + from distutils.core import setup + +with open("index_kits/__init__.py", "r") as file: + regex_version = r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]' + version = re.search(regex_version, file.read(), re.MULTILINE).group(1) + + +setup( + name='index_kits', + version=version, + author='jarvizhang', + author_email='jarvizhang@tencent.com', + description='An index kits for streaming reading arrow data.', + packages=['index_kits', 'index_kits/dataset'], + scripts=['bin/idk'], + install_requires=[ + "pillow>=9.3.0", + "tqdm>=4.60.0", + "pyarrow>=10.0.1", + "torch>=1.9", + ], + python_requires=">=3.8.12", +) diff --git a/Notice b/Notice index be83e11..39f5f9f 100644 --- a/Notice +++ b/Notice @@ -39,7 +39,6 @@ Redistribution and use in source and binary forms, with or without modification, THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. For the license of other third party components, please refer to the following URL: -https://github.com/pytorch/pytorch/blob/v1.13.1/LICENSE https://github.com/pytorch/pytorch/blob/v1.13.1/NOTICE @@ -56,7 +55,6 @@ Copyright (c) 2011-2023, Open source contributors. A copy of the BSD 3-Clause is included in this file. For the license of other third party components, please refer to the following URL: -https://github.com/pandas-dev/pandas/blob/v2.0.3/LICENSE https://github.com/pandas-dev/pandas/tree/v2.0.3/LICENSES @@ -71,7 +69,6 @@ All rights reserved. A copy of the BSD 3-Clause is included in this file. For the license of other third party components, please refer to the following URL: -https://github.com/numpy/numpy/blob/v1.24.4/LICENSE.txt https://github.com/numpy/numpy/blob/v1.24.4/LICENSES_bundled.txt @@ -126,14 +123,17 @@ Copyright (c) mT5 original author and authors 10. Mistral-7B Copyright (c) 2024 Mistral AI, All rights reserved +11. peft +Copyright 2023 The HuggingFace Team. All rights reserved. + Terms of the Apache License Version 2.0: -------------------------------------------------------------------- -Apache License +Apache License Version 2.0, January 2004 -http://www.apache.org/licenses/ +http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. @@ -164,15 +164,15 @@ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: -You must give any other recipients of the Work or Derivative Works a copy of this License; and +You must give any other recipients of the Work or Derivative Works a copy of this License; and -You must cause any modified files to carry prominent notices stating that You changed the files; and +You must cause any modified files to carry prominent notices stating that You changed the files; and -You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and +You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and -If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. +If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. -You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. @@ -191,7 +191,7 @@ END OF TERMS AND CONDITIONS Open Source Software/Model Licensed under the BSD 3-Clause License: -------------------------------------------------------------------- 1. torchvision -Copyright (c) Soumith Chintala 2016, +Copyright (c) Soumith Chintala 2016, All rights reserved. 2. flash_attn @@ -278,6 +278,12 @@ Copyright (c) sd-vae-ft-ema original author and authors 9. ComfyUI-Diffusers Copyright (c) 2023 Limitex +10. glide-text2im +Copyright (c) 2021 OpenAI + +11. improved-diffusion +Copyright (c) 2021 OpenAI + Terms of the MIT License: -------------------------------------------------------------------- @@ -316,3 +322,13 @@ https://github.com/Stability-AI/generative-models/blob/main/LICENSE-CODE https://github.com/Stability-AI/generative-models/tree/main/model_licenses +Open Source Software/Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein: +-------------------------------------------------------------------- +1. pyarrow +Copyright 2016-2024 The Apache Software Foundation + + +A copy of the Apache License Version 2.0 is included in this file. + +For the license of other third party components, please refer to the following URL: +https://github.com/apache/arrow/blob/main/NOTICE.txt diff --git a/README.md b/README.md index 349a2a1..41e2f7d 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,11 @@ This repo contains PyTorch model definitions, pre-trained weights and inference/ > [**DialogGen: Multi-modal Interactive Dialogue System for Multi-turn Text-to-Image Generation**](https://arxiv.org/abs/2403.08857)
## 🔥🔥🔥 News!! +* Jun 13, 2024: :zap: HYDiT-v1.1 version is released, which mitigates the issue of image oversaturation and alleviates the watermark issue. Please check [HunyuanDiT-v1.1 ](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.1) and +[Distillation-v1.1](https://huggingface.co/Tencent-Hunyuan/Distillation-v1.1) for more details. +* Jun 13, 2024: :truck: The training code is released, offering [full-parameter training](#full-parameter-training) and [LoRA training](#lora). * Jun 06, 2024: :tada: Hunyuan-DiT is now available in ComfyUI. Please check [ComfyUI](#using-comfyui) for more details. -* Jun 06, 2024: 🚀 We introduce Distillation version for Hunyuan-DiT acceleration, which achieves **50%** acceleration on NVIDIA GPUs. Please check [Tencent-Hunyuan/Distillation](https://huggingface.co/Tencent-Hunyuan/Distillation) for more details. +* Jun 06, 2024: 🚀 We introduce Distillation version for Hunyuan-DiT acceleration, which achieves **50%** acceleration on NVIDIA GPUs. Please check [Distillation](https://huggingface.co/Tencent-Hunyuan/Distillation) for more details. * Jun 05, 2024: 🤗 Hunyuan-DiT is now available in 🤗 Diffusers! Please check the [example](#using--diffusers) below. * Jun 04, 2024: :globe_with_meridians: Support Tencent Cloud links to download the pretrained models! Please check the [links](#-download-pretrained-models) below. * May 22, 2024: 🚀 We introduce TensorRT version for Hunyuan-DiT acceleration, which achieves **47%** acceleration on NVIDIA GPUs. Please check [TensorRT-libs](https://huggingface.co/Tencent-Hunyuan/TensorRT-libs) for instructions. @@ -63,8 +66,8 @@ or multi-turn language interactions to create the picture. - [x] Checkpoints - [x] Distillation Version - [x] TensorRT Version - - [ ] Training (Coming later ⏩️) - - [ ] Lora + - [x] Training + - [x] Lora - [ ] Controlnet (Pose, Canny, Depth, Tile) - [ ] IP-adapter - [ ] Hunyuan-DiT-XL checkpoints (0.7B model) @@ -90,6 +93,10 @@ or multi-turn language interactions to create the picture. - [📜 Requirements](#-requirements) - [🛠 Dependencies and Installation](#%EF%B8%8F-dependencies-and-installation) - [🧱 Download Pretrained Models](#-download-pretrained-models) + - [:truck: Training](#truck-training) + - [Data Preparation](#data-preparation) + - [Full Parameter Training](#full-parameter-training) + - [LoRA](#lora) - [🔑 Inference](#-inference) - [Using Gradio](#using-gradio) - [Using Diffusers](#using--diffusers) @@ -278,13 +285,112 @@ All models will be automatically downloaded. For more information about the mode | DialogGen | 7.0B | [DialogGen](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/dialoggen) | [DialogGen](https://dit.hunyuan.tencent.com/download/HunyuanDiT/dialoggen.zip) | | sdxl-vae-fp16-fix | 83M | [sdxl-vae-fp16-fix](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/t2i/sdxl-vae-fp16-fix) | [sdxl-vae-fp16-fix](https://dit.hunyuan.tencent.com/download/HunyuanDiT/sdxl-vae-fp16-fix.zip) | | Hunyuan-DiT | 1.5B | [Hunyuan-DiT](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT/tree/main/t2i/model) | [Hunyuan-DiT](https://dit.hunyuan.tencent.com/download/HunyuanDiT/model.zip) | +| Data demo | - | - | [Data demo](https://dit.hunyuan.tencent.com/download/HunyuanDiT/data_demo.zip) | + +## :truck: Training + +### Data Preparation + + Refer to the commands below to prepare the training data. + + 1. Install dependencies + + We offer an efficient data management library, named IndexKits, supporting the management of reading hundreds of millions of data during training, see more in [docs](./IndexKits/README.md). + ```shell + # 1 Install dependencies + cd HunyuanDiT + pip install -e ./IndexKits + ``` + 2. Data download + + Feel free to download the [data demo](https://dit.hunyuan.tencent.com/download/HunyuanDiT/data_demo.zip). + ```shell + # 2 Data download + wget -O ./dataset/data_demo.zip https://dit.hunyuan.tencent.com/download/HunyuanDiT/data_demo.zip + unzip ./dataset/data_demo.zip -d ./dataset + mkdir ./dataset/porcelain/arrows ./dataset/porcelain/jsons + ``` + 3. Data conversion + + Create a CSV file for training data with the fields listed in the table below. + + | Fields | Required | Description | Example | + |:---------------:| :------: |:----------------:|:-----------:| + | `image_path` | Required | image path | `./dataset/porcelain/images/0.png` | + | `text_zh` | Required | text | 青花瓷风格,一只蓝色的鸟儿站在蓝色的花瓶上,周围点缀着白色花朵,背景是白色 | + | `md5` | Optional | image md5 (Message Digest Algorithm 5) | `d41d8cd98f00b204e9800998ecf8427e` | + | `width` | Optional | image width | `1024 ` | + | `height` | Optional | image height | ` 1024 ` | + + > ⚠️ Optional fields like MD5, width, and height can be omitted. If omitted, the script below will automatically calculate them. This process can be time-consuming when dealing with large-scale training data. + + We utilize [Arrow](https://github.com/apache/arrow) for training data format, offering a standard and efficient in-memory data representation. A conversion script is provided to transform CSV files into Arrow format. + ```shell + # 3 Data conversion + python ./hydit/data_loader/csv2arrow.py ./dataset/porcelain/csvfile/image_text.csv ./dataset/porcelain/arrows + ``` + + 4. Data Selection and Configuration File Creation + + We configure the training data through YAML files. In these files, you can set up standard data processing strategies for filtering, copying, deduplicating, and more regarding the training data. For more details, see [docs](IndexKits/docs/MakeDataset.md). + + For a sample file, please refer to [file](./dataset/yamls/porcelain.yaml). For a full parameter configuration file, see [file](./IndexKits/docs/MakeDataset.md). + + + 5. Create training data index file using YAML file. + + ```shell + # Single Resolution Data Preparation + cd /HunyuanDiT + idk base -c dataset/yamls/porcelain.yaml -t dataset/porcelain/jsons/porcelain.json + + # Multi Resolution Data Preparation + idk multireso -c dataset/yamls/porcelain_mt.yaml -t dataset/porcelain/jsons/porcelain_mt.json + ``` + + The directory structure for `porcelain` dataset is: + + ```shell + cd ./dataset + + porcelain + ├──images/ (image files) + │ ├──0.png + │ ├──1.png + │ ├──...... + ├──csvfile/ (csv files containing text-image pairs) + │ ├──image_text.csv + ├──arrows/ (arrow files containing all necessary training data) + │ ├──00000.arrow + │ ├──00001.arrow + │ ├──...... + ├──jsons/ (final training data index files which read data from arrow files during training) + │ ├──porcelain.json + │ ├──porcelain_mt.json + ``` + +### Full-parameter Training + + To leverage DeepSpeed in training, you have the flexibility to control **single-node** / **multi-node** training by adjusting parameters such as `--hostfile` and `--master_addr`. For more details, see [link](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). + + ```shell + # Single Resolution Data Preparation + PYTHONPATH=./ sh hydit/train.sh --index-file dataset/porcelain/jsons/porcelain.json + + # Multi Resolution Data Preparation + PYTHONPATH=./ sh hydit/train.sh --index-file dataset/porcelain/jsons/porcelain.json --multireso --reso-step 64 + ``` + +### LoRA + +We provide training and inference scripts for LoRA, detailed in the [guidances](./lora/README.md). ## 🔑 Inference ### Using Gradio -Make sure you have activated the conda environment before running the following command. +Make sure the conda environment is activated before running the following command. ```shell # By default, we start a Chinese UI. @@ -304,7 +410,7 @@ python app/hydit_app.py --lang en # If your GPU memory is less than 32GB, use '--load-4bit' to enable 4-bit quantization, which requires at least 22GB of memory. python app/multiTurnT2I_app.py ``` -Then the demo can be accessed through http://0.0.0.0:443. It should be noted that the 0.0.0.0 here needs to be repleaed with the server IP. +Then the demo can be accessed through http://0.0.0.0:443. It should be noted that the 0.0.0.0 here needs to be X.X.X.X with your server IP. ### Using 🤗 Diffusers diff --git a/comfyui-hydit/hydit/config.py b/comfyui-hydit/hydit/config.py index decffd3..7c3fb83 100644 --- a/comfyui-hydit/hydit/config.py +++ b/comfyui-hydit/hydit/config.py @@ -51,7 +51,7 @@ def get_args(default_args=None): parser_hunyuan.add_argument("--negative", type=str, default="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,", help="Negative prompt.") # Acceleration - parser_hunyuan.add_argument("--use_fp16", action="store_true", help="Use FP16 precision.") + parser_hunyuan.add_argument("--use-fp16", action="store_true", help="Use FP16 precision.") parser_hunyuan.add_argument("--no-fp16", dest="use_fp16", action="store_false") parser_hunyuan.set_defaults(use_fp16=True) diff --git a/comfyui-hydit/hydit/diffusion/pipeline.py b/comfyui-hydit/hydit/diffusion/pipeline.py index cfa07af..9f57453 100644 --- a/comfyui-hydit/hydit/diffusion/pipeline.py +++ b/comfyui-hydit/hydit/diffusion/pipeline.py @@ -754,10 +754,7 @@ def __call__( generator, latents, ) - - # Set the save path - #save_path = "/apdcephfs_cq8/share_1367250/xuhuaren/comfyui_project/comfyui_debug.pt" # Save the variables as a dictionary """ @@ -780,8 +777,6 @@ def __call__( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - #with open("/apdcephfs_cq8/share_1367250/xuhuaren/dit-open/HunyuanDiT/output_python.txt", "a") as output_file: - # output_file.write(f"{t}\n") # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) diff --git a/dataset/yamls/porcelain.yaml b/dataset/yamls/porcelain.yaml new file mode 100644 index 0000000..95e8c65 --- /dev/null +++ b/dataset/yamls/porcelain.yaml @@ -0,0 +1,15 @@ +source: + - ./dataset/porcelain/arrows/00000.arrow + +filter: + column: + - name: height + type: int + action: ge + target: 1024 + default: 1024 + - name: width + type: int + action: ge + target: 1024 + default: 1024 \ No newline at end of file diff --git a/dataset/yamls/porcelain_mt.yaml b/dataset/yamls/porcelain_mt.yaml new file mode 100644 index 0000000..96ca828 --- /dev/null +++ b/dataset/yamls/porcelain_mt.yaml @@ -0,0 +1,5 @@ +src: + - ./dataset/porcelain/jsons/porcelain.json +base_size: 1024 +reso_step: 64 +min_size: 1024 \ No newline at end of file diff --git a/hydit/config.py b/hydit/config.py index f09a59b..013d9ce 100644 --- a/hydit/config.py +++ b/hydit/config.py @@ -1,35 +1,53 @@ import argparse from .constants import * -from .modules.models import HUNYUAN_DIT_CONFIG +from .modules.models import HUNYUAN_DIT_CONFIG, HUNYUAN_DIT_MODELS +from .diffusion.gaussian_diffusion import ModelVarType + +import deepspeed + +def model_var_type(value): + try: + return ModelVarType[value] + except KeyError: + raise ValueError(f"Invalid choice '{value}', valid choices are {[v.name for v in ModelVarType]}") def get_args(default_args=None): parser = argparse.ArgumentParser() + parser.add_argument("--task-flag", type=str) - # Basic - parser.add_argument("--prompt", type=str, default="一只小猫", help="The prompt for generating images.") - parser.add_argument("--model-root", type=str, default="ckpts", help="Model root path.") + # General Setting + parser.add_argument("--batch-size", type=int, default=1, help="Per-GPU batch size") + parser.add_argument('--seed', type=int, default=42, help="A seed for all the prompts.") + parser.add_argument("--use-fp16", action="store_true", help="Use FP16 precision.") + parser.add_argument("--no-fp16", dest="use_fp16", action="store_false") + parser.set_defaults(use_fp16=True) + parser.add_argument("--extra-fp16", action="store_true", help="Use extra fp16 for vae and text_encoder.") + + # HunYuan-DiT + parser.add_argument("--model", type=str, choices=list(HUNYUAN_DIT_CONFIG.keys()), default='DiT-g/2') parser.add_argument("--image-size", type=int, nargs='+', default=[1024, 1024], help='Image size (h, w). If a single value is provided, the image will be treated to ' '(value, value).') - parser.add_argument("--infer-mode", type=str, choices=["fa", "torch", "trt"], default="torch", - help="Inference mode") + parser.add_argument("--qk-norm", action="store_true", help="Query Key normalization. See http://arxiv.org/abs/2302.05442 for details.") + parser.set_defaults(qk_norm=True) + parser.add_argument("--norm", type=str, choices=["rms", "laryer"], default="layer", help="Normalization layer type") + parser.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.") + parser.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.") + parser.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of CLIP text encoder.") + parser.add_argument("--text-len-t5", type=int, default=256, help="Token length of T5 text encoder output.") - # HunYuan-DiT - parser.add_argument("--model", type=str, choices=list(HUNYUAN_DIT_CONFIG.keys()), default='DiT-g/2') - parser.add_argument("--norm", type=str, default="layer", help="Normalization layer type") - parser.add_argument("--load-key", type=str, choices=["ema", "module", "distill"], default="ema", help="Load model key for HunYuanDiT checkpoint.") - parser.add_argument('--size-cond', type=int, nargs='+', default=[1024, 1024], - help="Size condition used in sampling. 2 values are required for height and width. " - "If a single value is provided, the image will be treated to (value, value).") - parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.") + # LoRA config + parser.add_argument("--training_parts", type=str, default='all', choices=['all', 'lora'], help="Training parts") + parser.add_argument("--rank", type=int, default=64, help="Rank of LoRA") + parser.add_argument("--lora_ckpt", type=str, default=None, help="LoRA checkpoint") + parser.add_argument('--target-modules', type=str, nargs='+', default=['Wqkv', 'q_proj', 'kv_proj', 'out_proj'], + help="Target modules for LoRA fine tune") + parser.add_argument("--output_merge_path", type=str, default=None, help="Output path for merged model") + + - # Prompt enhancement - parser.add_argument("--enhance", action="store_true", help="Enhance prompt with dialoggen.") - parser.add_argument("--no-enhance", dest="enhance", action="store_false") - parser.add_argument("--load-4bit", help="load DialogGen model with 4bit quantization.", action="store_true") - parser.set_defaults(enhance=True) # Diffusion parser.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.") @@ -41,29 +59,139 @@ def get_args(default_args=None): help="Noise schedule") parser.add_argument("--beta-start", type=float, default=0.00085, help="Beta start value") parser.add_argument("--beta-end", type=float, default=0.03, help="Beta end value") + parser.add_argument("--sigma-small", action="store_true") + parser.add_argument("--mse-loss-weight-type", type=str, default="constant", + help="Min-SNR-gamma. Can be constant or min_snr_ where gamma is a integer. 5 is recommended in the paper.") + parser.add_argument("--model-var-type", type=model_var_type, default=None, help="Specify the model variable type.") + parser.add_argument("--noise-offset", type=float, default=0.0, help="Add extra noise to the input image.") - # Text condition - parser.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.") - parser.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.") - parser.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of CLIP text encoder.") - parser.add_argument("--text-len-t5", type=int, default=256, help="Token length of T5 text encoder output.") + # ======================================================================================================== + # Inference + # ======================================================================================================== + + # Basic Setting + parser.add_argument("--prompt", type=str, default="一只小猫", help="The prompt for generating images.") + parser.add_argument("--model-root", type=str, default="ckpts", help="Model root path.") + + # Model setting + parser.add_argument("--load-key", type=str, choices=["ema", "module", "distill", 'merge'], default="ema", help="Load model key for HunYuanDiT checkpoint.") + parser.add_argument('--size-cond', type=int, nargs='+', default=[1024, 1024], + help="Size condition used in sampling. 2 values are required for height and width. " + "If a single value is provided, the image will be treated to (value, value).") + parser.add_argument('--target-ratios', type=str, nargs='+', default=None, + help="Target ratios for multi-resolution training.") + parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.") parser.add_argument("--negative", type=str, default=None, help="Negative prompt.") # Acceleration - parser.add_argument("--use_fp16", action="store_true", help="Use FP16 precision.") - parser.add_argument("--no-fp16", dest="use_fp16", action="store_false") - parser.set_defaults(use_fp16=True) + parser.add_argument("--infer-mode", type=str, choices=["fa", "torch", "trt"], default="torch", help="Inference mode") parser.add_argument("--onnx-workdir", type=str, default="onnx_model", help="Path to save ONNX model") # Sampling - parser.add_argument("--batch-size", type=int, default=1, help="Per-GPU batch size") parser.add_argument("--sampler", type=str, choices=SAMPLER_FACTORY, default="ddpm", help="Diffusion sampler") parser.add_argument("--infer-steps", type=int, default=100, help="Inference steps") - parser.add_argument('--seed', type=int, default=42, help="A seed for all the prompts.") + + # Prompt enhancement + parser.add_argument("--enhance", action="store_true", help="Enhance prompt with dialoggen.") + parser.add_argument("--no-enhance", dest="enhance", action="store_false") + parser.add_argument("--load-4bit", help="load DialogGen model with 4bit quantization.", action="store_true") + parser.set_defaults(enhance=True) # App parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="Language") + + # ======================================================================================================== + # Training + # ======================================================================================================== + + # Basic Setting + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--epochs", type=int, default=1400) + parser.add_argument("--max-training-steps", type=int, default=10_000_000) + parser.add_argument("--gc-interval", type=int, default=40, + help='To address the memory bottleneck encountered during the preprocessing of the dataset,' + ' memory fragments are reclaimed here by invoking the gc.collect() function.') + parser.add_argument("--log-every", type=int, default=100) + parser.add_argument("--ckpt-every", type=int, default=100_000, help="Create a ckpt every a few steps.") + parser.add_argument("--ckpt-latest-every", type=int, default=10_000, help="Create a ckpt named `latest.pt` every a few steps.") + parser.add_argument("--num-workers", type=int, default=4) + parser.add_argument("--global-seed", type=int, default=1234) + parser.add_argument("--warmup-min-lr", type=float, default=1e-6) + parser.add_argument("--warmup-num-steps", type=float, default=0) + parser.add_argument("--weight-decay", type=float, default=0, help="weight-decay in optimizer") + parser.add_argument("--rope-img", type=str, default=None, choices=['extend', 'base512', 'base1024'], + help="Extend or interpolate the positional embedding of the image.") + parser.add_argument("--rope-real", action="store_true", + help="Use real part and imaginary part separately for RoPE.") + + # Classifier-free + parser.add_argument("--uncond-p", type=float, default=0.2, + help="The probability of dropping training text used for CLIP feature extraction") + parser.add_argument("--uncond-p-t5", type=float, default=0.2, + help="The probability of dropping training text used for mT5 feature extraction") + + # Directory + parser.add_argument("--results-dir", type=str, default="results") + parser.add_argument("--resume", type=str, default=None, help="Resume experiment from a checkpoint") + parser.add_argument("--no-strict", dest="strict", action="store_false", help="Strict loading of checkpoint") + parser.set_defaults(strict=True) + parser.add_argument("--resume-deepspeed", action="store_true", + help="Resume model and ema states from a checkpoint saved by Deepspeed version of DIT.") + parser.add_argument("--resume-split", action="store_true", + help="Resume model and ema states from two checkpoint separated from DeepSpeed ckpt.") + parser.add_argument("--ema-to-module", action="store_true", + help="If true, initialize the module with EMA weights.") + parser.add_argument("--module-to-ema", action="store_true", + help="if true, initialize the ema with Module weights.") + + # Dataset + parser.add_argument("--index-file", type=str, nargs='+', help="During training, provide a JSON file with data indices.") + parser.add_argument("--random-flip", action="store_true", help="Random flip image") + parser.add_argument("--reset-loader", action="store_true", + help="Reset the data loader. It is useful when resuming from a checkpoint but switch to a new dataset.") + parser.add_argument("--multireso", action="store_true", help="Use multi-resolution training.") + parser.add_argument("--reso-step", type=int, default=None, help="Step size for multi-resolution training.") + + # Additional condition + parser.add_argument("--random-shrink-size-cond", action="store_true", + help="Randomly shrink the original size condition.") + parser.add_argument("--merge-src-cond", action="store_true", help="Merge the source condition into a single value.") + + # EMA Model + parser.add_argument("--use-ema", action="store_true", help="Use EMA model") + parser.add_argument("--ema-dtype", type=str, choices=['fp16', 'fp32', 'none'], default="none", + help="EMA data type. If none, use the same data type as the model.") + parser.add_argument("--ema-decay", type=float, default=None, + help="EMA decay rate. If None, use the default value of the model.") + parser.add_argument("--ema-warmup", action="store_true", + help="EMA warmup. If True, perform ema_decay warmup from 0 to ema_decay.") + parser.add_argument("--ema-warmup-power", type=float, default=None, + help="EMA power. If None, use the default value of the model.") + parser.add_argument("--ema-reset-decay", action="store_true", + help="Reset EMA decay to 0 and restart increasing the EMA decay." + "Only works when --ema-warmup is enabled.") + # Acceleration + parser.add_argument("--use-flash-attn", action="store_true", help="During training, " + "flash attention is used to accelerate training.") + parser.add_argument("--no-flash-attn", dest="use_flash_attn", + action="store_false", help="During training, flash attention is not used to accelerate training.") + parser.add_argument("--use-zero-stage", type=int, default=1, help="Use AngelPTM zero stage. Support 2 and 3") + parser.add_argument("--grad-accu-steps", type=int, default=1, help="Gradient accumulation steps.") + + # ======================================================================================================== + # Deepspeed config + # ======================================================================================================== + parser = deepspeed.add_config_arguments(parser) + parser.add_argument('--local_rank', type=int, default=None, + help='local rank passed from distributed launcher.') + parser.add_argument('--deepspeed-optimizer', action='store_true', + help='Switching to the optimizers in DeepSpeed') + parser.add_argument('--remote-device', type=str, default='none', choices=['none', 'cpu', 'nvme'], + help='Remote device for ZeRO-3 initialized parameters.') + parser.add_argument('--zero-stage', type=int, default=1) + parser.add_argument("--async-ema", action="store_true", help="Whether to use multi stream to excut EMA.") + args = parser.parse_args(default_args) return args diff --git a/hydit/constants.py b/hydit/constants.py index fb6a6b4..a469215 100644 --- a/hydit/constants.py +++ b/hydit/constants.py @@ -1,4 +1,7 @@ +import torch + # ======================================================= + NOISE_SCHEDULES = { "linear", "scaled_linear", @@ -12,6 +15,7 @@ } # ======================================================= + NEGATIVE_PROMPT = '错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,' # ======================================================= @@ -23,6 +27,18 @@ # Constants about models # ======================================================= +VAE_EMA_PATH = "ckpts/t2i/sdxl-vae-fp16-fix" +TOKENIZER = "ckpts/t2i/tokenizer" +TEXT_ENCODER = 'ckpts/t2i/clip_text_encoder' +T5_ENCODER = { + 'MT5': 'ckpts/t2i/mt5', + 'attention_mask': True, + 'layer_index': -1, + 'attention_pool': True, + 'torch_dtype': torch.float16, + 'learnable_replace': True +} + SAMPLER_FACTORY = { 'ddpm': { 'scheduler': 'DDPMScheduler', diff --git a/hydit/data_loader/__init__.py b/hydit/data_loader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hydit/data_loader/arrow_load_stream.py b/hydit/data_loader/arrow_load_stream.py new file mode 100644 index 0000000..ac2bd17 --- /dev/null +++ b/hydit/data_loader/arrow_load_stream.py @@ -0,0 +1,255 @@ +import pickle +import random +from pathlib import Path +import ast +import numpy as np +import re +import json +import time +from functools import partial +from PIL import Image + +import torch +import torchvision.transforms as T +import torch.nn.functional as F +from torchvision.transforms import functional as TF +from torch.utils.data import Dataset + +from IndexKits.index_kits import ArrowIndexV2, MultiResolutionBucketIndexV2, MultiIndexV2 + + +class TextImageArrowStream(Dataset): + def __init__(self, + args, + resolution=512, + random_flip=None, + enable_CN=True, + log_fn=print, + index_file=None, + multireso=False, + batch_size=-1, + world_size=1, + random_shrink_size_cond=False, + merge_src_cond=False, + uncond_p=0.0, + text_ctx_len=77, + tokenizer=None, + uncond_p_t5=0.0, + text_ctx_len_t5=256, + tokenizer_t5=None, + ): + self.args = args + self.resolution = resolution + self.log_fn = lambda x: log_fn(f" {Path(__file__).stem} | " + x) + + self.random_flip = random_flip + # If true, the Chinese prompt from the `text_zh` column will be taken from the arrow file; + # otherwise, the English prompt from the `text_en` column will be taken, + # provided that `text_zh` or `text_en` exists in the arrow file. + self.enable_CN = enable_CN + self.index_file = index_file + self.multireso = multireso + self.batch_size = batch_size + self.world_size = world_size + self.index_manager = self.load_index() + + # clip params + self.uncond_p = uncond_p + self.text_ctx_len = text_ctx_len + self.tokenizer = tokenizer + + # t5 params + self.uncond_p_t5 = uncond_p_t5 + self.text_ctx_len_t5 = text_ctx_len_t5 + self.tokenizer_t5 = tokenizer_t5 + + # size condition + self.random_shrink_size_cond = random_shrink_size_cond + self.merge_src_cond = merge_src_cond + + assert isinstance(resolution, int), f"resolution must be an integer, got {resolution}" + self.flip_norm = T.Compose( + [ + T.RandomHorizontalFlip() if self.random_flip else T.Lambda(lambda x: x), + T.ToTensor(), + T.Normalize([0.5], [0.5]), + ] + ) + + # show info + if self.merge_src_cond: + self.log_fn("Enable merging src condition: (oriW, oriH) --> ((WH)**0.5, (WH)**0.5)") + + self.log_fn("Enable image_meta_size condition (original_size, target_size, crop_coords)") + self.log_fn(f"Image_transforms: {self.flip_norm}") + + def load_index(self): + multireso = self.multireso + index_file = self.index_file + batch_size = self.batch_size + world_size = self.world_size + + if multireso: + if isinstance(index_file, (list, tuple)): + if len(index_file) > 1: + raise ValueError(f"When enabling multireso, index_file should be a single file, but got {index_file}") + index_file = index_file[0] + index_manager = MultiResolutionBucketIndexV2(index_file, batch_size, world_size) + self.log_fn(f"Using MultiResolutionBucketIndexV2: {len(index_manager):,}") + else: + if isinstance(index_file, str): + index_file = [index_file] + if len(index_file) == 1: + index_manager = ArrowIndexV2(index_file[0]) + self.log_fn(f"Using ArrowIndexV2: {len(index_manager):,}") + else: + index_manager = MultiIndexV2(index_file) + self.log_fn(f"Using MultiIndexV2: {len(index_manager):,}") + + return index_manager + + def shuffle(self, seed, fast=False): + self.index_manager.shuffle(seed, fast=fast) + + def get_raw_image(self, index, image_key="image"): + try: + ret = self.index_manager.get_image(index, image_key) + except Exception as e: + self.log_fn(f'get_raw_image | Error: {e}') + ret = Image.new("RGB", (256, 256), (255, 255, 255)) + return ret + + @staticmethod + def random_crop_image(image, origin_size, target_size): + aspect_ratio = float(origin_size[0]) / float(origin_size[1]) + if origin_size[0] < origin_size[1]: + new_width = target_size[0] + new_height = int(new_width / aspect_ratio) + else: + new_height = target_size[1] + new_width = int(new_height * aspect_ratio) + + image = image.resize((new_width, new_height), Image.LANCZOS) + + if new_width > target_size[0]: + x_start = random.randint(0, new_width - target_size[0]) + y_start = 0 + else: + x_start = 0 + y_start = random.randint(0, new_height - target_size[1]) + image_crop = image.crop((x_start, y_start, x_start + target_size[0], y_start + target_size[1])) + crops_coords_top_left = (x_start, y_start) + return image_crop, crops_coords_top_left + + def get_style(self, index): + "Here we use a default learned embedder layer for future extension." + style = 0 + return style + + def get_image_with_hwxy(self, index, image_key="image"): + + image = self.get_raw_image(index, image_key=image_key) + origin_size = image.size + + if self.multireso: + target_size = self.index_manager.get_target_size(index) + image, crops_coords_top_left = self.index_manager.resize_and_crop( + image, target_size, resample=Image.LANCZOS, crop_type='random') + image_tensor = self.flip_norm(image) + else: + target_size = (self.resolution, self.resolution) + image_crop, crops_coords_top_left = self.random_crop_image(image, origin_size, target_size) + image_tensor = self.flip_norm(image_crop) + + if self.random_shrink_size_cond: + origin_size = (1024 if origin_size[0] < 1024 else origin_size[0], + 1024 if origin_size[1] < 1024 else origin_size[1]) + if self.merge_src_cond: + val = (origin_size[0] * origin_size[1]) ** 0.5 + origin_size = (val, val) + + image_meta_size = tuple(origin_size) + tuple(target_size) + tuple(crops_coords_top_left) + kwargs = { + 'image_meta_size': image_meta_size, + } + + style = self.get_style(index) + kwargs['style'] = style + + return image_tensor, kwargs + + def get_text_info_with_encoder(self, description): + pad_num = 0 + text_inputs = self.tokenizer( + description, + padding="max_length", + max_length=self.text_ctx_len, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids[0] + attention_mask = text_inputs.attention_mask[0].bool() + if pad_num > 0: + attention_mask[1:pad_num + 1] = False + return description, text_input_ids, attention_mask + + def fill_t5_token_mask(self, fill_tensor, fill_number, setting_length): + fill_length = setting_length - fill_tensor.shape[1] + if fill_length > 0: + fill_tensor = torch.cat((fill_tensor, fill_number * torch.ones(1, fill_length)), dim=1) + return fill_tensor + + def get_text_info_with_encoder_t5(self, description_t5): + text_tokens_and_mask = self.tokenizer_t5( + description_t5, + max_length=self.text_ctx_len_t5, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors='pt' + ) + text_input_ids_t5=self.fill_t5_token_mask(text_tokens_and_mask["input_ids"], fill_number=1, setting_length=self.text_ctx_len_t5).long() + attention_mask_t5=self.fill_t5_token_mask(text_tokens_and_mask["attention_mask"], fill_number=0, setting_length=self.text_ctx_len_t5).bool() + return description_t5, text_input_ids_t5, attention_mask_t5 + + def get_original_text(self, ind): + text = self.index_manager.get_attribute(ind, 'text_zh' if self.enable_CN else 'text_en') + text = str(text).strip() + return text + + def get_text(self, ind): + text = self.get_original_text(ind) + if text == '': + text = '随机生成一张图片' + return text + + def __getitem__(self, ind): + # Get text + if random.random() < self.uncond_p: + description = "" + else: + description = self.get_text(ind) + + # Get text for t5 + if random.random() < self.uncond_p_t5: + description_t5 = "" + else: + description_t5 = self.get_text(ind) + + original_pil_image, kwargs = self.get_image_with_hwxy(ind) + + # Use encoder to embed tokens online + text, text_embedding, text_embedding_mask = self.get_text_info_with_encoder(description) + + text_t5, text_embedding_t5, text_embedding_mask_t5 = self.get_text_info_with_encoder_t5(description_t5) + return ( + original_pil_image, + text_embedding.clone().detach(), + text_embedding_mask.clone().detach(), + text_embedding_t5.clone().detach(), + text_embedding_mask_t5.clone().detach(), + {k: torch.tensor(np.array(v)).clone().detach() for k, v in kwargs.items()}, + ) + + def __len__(self): + return len(self.index_manager) diff --git a/hydit/data_loader/csv2arrow.py b/hydit/data_loader/csv2arrow.py new file mode 100644 index 0000000..10518e2 --- /dev/null +++ b/hydit/data_loader/csv2arrow.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +import datetime +import gc +import os +import time +from multiprocessing import Pool +import subprocess +import pandas as pd +import pyarrow as pa +from tqdm import tqdm +import hashlib +from PIL import Image +import sys + + +def parse_data(data): + try: + img_path = data[0] + + with open(img_path, "rb") as fp: + image = fp.read() + md5 = hashlib.md5(fp.read()).hexdigest() + + with Image.open(img_path) as f: + width, height = f.size + + return [data[1], md5, width, height, image] + + except Exception as e: + print(f'error: {e}') + return + +def make_arrow(csv_root, dataset_root, start_id=0, end_id=-1): + print(csv_root) + arrow_dir = dataset_root + print(arrow_dir) + + if not os.path.exists(arrow_dir): + os.makedirs(arrow_dir) + + data = pd.read_csv(csv_root) + data = data[["image_path", "text_zh"]] + columns_list = data.columns.tolist() + columns_list.append("image") + + if end_id < 0: + end_id = len(data) + print(f'start_id:{start_id} end_id:{end_id}') + data = data[start_id:end_id] + num_slice = 5000 + start_sub = int(start_id / num_slice) + sub_len = int(len(data) // num_slice) # if int(len(data) // num_slice) else 1 + subs = list(range(sub_len + 1)) + for sub in tqdm(subs): + arrow_path = os.path.join(arrow_dir, '{}.arrow'.format(str(sub + start_sub).zfill(5))) + if os.path.exists(arrow_path): + continue + print(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} start {sub + start_sub}") + + sub_data = data[sub * num_slice: (sub + 1) * num_slice].values + + bs = pool.map(parse_data, sub_data) + bs = [b for b in bs if b] + print(f'length of this arrow:{len(bs)}') + + columns_list = ["text_zh", "md5", "width", "height", "image"] + dataframe = pd.DataFrame(bs, columns=columns_list) + table = pa.Table.from_pandas(dataframe) + + os.makedirs(dataset_root, exist_ok=True) + with pa.OSFile(arrow_path, "wb") as sink: + with pa.RecordBatchFileWriter(sink, table.schema) as writer: + writer.write_table(table) + del dataframe + del table + del bs + gc.collect() + + +if __name__ == '__main__': + + if len(sys.argv) != 3: + print("Usage: python hydit/data_loader/csv2arrow.py ${csv_root} ${output_arrow_data_path}") + sys.exit(1) + csv_root = sys.argv[1] + output_arrow_data_path = sys.argv[2] + pool = Pool(500) + make_arrow(csv_root, output_arrow_data_path) \ No newline at end of file diff --git a/hydit/diffusion/__init__.py b/hydit/diffusion/__init__.py index e69de29..f6a6811 100644 --- a/hydit/diffusion/__init__.py +++ b/hydit/diffusion/__init__.py @@ -0,0 +1,49 @@ +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +def create_diffusion( + *, + steps=1000, + learn_sigma=True, + sigma_small=False, + noise_schedule="linear", + use_kl=False, + predict_type='epsilon', + rescale_timesteps=False, + rescale_learned_sigmas=False, + timestep_respacing="", + mse_loss_weight_type='constant', + beta_start=0.0001, + beta_end=0.02, + noise_offset=0.0, +): + betas = gd.get_named_beta_schedule(noise_schedule, steps, beta_start, beta_end) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [steps] + mean_type = gd.predict_type_dict[predict_type] + + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=mean_type, + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + rescale_timesteps=rescale_timesteps, + mse_loss_weight_type=mse_loss_weight_type, + noise_offset=noise_offset, + ) diff --git a/hydit/diffusion/diffusion_utils.py b/hydit/diffusion/diffusion_utils.py new file mode 100644 index 0000000..9c62d5b --- /dev/null +++ b/hydit/diffusion/diffusion_utils.py @@ -0,0 +1,69 @@ +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs + diff --git a/hydit/diffusion/gaussian_diffusion.py b/hydit/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000..02b0e36 --- /dev/null +++ b/hydit/diffusion/gaussian_diffusion.py @@ -0,0 +1,1365 @@ +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl +from ..utils.tools import assert_shape + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + VELOCITY = enum.auto() # the model predicts v + + +predict_type_dict = { + 'epsilon': ModelMeanType.EPSILON, + 'sample': ModelMeanType.START_X, + 'v_prediction': ModelMeanType.VELOCITY, +} + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert_shape(betas, (num_diffusion_timesteps,)) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, beta_start=0.0001, beta_end=0.02): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * beta_start, # DDPM + beta_end=scale * beta_end, # DDPM + num_diffusion_timesteps=num_diffusion_timesteps, # DDPM + ) + elif schedule_name == "scaled_linear": + return get_beta_schedule( + "quad", + beta_start=beta_start, # StableDiffusion, should be 0.00085 + beta_end=beta_end, # StableDiffusion, should be 0.012 + num_diffusion_timesteps=num_diffusion_timesteps, # StableDiffusion + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + + Ported directly from here, and then adapted over time to further experimentation. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + :param model_mean_type: a ModelMeanType determining what the model outputs. + :param model_var_type: a ModelVarType determining how variance is output. + :param loss_type: a LossType determining the loss function to use. + :param rescale_timesteps: if True, pass floating point timesteps into the + model so that they are always scaled like in the + original paper (0 to 1000). + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, + mse_loss_weight_type='constant', + noise_offset=0.0, + ): + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + self.mse_loss_weight_type = mse_loss_weight_type + self.noise_offset = noise_offset + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + self.alphas_cumprod = alphas_cumprod + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert_shape(self.alphas_cumprod_prev, (self.num_timesteps,)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + self.sampler = { + "ddpm": self.p_sample_loop, + "ddim": self.ddim_sample_loop, + "plms": self.plms_sample_loop, + } + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + + In other words, sample from q(x_t | x_0). + + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert_shape(noise, x_start) + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + + q(x_{t-1} | x_t, x_0) + + """ + assert_shape(x_start, x_t) + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert_shape( + posterior_mean.shape[:1], + posterior_variance.shape[:1], + posterior_log_variance_clipped.shape[:1], + x_start.shape[:1], + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, + model_var_type=None, + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param model_var_type: if not None, overlap the default self.model_var_type. + It is useful when training with learned var but sampling with fixed var. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + if model_var_type is None: + model_var_type = self.model_var_type + + B, C = x.shape[:2] + assert_shape(t, (B,)) + out_dict = model(x, t, **model_kwargs) + model_output = out_dict['x'] + + if len(out_dict) > 1: + extra = {k: v for k, v in out_dict.items() if k != 'x'} + else: + extra = None + + # self.model_var_type corresponds to model output + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert_shape(model_output, (B, C * 2, *x.shape[2:])) + model_output, model_var_values = th.split(model_output, C, dim=1) + + # model_var_type corresponds to reverse diffusion process + if model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + if model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ + ModelMeanType.START_X, + ModelMeanType.EPSILON, + ModelMeanType.VELOCITY + ]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + elif self.model_mean_type == ModelMeanType.EPSILON: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_v(x_t=x, t=t, v=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + + assert_shape(model_mean, model_log_variance, pred_xstart, x) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert_shape(x_t, eps) + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_xstart_from_v(self, x_t, t, v): + assert_shape(x_t, v) + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def _predict_xstart_from_xprev(self, x_t, t, xprev): + assert_shape(x_t, xprev) + return ( # (xprev - coef2*x_t) / coef1 + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _velocity_from_xstart_and_noise(self, x_start, t, noise): + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * noise + - _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * x_start + ) + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert_shape(decoder_nll, x_start) + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"], "extra": out["extra"]} + + def training_losses(self, model, x_start, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + # Time steps + t = th.randint(0, self.num_timesteps, (x_start.shape[0],), device=x_start.device) + # Noise + if noise is None: + noise = th.randn_like(x_start) + if self.noise_offset > 0: + # Add channel wise noise offset + # https://www.crosslabs.org/blog/diffusion-with-offset-noise + noise = noise + self.noise_offset * th.randn(*x_start.shape[:2], 1, 1, device=x_start.device) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.mse_loss_weight_type == 'constant': + mse_loss_weight = th.ones_like(t) + elif self.mse_loss_weight_type.startswith("min_snr_"): + alpha = _extract_into_tensor(self.sqrt_alphas_cumprod, t, t.shape) + sigma = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, t.shape) + snr = (alpha / sigma) ** 2 + + k = float(self.mse_loss_weight_type.split('min_snr_')[-1]) + # min{snr, k} + mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0] / snr + else: + raise ValueError(self.mse_loss_weight_type) + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + out_dict = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + ) + terms["loss"] = out_dict["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + extra = out_dict["extra"] + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + out_dict = model(x_t, t, **model_kwargs) + model_output = out_dict['x'] + extra = {k: v for k, v in out_dict.items() if k != 'x'} + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert_shape(model_output, (B, C * 2, *x_t.shape[2:])) + model_output, model_var_values = th.split(model_output, C, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: dict(x=r), + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.VELOCITY: + target = self._velocity_from_xstart_and_noise(x_start, t, noise) + else: + target = { + # ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + # x_start=x_start, x_t=x_t, t=t + # )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert_shape(model_output, target, x_start) + raw_mse = mean_flat((target - model_output) ** 2).detach() + terms["mse"] = mse_loss_weight * mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + terms["raw_loss"] = raw_mse + terms["vb"].detach() + else: + terms["loss"] = terms["mse"] + terms["raw_loss"] = raw_mse + else: + raise NotImplementedError(self.loss_type) + + terms.update(extra) + return terms + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn( + x, t, **model_kwargs + ) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + model_var_type=None, + **kwargs, + ): + """ + Sample x_{t-1} from the model at the given timestep. + + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param model_var_type: if not None, overlap the default self.model_var_type. + It is useful when training with learned var but sampling with fixed var. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + model_var_type=model_var_type, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + model_var_type=None, + device=None, + progress=False, + progress_leave=True, + **kwargs, + ): + """ + Generate samples from the model. + + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param model_var_type: if not None, overlap the default self.model_var_type. + It is useful when training with learned var but sampling with fixed var. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + model_var_type=model_var_type, + device=device, + progress=progress, + progress_leave=progress_leave, + **kwargs, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + model_var_type=None, + device=None, + progress=False, + progress_leave=True, + **kwargs, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices, leave=progress_leave) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + model_var_type=model_var_type, + **kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + progress_leave=True, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + progress_leave=progress_leave, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + progress_leave=True, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices, leave=progress_leave) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + + This term can't be optimized, as it only depends on the encoder. + + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + def get_eps( + self, + model, + x, + t, + model_kwargs, + cond_fn=None, + ): + model_output = model(x, t, **model_kwargs)['x'] + if isinstance(model_output, tuple): + model_output, _ = model_output + eps = model_output[:, :4] + if cond_fn is not None: + alpha_bar = _extract_into_tensor_lerp(self.alphas_cumprod, t, x.shape) + eps = eps - th.sqrt(1 - alpha_bar) * cond_fn(x, t, **model_kwargs) + return eps + + def eps_to_pred_xstart( + self, + x, + eps, + t, + ): + alpha_bar = _extract_into_tensor_lerp(self.alphas_cumprod, t, x.shape) + return (x - eps * th.sqrt(1 - alpha_bar)) / th.sqrt(alpha_bar) + + def pndm_transfer( + self, + x, + eps, + t_1, + t_2, + ): + pred_xstart = self.eps_to_pred_xstart(x, eps, t_1) + alpha_bar_prev = _extract_into_tensor_lerp(self.alphas_cumprod, t_2, x.shape) + return pred_xstart * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps + + def prk_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model using PRK. + + Same usage as p_sample_loop(). + """ + final = None + for sample in self.prk_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def prk_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Use PRK to sample from the model and yield intermediate samples from + each timestep of PRK. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1][1:-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices, leave=False) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.prk_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def prk_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model using fourth-order Pseudo Runge-Kutta + (https://openreview.net/forum?id=PlKWVd2yBkY). + + Same usage as p_sample(). + """ + if model_kwargs is None: + model_kwargs = {} + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + eps_1 = self.get_eps(model, x, t, model_kwargs, cond_fn) + x_1 = self.pndm_transfer(x, eps_1, t, t - 0.5) + eps_2 = self.get_eps(model, x_1, t - 0.5, model_kwargs, cond_fn) + x_2 = self.pndm_transfer(x, eps_2, t, t - 0.5) + eps_3 = self.get_eps(model, x_2, t - 0.5, model_kwargs, cond_fn) + x_3 = self.pndm_transfer(x, eps_3, t, t - 1) + eps_4 = self.get_eps(model, x_3, t - 1, model_kwargs, cond_fn) + eps_prime = (eps_1 + 2 * eps_2 + 2 * eps_3 + eps_4) / 6 + + sample = self.pndm_transfer(x, eps_prime, t, t - 1) + pred_xstart = self.eps_to_pred_xstart(x, eps_prime, t) + pred_xstart = process_xstart(pred_xstart) + return {"sample": sample, "pred_xstart": pred_xstart, "eps": eps_prime} + + def plms_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + progress_leave=True, + ): + """ + Generate samples from the model using PLMS. + + Same usage as p_sample_loop(). + """ + assert self.model_mean_type == ModelMeanType.EPSILON, \ + 'plms_sample only support model_mean_type == ModelMeanType.EPSILON' + final = None + for sample in self.plms_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + progress_leave=progress_leave, + ): + final = sample + return final["sample"] + + def plms_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + progress_leave=True, + ): + """ + Use PLMS to sample from the model and yield intermediate samples from + each timestep of PLMS. + + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1][1:-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices, leave=progress_leave) + + old_eps = [] + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + if len(old_eps) < 3: + out = self.prk_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + else: + out = self.plms_sample( + model, + img, + old_eps, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + old_eps.pop(0) + old_eps.append(out["eps"]) + yield out + img = out["sample"] + + def plms_sample( + self, + model, + x, + old_eps, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model using fourth-order Pseudo Linear Multistep + (https://openreview.net/forum?id=PlKWVd2yBkY). + """ + if model_kwargs is None: + model_kwargs = {} + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + eps = self.get_eps(model, x, t, model_kwargs, cond_fn) + eps_prime = (55 * eps - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + sample = self.pndm_transfer(x, eps_prime, t, t - 1) + pred_xstart = self.eps_to_pred_xstart(x, eps, t) + pred_xstart = process_xstart(pred_xstart) + return {"sample": sample, "pred_xstart": pred_xstart, "eps": eps} + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +def _extract_into_tensor_lerp(arr, timesteps, broadcast_shape): + """ + Extract values from arr with fractional time steps + """ + timesteps = timesteps.float() + frac = timesteps.frac() + while len(frac.shape) < len(broadcast_shape): + frac = frac[..., None] + res_1 = _extract_into_tensor(arr, timesteps.floor().long(), broadcast_shape) + res_2 = _extract_into_tensor(arr, timesteps.ceil().long(), broadcast_shape) + return th.lerp(res_1, res_2, frac) diff --git a/hydit/diffusion/pipeline.py b/hydit/diffusion/pipeline.py index 558608c..d030471 100644 --- a/hydit/diffusion/pipeline.py +++ b/hydit/diffusion/pipeline.py @@ -13,7 +13,10 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union +import PIL +import numpy as np import torch +import torchvision.transforms as T from diffusers.configuration_utils import FrozenDict from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin @@ -24,6 +27,7 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( + PIL_INTERPOLATION, deprecate, logging, replace_example_docstring, @@ -39,18 +43,30 @@ EXAMPLE_DOC_STRING = """ Examples: ```py + >>> import requests >>> import torch - >>> from diffusers import StableDiffusionPipeline + >>> from PIL import Image + >>> from io import BytesIO - >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) - >>> pipe = pipe.to("cuda") + >>> from diffusers import StableDiffusionImg2ImgPipeline - >>> prompt = "a photo of an astronaut riding a horse on mars" - >>> image = pipe(prompt).images[0] - ``` -""" + >>> device = "cuda" + >>> model_id_or_path = "runwayml/stable-diffusion-v1-5" + >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + >>> prompt = "A fantasy landscape, trending on artstation" + >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + >>> images[0].save("fantasy_landscape.png") + ``` +""" def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and @@ -64,10 +80,34 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg - -class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin): +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" - Pipeline for text-to-image generation using Stable Diffusion. + Pipeline for text-guided image-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -86,8 +126,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): A `BertTokenizer` or `CLIPTokenizer` to tokenize text. unet (Optional[`HunYuanDiT`, `UNet2DConditionModel`]): - A `HunYuanDiT` or `UNet2DConditionModel` to denoise the encoded image latents. - Notice: Here we still keep the word `unet` for compatibility with the previous version of the pipeline. + A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. @@ -186,35 +225,7 @@ def __init__( self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, @@ -245,6 +256,7 @@ def _encode_prompt( return prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, @@ -326,7 +338,7 @@ def encode_prompt( untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids + text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode( untruncated_ids[:, tokenizer.model_max_length - 1 : -1] @@ -415,6 +427,34 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask + def _convert_to_rgb(self, image): + return image.convert('RGB') + + def image_transform(self, image_size=224): + transform = T.Compose([ + T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC), + self._convert_to_rgb, + T.ToTensor(), + T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + return transform + + def encode_img(self, img, device, do_classifier_free_guidance): + # print('len', len(img)) + # print('img', img.size) + img = img[0] # TODO: support batch processing + image_preprocess = self.image_transform(224) + img_for_clip = image_preprocess(img) + # print('img_for_clip', img_for_clip.shape) + img_for_clip = img_for_clip.unsqueeze(0) + img_clip_embedding = self.img_encoder(img_for_clip.to(device)).to(dtype=torch.float16) + # print('img_clip_embedding_1_type', img_clip_embedding.dtype) + if do_classifier_free_guidance: + negative_img_clip_embedding = torch.zeros_like(img_clip_embedding) + return img_clip_embedding, negative_img_clip_embedding + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None @@ -429,6 +469,7 @@ def run_safety_checker(self, image, device, dtype): ) return image, has_nsfw_concept + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) @@ -440,6 +481,7 @@ def decode_latents(self, latents): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -504,6 +546,15 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -591,10 +642,6 @@ def __call__( generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. @@ -734,20 +781,9 @@ def __call__( return_dict=False, ) elif self.infer_mode == "trt": - noise_pred = self.unet( - x=latent_model_input.contiguous(), - t_emb=t_expand.contiguous(), - context=prompt_embeds.contiguous(), - image_meta_size=ims.contiguous(), - style=style.contiguous(), - freqs_cis_img0=freqs_cis_img[0].to(device).contiguous(), - freqs_cis_img1=freqs_cis_img[1].to(device).contiguous(), - text_embedding_mask=attention_mask.contiguous(), - encoder_hidden_states_t5=prompt_embeds_t5.contiguous(), - text_embedding_mask_t5=attention_mask_t5.contiguous(), - ) + raise NotImplementedError("TensorRT model is not supported yet.") else: - raise ValueError("Unknown infer_mode: {self.infer_mode}") + raise ValueError("[ERROR] invalid inference mode! please check your config file") if learn_sigma: noise_pred, _ = noise_pred.chunk(2, dim=1) diff --git a/hydit/diffusion/respace.py b/hydit/diffusion/respace.py new file mode 100644 index 0000000..efd91b9 --- /dev/null +++ b/hydit/diffusion/respace.py @@ -0,0 +1,141 @@ +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + Improved DDPM + + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def get_eps(self, model, *args, **kwargs): + return super().get_eps(self._wrap_model(model), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + """ + Improved DDPM + + When using a subsequent timesteps (e.g., 250), we must wrap the model + for mapping the timesteps 1-250 with step 1 to 1-1000 with step 4 + """ + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + """ + Here we must make a interpolation because `ts` maybe a float (e.g., 4.5) + in the PLMS/PNDM sampler. + """ + ts = ts.float() + frac = ts.frac() + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts_1 = map_tensor[ts.floor().long()] + new_ts_2 = map_tensor[ts.ceil().long()] + new_ts = th.lerp(new_ts_1, new_ts_2, frac) + return self.model(x, new_ts, **kwargs) diff --git a/hydit/ds_config.py b/hydit/ds_config.py new file mode 100644 index 0000000..d108a6d --- /dev/null +++ b/hydit/ds_config.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- +import os + +def deepspeed_config_from_args(args, global_batch_size): + if args.use_zero_stage == 2: + deepspeed_config = { + "train_batch_size": global_batch_size, + "train_micro_batch_size_per_gpu": args.batch_size, + "gradient_accumulation_steps": args.grad_accu_steps, + "steps_per_print": args.log_every, + "optimizer": { + "type": "AdamW", + "params": { + "lr": args.lr, + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-08, + "weight_decay": args.weight_decay + } + }, + + "zero_optimization": { + "stage": 2, + "reduce_scatter": False, + "reduce_bucket_size": 1e9, + }, + + "gradient_clipping": 1.0, + "prescale_gradients": True, + + "fp16": { + "enabled": args.use_fp16, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1e-3, + "initial_scale_power": 15 + }, + + "bf16": { + "enabled": False + }, + + "wall_clock_breakdown": False + } + + elif args.use_zero_stage == 3: + deepspeed_config = { + "train_batch_size": args.global_batch_size, + # "train_micro_batch_size_per_gpu": args.batch_size, + "gradient_accumulation_steps": args.grad_accu_steps, + "steps_per_print": args.log_every, + + "optimizer": { + "type": "AdamW", + "params": { + "lr": args.lr, + "betas": [ + 0.9, + 0.999 + ], + "eps": 1e-08, + "weight_decay": args.weight_decay + } + }, + + "zero_optimization": { + "stage": 3, + "allgather_partitions": True, + "overlap_comm": True, + "reduce_scatter": True, + "contiguous_gradients": True, + "stage3_prefetch_bucket_size": 5e8, + "stage3_max_live_parameters" : 6e8, + "reduce_bucket_size": 1.2e9, + "sub_group_size": 1e9, + "sub_group_buffer_num": 10, + "pipeline_optimizer": True, + "max_contigous_event_size": 0, + "cache_sub_group_rate": 0.0, + "prefetch_cache_sub_group_rate": 1.0, + "max_contigous_params_size": -1, + "max_param_reduce_events": 0, + "stage3_param_persistence_threshold": 9e9, + "is_communication_time_profiling": False, + "save_large_model_multi_slice": True, + "use_fused_op_with_grad_norm_overflow": False, + "offload_optimizer": { + "device": "cpu", + "pin_memory": True + }, + "offload_param": { + "device": "cpu", + "pin_memory": True + } + }, + + "gradient_clipping": 1.0, + "prescale_gradients": False, + + "fp16": { + "enabled": True, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 15 + }, + + "bf16": { + "enabled": False + }, + + "wall_clock_breakdown": False, + "mem_chunk": { + "default_chunk_size": 536870911, + "use_fake_dist": False, + "client": { + "mem_tracer": { + "use_async_mem_monitor": True, + "warmup_gpu_chunk_mem_ratio": 0.8, + "overall_gpu_mem_ratio": 0.8, + "overall_cpu_mem_ratio": 1.0, + "margin_use_ratio": 0.8, + "use_fake_dist": False + }, + "opts": { + "with_mem_cache": True, + "with_async_move": True + } + } + } + } + else: + raise ValueError + return deepspeed_config + + diff --git a/hydit/inference.py b/hydit/inference.py index 7751ffb..c49e207 100644 --- a/hydit/inference.py +++ b/hydit/inference.py @@ -15,12 +15,13 @@ from transformers import BertModel, BertTokenizer from transformers.modeling_utils import logger as tf_logger -from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT, TRT_MAX_WIDTH, TRT_MAX_HEIGHT, TRT_MAX_BATCH_SIZE +from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT from .diffusion.pipeline import StableDiffusionPipeline from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop from .modules.text_encoder import MT5Embedder from .utils.tools import set_seeds +from peft import LoraConfig class Resolution: @@ -201,7 +202,6 @@ def __init__(self, args, models_root_path): self.infer_mode = self.args.infer_mode if self.infer_mode in ['fa', 'torch']: - model_dir = self.root / "model" model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt" if not model_path.exists(): raise ValueError(f"model_path not exists: {model_path}") @@ -215,6 +215,15 @@ def __init__(self, args, models_root_path): logger.info(f"Loading torch model {model_path}...") state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) self.model.load_state_dict(state_dict) + + lora_ckpt = args.lora_ckpt + if lora_ckpt is not None and lora_ckpt != "": + logger.info(f"Loading Lora checkpoint {lora_ckpt}...") + + self.model.load_adapter(lora_ckpt) + self.model.merge_and_unload() + + self.model.eval() logger.info(f"Loading torch model finished") elif self.infer_mode == 'trt': @@ -301,8 +310,7 @@ def predict(self, seed = random.randint(0, 1_000_000) if not isinstance(seed, int): raise TypeError(f"`seed` must be an integer, but got {type(seed)}") - generator = set_seeds(seed) - + generator = set_seeds(seed, device=self.device) # ======================================================================== # Arguments: target_width, target_height # ======================================================================== diff --git a/hydit/lr_scheduler.py b/hydit/lr_scheduler.py new file mode 100644 index 0000000..027c7e6 --- /dev/null +++ b/hydit/lr_scheduler.py @@ -0,0 +1,761 @@ +""" +Implementation of learning rate schedules. + +Taken and modified from PyTorch v1.0.1 source +https://github.com/pytorch/pytorch/blob/v1.1.0/torch/optim/lr_scheduler.py +""" + +import argparse +from torch.optim import Optimizer +import math + +LR_SCHEDULE = 'lr_schedule' +LR_RANGE_TEST = 'LRRangeTest' +ONE_CYCLE = 'OneCycle' +WARMUP_LR = 'WarmupLR' +WARMUP_DECAY_LR = 'WarmupDecayLR' +VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR] + +LR_RANGE_TEST_MIN_LR = 'lr_range_test_min_lr' +LR_RANGE_TEST_STEP_RATE = 'lr_range_test_step_rate' +LR_RANGE_TEST_STEP_SIZE = 'lr_range_test_step_size' +LR_RANGE_TEST_STAIRCASE = 'lr_range_test_staircase' + +EDGE_VALUE = 'edge_value' +MID_VALUE = 'mid_value' + +CYCLE_FIRST_STEP_SIZE = 'cycle_first_step_size' +CYCLE_FIRST_STAIR_COUNT = 'cycle_first_stair_count' +CYCLE_SECOND_STEP_SIZE = 'cycle_second_step_size' +CYCLE_SECOND_STAIR_COUNT = 'cycle_second_stair_count' +DECAY_STEP_SIZE = 'decay_step_size' + +CYCLE_MIN_LR = 'cycle_min_lr' +CYCLE_MAX_LR = 'cycle_max_lr' +DECAY_LR_RATE = 'decay_lr_rate' + +CYCLE_MIN_MOM = 'cycle_min_mom' +CYCLE_MAX_MOM = 'cycle_max_mom' +DECAY_MOM_RATE = 'decay_mom_rate' + +WARMUP_MIN_LR = 'warmup_min_lr' +WARMUP_MAX_LR = 'warmup_max_lr' +WARMUP_NUM_STEPS = 'warmup_num_steps' +WARMUP_TYPE = 'warmup_type' +WARMUP_LOG_RATE = 'log' +WARMUP_LINEAR_RATE = 'linear' + +TOTAL_NUM_STEPS = 'total_num_steps' + + +def add_tuning_arguments(parser): + group = parser.add_argument_group('Convergence Tuning', 'Convergence tuning configurations') + + # LR scheduler + group.add_argument('--lr_schedule', type=str, default=None, help='LR schedule for training.') + + # Learning rate range test + group.add_argument("--lr_range_test_min_lr", type=float, default=0.001, help='Starting lr value.') + group.add_argument("--lr_range_test_step_rate", type=float, default=1.0, help='scaling rate for LR range test.') + group.add_argument("--lr_range_test_step_size", type=int, default=1000, help='training steps per LR change.') + group.add_argument("--lr_range_test_staircase", + type=bool, + default=False, + help='use staircase scaling for LR range test.') + + # OneCycle schedule + group.add_argument("--cycle_first_step_size", + type=int, + default=1000, + help='size of first step of 1Cycle schedule (training steps).') + group.add_argument("--cycle_first_stair_count", + type=int, + default=-1, + help='first stair count for 1Cycle schedule.') + group.add_argument("--cycle_second_step_size", + type=int, + default=-1, + help='size of second step of 1Cycle schedule (default first_step_size).') + group.add_argument("--cycle_second_stair_count", + type=int, + default=-1, + help='second stair count for 1Cycle schedule.') + group.add_argument("--decay_step_size", + type=int, + default=1000, + help='size of intervals for applying post cycle decay (training steps).') + + # 1Cycle LR + group.add_argument("--cycle_min_lr", type=float, default=0.01, help='1Cycle LR lower bound.') + group.add_argument("--cycle_max_lr", type=float, default=0.1, help='1Cycle LR upper bound.') + group.add_argument("--decay_lr_rate", type=float, default=0.0, help='post cycle LR decay rate.') + + # 1Cycle Momentum + group.add_argument('--cycle_momentum', default=False, action='store_true', help='Enable 1Cycle momentum schedule.') + group.add_argument("--cycle_min_mom", type=float, default=0.8, help='1Cycle momentum lower bound.') + group.add_argument("--cycle_max_mom", type=float, default=0.9, help='1Cycle momentum upper bound.') + group.add_argument("--decay_mom_rate", type=float, default=0.0, help='post cycle momentum decay rate.') + + # Warmup LR + group.add_argument('--warmup_min_lr', type=float, default=0, help='WarmupLR minimum/initial LR value') + group.add_argument('--warmup_max_lr', type=float, default=0.001, help='WarmupLR maximum LR value.') + group.add_argument('--warmup_num_steps', type=int, default=1000, help='WarmupLR step count for LR warmup.') + group.add_argument('--warmup_type', + type=str, + default=WARMUP_LOG_RATE, + help='WarmupLR increasing function during warmup') + return parser + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser = add_tuning_arguments(parser) + + lr_sched_args, unknown_args = parser.parse_known_args() + return lr_sched_args, unknown_args + + +def override_lr_range_test_params(args, params): + if hasattr(args, LR_RANGE_TEST_MIN_LR) and args.lr_range_test_min_lr is not None: + params[LR_RANGE_TEST_MIN_LR] = args.lr_range_test_min_lr + + if hasattr(args, LR_RANGE_TEST_STEP_RATE) and args.lr_range_test_step_rate is not None: + params[LR_RANGE_TEST_STEP_RATE] = args.lr_range_test_step_rate + + if hasattr(args, LR_RANGE_TEST_STEP_SIZE) and args.lr_range_test_step_size is not None: + params[LR_RANGE_TEST_STEP_SIZE] = args.lr_range_test_step_size + + if hasattr(args, LR_RANGE_TEST_STAIRCASE) and args.lr_range_test_staircase is not None: + params[LR_RANGE_TEST_STAIRCASE] = args.lr_range_test_staircase + + +def override_1cycle_params(args, params): + if hasattr(args, CYCLE_FIRST_STEP_SIZE) and args.cycle_first_step_size is not None: + params[CYCLE_FIRST_STEP_SIZE] = args.cycle_first_step_size + + if hasattr(args, CYCLE_FIRST_STAIR_COUNT) and args.cycle_first_stair_count is not None: + params[CYCLE_FIRST_STAIR_COUNT] = args.cycle_first_stair_count + + if hasattr(args, CYCLE_SECOND_STEP_SIZE) and args.cycle_second_step_size is not None: + params[CYCLE_SECOND_STEP_SIZE] = args.cycle_second_step_size + + if hasattr(args, CYCLE_SECOND_STAIR_COUNT) and args.cycle_second_stair_count is not None: + params[CYCLE_SECOND_STAIR_COUNT] = args.cycle_second_stair_count + + if hasattr(args, DECAY_STEP_SIZE) and args.decay_step_size is not None: + params[DECAY_STEP_SIZE] = args.decay_step_size + + # 1Cycle LR params + if hasattr(args, CYCLE_MIN_LR) and args.cycle_min_lr is not None: + params[CYCLE_MIN_LR] = args.cycle_min_lr + + if hasattr(args, CYCLE_MAX_LR) and args.cycle_max_lr is not None: + params[CYCLE_MAX_LR] = args.cycle_max_lr + + if hasattr(args, DECAY_LR_RATE) and args.decay_lr_rate is not None: + params[DECAY_LR_RATE] = args.decay_lr_rate + + # 1Cycle MOM params + if hasattr(args, CYCLE_MIN_MOM) and args.cycle_min_mom is not None: + params[CYCLE_MIN_MOM] = args.cycle_min_mom + + if hasattr(args, CYCLE_MAX_MOM) and args.cycle_max_mom is not None: + params[CYCLE_MAX_MOM] = args.cycle_max_mom + + if hasattr(args, DECAY_MOM_RATE) and args.decay_mom_rate is not None: + params[DECAY_MOM_RATE] = args.decay_mom_rate + + +def override_warmupLR_params(args, params): + if hasattr(args, WARMUP_MIN_LR) and args.warmup_min_lr is not None: + params[WARMUP_MIN_LR] = args.warmup_min_lr + + if hasattr(args, WARMUP_MAX_LR) and args.warmup_max_lr is not None: + params[WARMUP_MAX_LR] = args.warmup_max_lr + + if hasattr(args, WARMUP_NUM_STEPS) and args.warmup_num_steps is not None: + params[WARMUP_NUM_STEPS] = args.warmup_num_steps + + if hasattr(args, WARMUP_TYPE) and args.warmup_type is not None: + params[WARMUP_TYPE] = args.warmup_type + + +def override_params(args, params): + # LR range test params + override_lr_range_test_params(args, params) + + # 1Cycle params + override_1cycle_params(args, params) + + # WarmupLR params + override_warmupLR_params(args, params) + + +def get_config_from_args(args): + if not hasattr(args, LR_SCHEDULE) or args.lr_schedule is None: + return None, '--{} not specified on command line'.format(LR_SCHEDULE) + + if not args.lr_schedule in VALID_LR_SCHEDULES: + return None, '{} is not supported LR schedule'.format(args.lr_schedule) + + config = {} + config['type'] = args.lr_schedule + config['params'] = {} + + if args.lr_schedule == LR_RANGE_TEST: + override_lr_range_test_params(args, config['params']) + elif args.lr_schedule == ONE_CYCLE: + override_1cycle_params(args, config['params']) + else: + override_warmupLR_params(args, config['params']) + + return config, None + + +def get_lr_from_config(config): + if not 'type' in config: + return None, 'LR schedule type not defined in config' + + if not 'params' in config: + return None, 'LR schedule params not defined in config' + + lr_schedule = config['type'] + lr_params = config['params'] + + if not lr_schedule in VALID_LR_SCHEDULES: + return None, '{} is not a valid LR schedule'.format(lr_schedule) + + if lr_schedule == LR_RANGE_TEST: + return lr_params[LR_RANGE_TEST_MIN_LR], '' + if lr_schedule == ONE_CYCLE: + return lr_params[CYCLE_MAX_LR], '' + # Warmup LR + return lr_params[WARMUP_MAX_LR], '' + + +""" +Only optimizers that are subclass of torch.optim.Optimizer are supported. So check the passed optimizer and wrapped +optimizer to see if requirement is satisfied. +TODO: Looking under the hood to examine the wrapped optimizer is a hack that requires a better long-term fix. +""" + + +def get_torch_optimizer(optimizer): + if isinstance(optimizer, Optimizer): + return optimizer + + if hasattr(optimizer, 'optimizer') and isinstance(optimizer.optimizer, Optimizer): + return optimizer.optimizer + + raise TypeError('{} is not a subclass of torch.optim.Optimizer'.format(type(optimizer).__name__)) + + +class LRRangeTest(object): + """Sets the learning rate of each parameter group according to + learning rate range test (LRRT) policy. The policy increases learning + rate starting from a base value with a constant frequency, as detailed in + the paper `A disciplined approach to neural network hyper-parameters: Part1`_. + + LRRT policy is used for finding maximum LR that trains a model without divergence, and can be used to + configure the LR boundaries for Cyclic LR schedules. + + LRRT changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_range_test_min_lr (float or list): Initial learning rate which is the + lower boundary in the range test for each parameter group. + lr_range_test_step_size (int): Interval of training steps to increase learning rate. Default: 2000 + lr_range_test_step_rate (float): Scaling rate for range test. Default: 1.0 + lr_range_test_staircase (bool): Scale in staircase fashion, rather than continuous. Default: False. + last_batch_iteration (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_batch_iteration=-1, the schedule is started from the beginning. + Default: -1 + + Example: + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = LRRangeTest(optimizer) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: + https://arxiv.org/abs/1803.09820 +""" + + def __init__(self, + optimizer: Optimizer, + lr_range_test_min_lr: float = 1e-3, + lr_range_test_step_size: int = 2000, + lr_range_test_step_rate: float = 1.0, + lr_range_test_staircase: bool = False, + last_batch_iteration: int = -1): + + self.optimizer = get_torch_optimizer(optimizer) + + if isinstance(lr_range_test_min_lr, list) or isinstance(lr_range_test_min_lr, tuple): + if len(lr_range_test_min_lr) != len(self.optimizer.param_groups): + raise ValueError("expected {} lr_range_test_min_lr, got {}".format(len(self.optimizer.param_groups), + len(lr_range_test_min_lr))) + self.min_lr = list(lr_range_test_min_lr) + else: + self.min_lr = [lr_range_test_min_lr] * len(self.optimizer.param_groups) + + self.step_size = lr_range_test_step_size + self.step_rate = lr_range_test_step_rate + self.last_batch_iteration = last_batch_iteration + self.staircase = lr_range_test_staircase + self.interval_fn = self._staircase_interval if lr_range_test_staircase else self._continuous_interval + + if last_batch_iteration == -1: + self._update_optimizer(self.min_lr) + + def _staircase_interval(self): + return math.floor(float(self.last_batch_iteration + 1) / self.step_size) + + def _continuous_interval(self): + return float(self.last_batch_iteration + 1) / self.step_size + + def _get_increase(self): + return (1 + self.step_rate * self.interval_fn()) + + def get_lr(self): + lr_increase = self._get_increase() + return [lr_range_test_min_lr * lr_increase for lr_range_test_min_lr in self.min_lr] + + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" + return self._last_lr + + def _update_optimizer(self, group_lrs): + for param_group, lr in zip(self.optimizer.param_groups, group_lrs): + param_group['lr'] = lr + + def step(self, batch_iteration=None): + if batch_iteration is None: + batch_iteration = self.last_batch_iteration + 1 + self.last_batch_iteration = batch_iteration + self._update_optimizer(self.get_lr()) + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + def state_dict(self): + return {'last_batch_iteration': self.last_batch_iteration} + + def load_state_dict(self, sd): + self.last_batch_iteration = sd['last_batch_iteration'] + + + +class OneCycle(object): + """Sets the learning rate of each parameter group according to + 1Cycle learning rate policy (1CLR). 1CLR is a variation of the + Cyclical Learning Rate (CLR) policy that involves one cycle followed by + decay. The policy simultaneously cycles the learning rate (and momentum) + between two boundaries with a constant frequency, as detailed in + the paper `A disciplined approach to neural network hyper-parameters`_. + + 1CLR policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This implementation was adapted from the github repo: `pytorch/pytorch`_ + + Args: + optimizer (Optimizer): Wrapped optimizer. + cycle_min_lr (float or list): Initial learning rate which is the + lower boundary in the cycle for each parameter group. + cycle_max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (cycle_max_lr - cycle_min_lr). + The lr at any cycle is the sum of cycle_min_lr + and some scaling of the amplitude; therefore + cycle_max_lr may not actually be reached depending on + scaling function. + decay_lr_rate(float): Decay rate for learning rate. Default: 0. + cycle_first_step_size (int): Number of training iterations in the + increasing half of a cycle. Default: 2000 + cycle_second_step_size (int): Number of training iterations in the + decreasing half of a cycle. If cycle_second_step_size is None, + it is set to cycle_first_step_size. Default: None + cycle_first_stair_count(int): Number of stairs in first half of cycle phase. This means + lr/mom are changed in staircase fashion. Default 0, means staircase disabled. + cycle_second_stair_count(int): Number of stairs in second half of cycle phase. This means + lr/mom are changed in staircase fashion. Default 0, means staircase disabled. + decay_step_size (int): Intervals for applying decay in decay phase. Default: 0, means no decay. + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'cycle_min_mom' and 'cycle_max_mom'. + Default: True + cycle_min_mom (float or list): Initial momentum which is the + lower boundary in the cycle for each parameter group. + Default: 0.8 + cycle_max_mom (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (cycle_max_mom - cycle_min_mom). + The momentum at any cycle is the difference of cycle_max_mom + and some scaling of the amplitude; therefore + cycle_min_mom may not actually be reached depending on + scaling function. Default: 0.9 + decay_mom_rate (float): Decay rate for momentum. Default: 0. + last_batch_iteration (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_batch_iteration=-1, the schedule is started from the beginning. + Default: -1 + + Example: + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = OneCycle(optimizer, 0.0001, 0.0010) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + + .. _A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch size, momentum, and weight decay: https://arxiv.org/abs/1803.09820 + """ + + def __init__(self, + optimizer, + cycle_min_lr, + cycle_max_lr, + decay_lr_rate=0., + cycle_first_step_size=2000, + cycle_second_step_size=None, + cycle_first_stair_count=0, + cycle_second_stair_count=None, + decay_step_size=0, + cycle_momentum=True, + cycle_min_mom=0.8, + cycle_max_mom=0.9, + decay_mom_rate=0., + last_batch_iteration=-1): + + self.optimizer = get_torch_optimizer(optimizer) + + # Initialize cycle shape + self._initialize_cycle(cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count, + cycle_second_stair_count, decay_step_size) + + # Initialize cycle lr + self._initialize_lr(self.optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration) + + # Initialize cyclic momentum + self.cycle_momentum = cycle_momentum + if cycle_momentum: + self._initialize_momentum(self.optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, + last_batch_iteration) + + # Initialize batch iteration tracker + self.last_batch_iteration = last_batch_iteration + + # Configure cycle shape + + def _initialize_cycle(self, cycle_first_step_size, cycle_second_step_size, cycle_first_stair_count, + cycle_second_stair_count, decay_step_size): + cycle_first_step_size = float(cycle_first_step_size) + cycle_second_step_size = float( + cycle_second_step_size) if cycle_second_step_size is not None else cycle_first_step_size + + self.total_size = cycle_first_step_size + cycle_second_step_size + self.step_ratio = cycle_first_step_size / self.total_size + self.first_stair_count = cycle_first_stair_count + self.second_stair_count = cycle_first_stair_count if cycle_second_stair_count is None else cycle_second_stair_count + self.decay_step_size = decay_step_size + + if math.isclose(self.decay_step_size, 0): + self.skip_lr_decay = True + self.skip_mom_decay = True + else: + self.skip_lr_decay = False + self.skip_mom_decay = False + + # Configure lr schedule + def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, last_batch_iteration): + self.min_lrs = [cycle_min_lr] * len(optimizer.param_groups) + if last_batch_iteration == -1: + for lr, group in zip(self.min_lrs, optimizer.param_groups): + group['lr'] = lr + + self.max_lrs = [cycle_max_lr] * len(optimizer.param_groups) + self.decay_lr_rate = decay_lr_rate + + if math.isclose(self.decay_lr_rate, 0): + self.skip_lr_decay = True + + # Configure momentum schedule + def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration): + if 'betas' not in optimizer.defaults: + optimizer_name = type(optimizer).__name__ + print( + f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults" + ) + self.cycle_momentum = False + return + + self.decay_mom_rate = decay_mom_rate + self.min_moms = [(cycle_min_mom, 0.99)] * len(optimizer.param_groups) + self.max_moms = [(cycle_max_mom, 0.99)] * len(optimizer.param_groups) + + if last_batch_iteration == -1: + for momentum, group in zip(self.min_moms, optimizer.param_groups): + group['betas'] = momentum + + if math.isclose(self.decay_mom_rate, 0): + self.skip_mom_decay = True + + def _get_scale_factor(self): + batch_iteration = (self.last_batch_iteration + 1) + cycle = math.floor(1 + batch_iteration / self.total_size) + x = 1. + batch_iteration / self.total_size - cycle + if x <= self.step_ratio: + scale_factor = x / self.step_ratio + else: + scale_factor = (x - 1) / (self.step_ratio - 1) + + return scale_factor + + def _get_cycle_mom(self): + scale_factor = self._get_scale_factor() + momentums = [] + for base_betas, max_betas in zip(self.min_moms, self.max_moms): + cycle_min_mom = base_betas[0] + cycle_max_mom = max_betas[0] + base_height = (cycle_max_mom - cycle_min_mom) * scale_factor + momentum = cycle_max_mom - base_height + momentums.append((momentum, base_betas[1])) + return momentums + + def _get_cycle_lr(self): + scale_factor = self._get_scale_factor() + lrs = [] + for cycle_min_lr, cycle_max_lr in zip(self.min_lrs, self.max_lrs): + base_height = (cycle_max_lr - cycle_min_lr) * scale_factor + lr = cycle_min_lr + base_height + lrs.append(lr) + + return lrs + + def _get_decay_mom(self, decay_batch_iteration): + if self.skip_mom_decay: + return self.max_moms + + decay_interval = decay_batch_iteration / self.decay_step_size + mom_decay_factor = (1 + self.decay_mom_rate * decay_interval) + momentums = [(beta0 * mom_decay_factor, beta1) for beta0, beta1 in self.max_moms] + + return momentums + + def _get_decay_lr(self, decay_batch_iteration): + """Calculates the learning rate at batch index. This function is used + after the cycle completes and post cycle decaying of lr/mom is enabled. + This function treats `self.last_batch_iteration` as the last batch index. + """ + if self.skip_lr_decay: + return self.min_lrs + + decay_interval = decay_batch_iteration / self.decay_step_size + lr_decay_factor = (1 + self.decay_lr_rate * decay_interval) + lrs = [cycle_min_lr / lr_decay_factor for cycle_min_lr in self.min_lrs] + + return lrs + + def get_lr(self): + """Calculates the learning rate at batch index. This function treats + `self.last_batch_iteration` as the last batch index. + """ + if self.last_batch_iteration < self.total_size: + return self._get_cycle_lr() + return self._get_decay_lr(self.last_batch_iteration - self.total_size + 1) + + def get_mom(self): + """Calculates the momentum at batch index. This function treats + `self.last_batch_iteration` as the last batch index. + """ + if not self.cycle_momentum: + return None + + if self.last_batch_iteration < self.total_size: + return self._get_cycle_mom() + return self._get_decay_mom(self.last_batch_iteration - self.total_size + 1) + + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" + return self._last_lr + + def step(self, batch_iteration=None): + """ Updates the optimizer with the learning rate for the last batch index. + `self.last_batch_iteration` is treated as the last batch index. + + If self.cycle_momentum is true, also updates optimizer momentum. + """ + if batch_iteration is None: + batch_iteration = self.last_batch_iteration + 1 + + self.last_batch_iteration = batch_iteration + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + if self.cycle_momentum: + momentums = self.get_mom() + for param_group, momentum in zip(self.optimizer.param_groups, momentums): + param_group['betas'] = momentum + + def state_dict(self): + return {'last_batch_iteration': self.last_batch_iteration} + + def load_state_dict(self, sd): + self.last_batch_iteration = sd['last_batch_iteration'] + + + +class WarmupLR(object): + """Increase the learning rate of each parameter group from min lr to max lr + over warmup_num_steps steps, and then fix at max lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_min_lr (float or list): minimum learning rate. Default: 0 + warmup_max_lr (float or list): maximum learning rate. Default: 0.001 + warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000 + warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log + last_batch_iteration (int): The index of the last batch. Default: -1. + Example: + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = WarmupLR(optimizer) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + """ + + def __init__(self, + optimizer: Optimizer, + warmup_min_lr: float = 0.0, + warmup_max_lr: float = 0.001, + warmup_num_steps: int = 1000, + warmup_type: str = WARMUP_LOG_RATE, + last_batch_iteration: int = -1): + + self.optimizer = get_torch_optimizer(optimizer) + + self.min_lrs = self._format_param(self.optimizer, warmup_min_lr, "min_lr") + self.max_lrs = self._format_param(self.optimizer, warmup_max_lr, "max_lr") + self.delta_lrs = [big - small for big, small in zip(self.max_lrs, self.min_lrs)] + self.warmup_num_steps = max(2, warmup_num_steps) + # Currently only support linear and log function + if warmup_type not in {WARMUP_LOG_RATE, WARMUP_LINEAR_RATE}: + print(f"Using unknown warmup_type: {warmup_type}. The increasing function " + f"is set to default (log)") + warmup_type = WARMUP_LOG_RATE + self.warmup_type = warmup_type + self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps) + self.last_batch_iteration = last_batch_iteration + + def get_lr(self): + if self.last_batch_iteration < 0: + print("Attempting to get learning rate from scheduler before it has started") + return [0.0] + gamma = self._get_gamma() + return [min_lr + (delta_lr * gamma) for min_lr, delta_lr in zip(self.min_lrs, self.delta_lrs)] + + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + assert getattr(self, '_last_lr', None) is not None, "need to call step() first" + return self._last_lr + + def step(self, last_batch_iteration=None): + if last_batch_iteration is None: + last_batch_iteration = self.last_batch_iteration + 1 + self.last_batch_iteration = last_batch_iteration + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): + param_group['lr'] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + def state_dict(self): + return {'last_batch_iteration': self.last_batch_iteration} + + def load_state_dict(self, sd): + self.last_batch_iteration = sd['last_batch_iteration'] + + def _get_gamma(self): + if self.last_batch_iteration < self.warmup_num_steps: + if self.warmup_type == WARMUP_LOG_RATE: + return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1) + elif self.warmup_type == WARMUP_LINEAR_RATE: + return self.last_batch_iteration / self.warmup_num_steps + return 1.0 + + def _format_param(self, optimizer, param_value, param_name): + if isinstance(param_value, list) or isinstance(param_value, tuple): + if len(param_value) != len(optimizer.param_groups): + raise ValueError("expected {} value for {}, got {}".format(len(optimizer.param_groups), param_name, + FileNotFoundError(param_value))) + return list(param_value) + return [param_value] * len(optimizer.param_groups) + + + +class WarmupDecayLR(WarmupLR): + """Increase the learning rate of each parameter group from min lr to max lr + over warmup_num_steps steps, and then decay at linear rate over the remaining training steps. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_num_steps (int): total number of training steps + warmup_min_lr (float or list): minimum learning rate. Default: 0 + warmup_max_lr (float or list): maximum learning rate. Default: 0.001 + warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000 + warmup_type {‘log’, ‘linear’}: increasing function from min_lr to max_lr during warmup. Default: log + last_batch_iteration (int): The index of the last batch. Default: -1. + Example: + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = WarmupDecayLR(optimizer, 1000000) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + """ + + def __init__(self, + optimizer: Optimizer, + total_num_steps: int, + warmup_min_lr: float = 0.0, + warmup_max_lr: float = 0.001, + warmup_num_steps: int = 1000, + warmup_type: str = WARMUP_LOG_RATE, + last_batch_iteration: int = -1): + + self.total_num_steps = total_num_steps + super(WarmupDecayLR, self).__init__(optimizer, warmup_min_lr, warmup_max_lr, warmup_num_steps, warmup_type, + last_batch_iteration) + if self.total_num_steps < self.warmup_num_steps: + print('total_num_steps {} is less than warmup_num_steps {}'.format( + total_num_steps, warmup_num_steps)) + + def _get_gamma(self): + if self.last_batch_iteration < self.warmup_num_steps: + if self.warmup_type == WARMUP_LOG_RATE: + return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1) + elif self.warmup_type == WARMUP_LINEAR_RATE: + return self.last_batch_iteration / self.warmup_num_steps + return max( + 0.0, + float(self.total_num_steps - self.last_batch_iteration) / + float(max(1.0, self.total_num_steps - self.warmup_num_steps))) diff --git a/hydit/modules/ema.py b/hydit/modules/ema.py new file mode 100644 index 0000000..acc31ee --- /dev/null +++ b/hydit/modules/ema.py @@ -0,0 +1,127 @@ +from collections import OrderedDict +from copy import deepcopy + +import torch +from deepspeed.utils import instrument_w_nvtx +from pathlib import Path + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + + +class EMA(object): + def __init__(self, args, model, device, logger): + if args.ema_dtype == 'fp32': + self.warmup = args.ema_warmup + self.update_after_step = 0 + self.max_value = args.ema_decay if args.ema_decay is not None else 0.9999 + self.inv_gamma = 1.0 + self.power = args.ema_warmup_power if args.ema_warmup_power is not None else 2 / 3 + self.min_value = 0.0 + else: + self.warmup = args.ema_warmup + self.update_after_step = 0 + self.max_value = args.ema_decay if args.ema_decay is not None else 0.992 + self.inv_gamma = 1.0 + self.power = args.ema_warmup_power if args.ema_warmup_power is not None else 0.446249 + # 0.446249 == math.log(1 - 0.992) / math.log(50000) + self.min_value = 0.0 + + self.ema_reset_decay = args.ema_reset_decay + self.decay_steps = 0 + + if args.ema_dtype == 'none': + ema_dtype = 'fp16' if args.use_fp16 else 'fp32' + else: + ema_dtype = args.ema_dtype + + # 由于module.half()和module.float()会发生inplace类型修改,因此需要先copy后修改类型 + self.ema_model = deepcopy(model) + if ema_dtype == 'fp16': + self.ema_model = self.ema_model.half().to(device) + elif ema_dtype == 'fp32': + self.ema_model = self.ema_model.float().to(device) + else: + raise ValueError(f"Unknown EMA dtype {ema_dtype}.") + + requires_grad(self.ema_model, False) + + logger.info(f" Using EMA with date type {args.ema_dtype} " + f"(decay={args.ema_decay}, warmup={args.ema_warmup}, warmup_power={args.ema_warmup_power}, " + f"reset_decay={args.ema_reset_decay}).") + + def get_decay(self): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + + @jarvizhang's notes on EMA max_value when enabling FP16: + If using FP16 for EMA, max_value=0.995 is better (Don't larger than 0.999, unless you know + what you are doing). This is because FP16 has less precision than FP32, so the EMA value can + be pushed out of the range of FP16. + + gamma=1, power=0.446249 are good values for models (reaches decay factor 0.99 at 30K steps, + 0.992 at 50K steps). + """ + if self.warmup: + step = max(0, self.decay_steps - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if step <= 0: + return 0.0 + + return max(self.min_value, min(value, self.max_value)) + else: + return self.max_value + + @torch.no_grad() + @instrument_w_nvtx + def update(self, model, step, decay=None): + """ + Step the EMA model towards the current model. + + Parameters + ---------- + model: nn.Module + The current model + step: int + The current training step. This is used to determine the decay factor. If you want to control + the decay, you can pass in a custom step instead. + For example, if you want to restart the EMA decay, you can pass in step=0 at start and increase + step by step. + decay: float + The decay factor. If None, will be determined by the current step. + """ + if decay is None: + if self.ema_reset_decay: + self.decay_steps += 1 + else: + self.decay_steps = step + decay = self.get_decay() + + ema_params = OrderedDict(self.ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + + return None + + def state_dict(self, *args, **kwargs): + return self.ema_model.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs): + return self.ema_model.load_state_dict(*args, **kwargs) + + def train(self): + self.ema_model.train() + + def eval(self): + self.ema_model.eval() diff --git a/hydit/modules/fp16_layers.py b/hydit/modules/fp16_layers.py new file mode 100644 index 0000000..6812811 --- /dev/null +++ b/hydit/modules/fp16_layers.py @@ -0,0 +1,86 @@ +import torch +from torch.autograd import Variable +from torch.nn.parameter import Parameter + + +_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) + + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` + #is a nested tuple/list structure.""" + if isinstance(val, dict): + res_dict = {} + for k, v in val.items(): + if k!= 'cos_cis_img' and k != 'sin_cis_img': + res_dict[k] = conversion_helper(v, conversion) + else: + res_dict[k] = v + return res_dict + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + + +def fp32_to_float16(val, float16_convertor): + """Convert fp32 `val` to fp16/bf16""" + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, _FLOAT_TYPES): + val = float16_convertor(val) + return val + return conversion_helper(val, half_conversion) + + +def float16_to_fp32(val): + """Convert fp16/bf16 `val` to fp32""" + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, (_HALF_TYPES,)): + val = val.float() + return val + return conversion_helper(val, float_conversion) + + +class Float16Module(torch.nn.Module): + + def __init__(self, module, args): + super(Float16Module, self).__init__() + + self.add_module('module', module.half()) + + def float16_convertor(val): + return val.half() + + self.float16_convertor = float16_convertor + + self.config = self.module.config + self.dtype = torch.float16 + + def forward(self, *inputs, **kwargs): + inputs = fp32_to_float16(inputs, self.float16_convertor) + kwargs = fp32_to_float16(kwargs, self.float16_convertor) + outputs = self.module(*inputs, **kwargs) + outputs = float16_to_fp32(outputs) + return outputs + + def state_dict(self, destination=None, prefix='', keep_vars=False): + return self.module.state_dict(destination, prefix, keep_vars) + + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', + keep_vars=False): + return self.module.state_dict_for_save_checkpoint(destination, prefix, + keep_vars) + + def load_state_dict(self, state_dict, strict=True): + self.module.load_state_dict(state_dict, strict=strict) + diff --git a/hydit/modules/models.py b/hydit/modules/models.py index d125aa9..059f9bd 100644 --- a/hydit/modules/models.py +++ b/hydit/modules/models.py @@ -10,6 +10,14 @@ from .norm_layers import RMSNorm from .poolers import AttentionPool +from transformers.integrations import PeftAdapterMixin +from typing import Any, Optional, Union +from tqdm import tqdm +from peft.utils import ( + ModulesToSaveWrapper, + _get_submodules, +) + def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) @@ -136,12 +144,14 @@ def forward(self, x, c): return x -class HunYuanDiT(ModelMixin, ConfigMixin): +class HunYuanDiT(ModelMixin, ConfigMixin, PeftAdapterMixin): """ HunYuanDiT: Diffusion model with a Transformer backbone. Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline. + Parameters ---------- args: argparse.Namespace @@ -191,10 +201,10 @@ def __init__( self.text_len_t5 = args.text_len_t5 self.norm = args.norm - use_flash_attn = args.infer_mode == 'fa' + use_flash_attn = args.infer_mode == 'fa' or args.use_flash_attn if use_flash_attn: log_fn(f" Enable Flash Attention.") - qk_norm = True # See http://arxiv.org/abs/2302.05442 for details. + qk_norm = args.qk_norm # See http://arxiv.org/abs/2302.05442 for details. self.mlp_t5 = nn.Sequential( nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True), @@ -287,7 +297,6 @@ def forward(self, return_dict: bool Whether to return a dictionary. """ - text_states = encoder_hidden_states # 2,77,1024 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 text_states_mask = text_embedding_mask.bool() # 2,77 @@ -396,6 +405,82 @@ def unpatchify(self, x, h, w): imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) return imgs + def _replace_module(self, parent, child_name, new_module, child) -> None: + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.get_base_layer() + elif hasattr(child, "quant_linear_module"): + # TODO maybe not necessary to have special treatment? + child = child.quant_linear_module + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + # if any(prefix in name for prefix in PREFIXES): + # module.to(child.weight.device) + if "ranknum" in name: + module.to(child.weight.device) + + def merge_and_unload(self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names = None,): + if merge: + if getattr(self, "quantization_method", None) == "gptq": + raise ValueError("Cannot merge layers when the model is gptq quantized") + + def merge_recursively(module): + # helper function to recursively merge the base_layer of the target + path = [] + layer = module + while hasattr(layer, "base_layer"): + path.append(layer) + layer = layer.base_layer + for layer_before, layer_after in zip(path[:-1], path[1:]): + layer_after.merge(safe_merge=safe_merge, adapter_names=adapter_names) + layer_before.base_layer = layer_after.base_layer + module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + + key_list = [key for key, _ in self.named_modules()] + desc = "Unloading " + ("and merging " if merge else "") + "model" + + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + merge_recursively(target) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + new_module = target.modules_to_save[target.active_adapter] + if hasattr(new_module, "base_layer"): + # check if the module is itself a tuner layer + if merge: + new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + new_module = new_module.get_base_layer() + setattr(parent, target_name, new_module) + + ################################################################################# # HunYuanDiT Configs # @@ -404,6 +489,14 @@ def unpatchify(self, x, h, w): HUNYUAN_DIT_CONFIG = { 'DiT-g/2': {'depth': 40, 'hidden_size': 1408, 'patch_size': 2, 'num_heads': 16, 'mlp_ratio': 4.3637}, 'DiT-XL/2': {'depth': 28, 'hidden_size': 1152, 'patch_size': 2, 'num_heads': 16}, - 'DiT-L/2': {'depth': 24, 'hidden_size': 1024, 'patch_size': 2, 'num_heads': 16}, - 'DiT-B/2': {'depth': 12, 'hidden_size': 768, 'patch_size': 2, 'num_heads': 12}, } + +def DiT_g_2(args, **kwargs): + return HunYuanDiT(args, depth=40, hidden_size=1408, patch_size=2, num_heads=16, mlp_ratio=4.3637, **kwargs) +def DiT_XL_2(args, **kwargs): + return HunYuanDiT(args, depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) + +HUNYUAN_DIT_MODELS = { + 'DiT-g/2': DiT_g_2, + 'DiT-XL/2': DiT_XL_2, +} \ No newline at end of file diff --git a/hydit/modules/posemb_layers.py b/hydit/modules/posemb_layers.py index 62c83df..dcb41a7 100644 --- a/hydit/modules/posemb_layers.py +++ b/hydit/modules/posemb_layers.py @@ -10,12 +10,12 @@ def _to_tuple(x): return x -def get_fill_resize_and_crop(src, tgt): # src 来源的分辨率 tgt base 分辨率 +def get_fill_resize_and_crop(src, tgt): th, tw = _to_tuple(tgt) h, w = _to_tuple(src) - tr = th / tw # base 分辨率 - r = h / w # 目标分辨率 + tr = th / tw # base resolution + r = h / w # target resolution # resize if r > tr: @@ -23,7 +23,7 @@ def get_fill_resize_and_crop(src, tgt): # src 来源的分辨率 tgt base resize_width = int(round(th / h * w)) else: resize_width = tw - resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来 + resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution crop_top = int(round((th - resize_height) / 2.0)) crop_left = int(round((tw - resize_width) / 2.0)) @@ -44,13 +44,13 @@ def get_meshgrid(start, *args): num = (stop[0] - start[0], stop[1] - start[1]) elif len(args) == 2: # start is start, args[0] is stop, args[1] is num - start = _to_tuple(start) # 左上角 eg: 12,0 - stop = _to_tuple(args[0]) # 右下角 eg: 20,32 - num = _to_tuple(args[1]) # 目标大小 eg: 32,124 + start = _to_tuple(start) + stop = _to_tuple(args[0]) + num = _to_tuple(args[1]) else: raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") - grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份 + grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) # [2, W, H] @@ -137,7 +137,7 @@ def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True): [HW, D/2] """ grid = get_meshgrid(start, *args) # [2, H, w] - grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致 + grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) return pos_embed @@ -193,14 +193,13 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float def calc_sizes(rope_img, patch_size, th, tw): - """ 计算 RoPE 的尺寸. """ if rope_img == 'extend': - # 拓展模式 + # Expansion mode sub_args = [(th, tw)] elif rope_img.startswith('base'): - # 基于一个尺寸, 其他尺寸插值获得. - base_size = int(rope_img[4:]) // 8 // patch_size # 基于512作为base,其他根据512差值得到 - start, stop = get_fill_resize_and_crop((th, tw), base_size) # 需要在32x32里面 crop的左上角和右下角 + # Based on the specified dimensions, other dimensions are obtained through interpolation. + base_size = int(rope_img[4:]) // 8 // patch_size + start, stop = get_fill_resize_and_crop((th, tw), base_size) sub_args = [start, stop, (th, tw)] else: raise ValueError(f"Unknown rope_img: {rope_img}") @@ -218,7 +217,7 @@ def init_image_posemb(rope_img, freqs_cis_img = {} for reso in resolutions: th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size - sub_args = calc_sizes(rope_img, patch_size, th, tw) # [左上角, 右下角, 目标高宽] 需要在32x32里面 crop的左上角和右下角 + sub_args = calc_sizes(rope_img, patch_size, th, tw) freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real) log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) " f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}") diff --git a/hydit/modules/trt/engine.py b/hydit/modules/trt/engine.py deleted file mode 100644 index bba110e..0000000 --- a/hydit/modules/trt/engine.py +++ /dev/null @@ -1,152 +0,0 @@ -# -# Copyright 2022 The HuggingFace Inc. team. -# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import os -from collections import OrderedDict -from copy import copy - -import numpy as np -import tensorrt as trt -import torch -from polygraphy import cuda -from polygraphy.backend.common import bytes_from_path -from polygraphy.backend.trt import CreateConfig, Profile -from polygraphy.backend.trt import engine_from_bytes, engine_from_network, network_from_onnx_path, save_engine -from polygraphy.backend.trt import util as trt_util -import ctypes -from glob import glob -from cuda import cudart - -TRT_LOGGER = trt.Logger(trt.Logger.INFO) -trt_util.TRT_LOGGER = TRT_LOGGER - - -class Engine(): - def __init__( - self, - model_name, - engine_dir, - onnx_file=None, - ): - self.engine_path = os.path.join(engine_dir, model_name + '.plan') - self.engine = None - self.context = None - self.buffers = OrderedDict() - self.tensors = OrderedDict() - - self.weightNameList = None - self.refitter = None - self.onnx_initializers = None - self.onnx_file = onnx_file - self.trt_lora_weight = None - self.trt_lora_weight_mem = None - self.torch_weight = None - - def __del__(self): - del self.engine - del self.context - del self.buffers - del self.tensors - - def build(self, onnx_path, fp16, input_profile=None, enable_preview=False, sparse_weights=False): - print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") - p = Profile() - if input_profile: - for name, dims in input_profile.items(): - assert len(dims) == 3 - p.add(name, min=dims[0], opt=dims[1], max=dims[2]) - - preview_features = [] - if enable_preview: - trt_version = [int(i) for i in trt.__version__.split(".")] - # FASTER_DYNAMIC_SHAPES_0805 should only be used for TRT 8.5.1 or above. - if trt_version[0] > 8 or \ - (trt_version[0] == 8 and (trt_version[1] > 5 or (trt_version[1] == 5 and trt_version[2] >= 1))): - preview_features = [trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805] - - engine = engine_from_network(network_from_onnx_path(onnx_path), config=CreateConfig(fp16=fp16, profiles=[p], - preview_features=preview_features, - sparse_weights=sparse_weights)) - save_engine(engine, path=self.engine_path) - - def activate(self, plugin_path=""): - ctypes.cdll.LoadLibrary(plugin_path) - self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) - self.context = self.engine.create_execution_context() - - def get_shared_memory(self): - _, device_memory = cudart.cudaMalloc(self.engine.device_memory_size) - self.device_memory = device_memory - return self.device_memory - - def set_shared_memory(self, device_memory_size): - self.context.device_memory = device_memory_size - - def binding_input(self, name, shape): - idx = self.engine.get_binding_index(name) - result = self.context.set_binding_shape(idx, shape) - return result - - def allocate_buffers(self, shape_dict=None, device='cuda'): - print("Allocate buffers and bindings inputs:") - for idx in range(trt_util.get_bindings_per_profile(self.engine)): - binding = self.engine[idx] - print("binding: ", binding) - if shape_dict and binding in shape_dict: - shape = shape_dict[binding] - else: - shape = self.engine.get_binding_shape(binding) - nv_dtype = self.engine.get_binding_dtype(binding) - dtype_map = {trt.DataType.FLOAT: np.float32, - trt.DataType.HALF: np.float16, - trt.DataType.INT8: np.int8, - trt.DataType.INT64: np.int64, - trt.DataType.BOOL: bool} - if hasattr(trt.DataType, 'INT32'): - dtype_map[trt.DataType.INT32] = np.int32 - dtype = dtype_map[nv_dtype] - if self.engine.binding_is_input(binding): - self.context.set_binding_shape(idx, shape) - # Workaround to convert np dtype to torch - np_type_tensor = np.empty(shape=[], dtype=dtype) - torch_type_tensor = torch.from_numpy(np_type_tensor) - tensor = torch.empty(tuple(shape), dtype=torch_type_tensor.dtype).to(device=device) - - print(f" binding={binding}, shape={shape}, dtype={tensor.dtype}") - self.tensors[binding] = tensor - self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) - - def infer(self, feed_dict, stream): - start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) - # shallow copy of ordered dict - device_buffers = copy(self.buffers) - for name, buf in feed_dict.items(): - assert isinstance(buf, cuda.DeviceView) - device_buffers[name] = buf - self.binding_input(name, buf.shape) - bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] - noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) - if not noerror: - raise ValueError(f"ERROR: inference failed.") - - for idx in range(trt_util.get_bindings_per_profile(self.engine)): - binding = self.engine[idx] - if not self.engine.binding_is_input(binding): - shape = self.context.get_binding_shape(idx) - self.tensors[binding].resize_(tuple(shape)) - return self.tensors diff --git a/hydit/modules/trt/hcf_model.py b/hydit/modules/trt/hcf_model.py deleted file mode 100644 index b0ded21..0000000 --- a/hydit/modules/trt/hcf_model.py +++ /dev/null @@ -1,123 +0,0 @@ -import numpy as np -import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.models import ModelMixin -from polygraphy import cuda - -from .engine import Engine - - -class TRTModel(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - in_channels=4, - model_name="unet-dyn", - engine_dir="./unet", - device_id=0, - fp16=True, - image_width=1024, - image_height=1024, - text_maxlen=77, - embedding_dim=768, - max_batch_size=1, - plugin_path="./ckpts/trt_model/fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so", - ): - super().__init__() - # create engine - self.in_channels = in_channels # For pipeline compatibility - self.fp16 = fp16 - self.max_batch_size = max_batch_size - self.model_name = model_name - self.engine_dir = engine_dir - self.engine = Engine(self.model_name, self.engine_dir) - self.engine.activate(plugin_path) - # create cuda stream - self.stream = cuda.Stream() - # create inputs buffer - self.latent_width = image_width // 8 - self.latent_height = image_height // 8 - self.text_maxlen = text_maxlen - self.embedding_dim = embedding_dim - shape_dict = { - 'x': (2 * self.max_batch_size, 4, self.latent_height, self.latent_width), - 't': (2 * self.max_batch_size,), - 'encoder_hidden_states': (2 * self.max_batch_size, self.text_maxlen, self.embedding_dim), - 'text_embedding_mask': (2 * self.max_batch_size, self.text_maxlen), - 'encoder_hidden_states_t5': (2 * self.max_batch_size, 256, 2048), - 'text_embedding_mask_t5': (2 * self.max_batch_size, 256), - 'image_meta_size': (2 * self.max_batch_size, 6), - 'style': (2 * self.max_batch_size,), - 'cos_cis_img': (6400, 88), - 'sin_cis_img': (6400, 88), - 'output': (2 * self.max_batch_size, 8, self.latent_height, self.latent_width), - } - device = "cuda:{}".format(device_id) - self.engine_device = torch.device(device) - self.engine.allocate_buffers(shape_dict=shape_dict, device=device) - - print("[INFO] Create hcf nv controlled unet success") - - @property - def device(self): - return self.engine_device - - def __call__(self, x, t_emb, context, image_meta_size, style, freqs_cis_img0, - freqs_cis_img1, text_embedding_mask, encoder_hidden_states_t5, text_embedding_mask_t5): - return self.forward(x=x, t_emb=t_emb, context=context, image_meta_size=image_meta_size, style=style, - freqs_cis_img0=freqs_cis_img0, freqs_cis_img1=freqs_cis_img1, - text_embedding_mask=text_embedding_mask, encoder_hidden_states_t5=encoder_hidden_states_t5, - text_embedding_mask_t5=text_embedding_mask_t5) - - def get_shared_memory(self): - return self.engine.get_shared_memory() - - def set_shared_memory(self, shared_memory): - self.engine.set_shared_memory(shared_memory) - - def forward(self, x, t_emb, context, image_meta_size, style, freqs_cis_img0, - freqs_cis_img1, text_embedding_mask, encoder_hidden_states_t5, text_embedding_mask_t5): - x_c = x.half() - t_emb_c = t_emb.half() - context_c = context.half() - image_meta_size_c = image_meta_size.half() - style_c = style.long() - freqs_cis_img0_c = freqs_cis_img0.float() - freqs_cis_img1_c = freqs_cis_img1.float() - text_embedding_mask_c = text_embedding_mask.long() - encoder_hidden_states_t5_c = encoder_hidden_states_t5.half() - text_embedding_mask_t5_c = text_embedding_mask_t5.long() - dtype = np.float16 - batch_size = x.shape[0] // 2 - if batch_size <= self.max_batch_size: - sample_inp = cuda.DeviceView(ptr=x_c.reshape(-1).data_ptr(), shape=x_c.shape, dtype=np.float16) - t_emb_inp = cuda.DeviceView(ptr=t_emb_c.reshape(-1).data_ptr(), shape=t_emb_c.shape, dtype=np.float16) - embeddings_inp = cuda.DeviceView(ptr=context_c.reshape(-1).data_ptr(), shape=context_c.shape, - dtype=np.float16) - image_meta_size_inp = cuda.DeviceView(ptr=image_meta_size_c.reshape(-1).data_ptr(), - shape=image_meta_size_c.shape, dtype=np.float16) - style_inp = cuda.DeviceView(ptr=style_c.reshape(-1).data_ptr(), shape=style_c.shape, dtype=np.int64) - freqs_cis_img0_inp = cuda.DeviceView(ptr=freqs_cis_img0_c.reshape(-1).data_ptr(), - shape=freqs_cis_img0_c.shape, dtype=np.float32) - freqs_cis_img1_inp = cuda.DeviceView(ptr=freqs_cis_img1_c.reshape(-1).data_ptr(), - shape=freqs_cis_img1_c.shape, dtype=np.float32) - text_embedding_mask_inp = cuda.DeviceView(ptr=text_embedding_mask_c.reshape(-1).data_ptr(), - shape=text_embedding_mask_c.shape, dtype=np.int64) - encoder_hidden_states_t5_inp = cuda.DeviceView(ptr=encoder_hidden_states_t5_c.reshape(-1).data_ptr(), - shape=encoder_hidden_states_t5_c.shape, dtype=np.float16) - text_embedding_mask_t5_inp = cuda.DeviceView(ptr=text_embedding_mask_t5_c.reshape(-1).data_ptr(), - shape=text_embedding_mask_t5_c.shape, dtype=np.int64) - feed_dict = {"x": sample_inp, - "t": t_emb_inp, - "encoder_hidden_states": embeddings_inp, - "image_meta_size": image_meta_size_inp, - "text_embedding_mask": text_embedding_mask_inp, - "encoder_hidden_states_t5": encoder_hidden_states_t5_inp, - "text_embedding_mask_t5": text_embedding_mask_t5_inp, - "style": style_inp, "cos_cis_img": freqs_cis_img0_inp, - "sin_cis_img": freqs_cis_img1_inp} - latent = self.engine.infer(feed_dict, self.stream) - return latent['output'] - else: - raise ValueError( - "[ERROR] Input batch_size={} execeed max_batch_size={}".format(batch_size, self.max_batch_size)) diff --git a/hydit/run_g.sh b/hydit/run_g.sh new file mode 100644 index 0000000..50b043c --- /dev/null +++ b/hydit/run_g.sh @@ -0,0 +1,8 @@ +model='DiT-g/2' +params=" \ + --qk-norm \ + --model ${model} \ + --rope-img base512 \ + --rope-real \ + " +deepspeed hydit/train_deepspeed.py ${params} "$@" \ No newline at end of file diff --git a/hydit/train.sh b/hydit/train.sh new file mode 100644 index 0000000..65742b7 --- /dev/null +++ b/hydit/train.sh @@ -0,0 +1,41 @@ +task_flag="dit_g2_full_1024p" # the task flag is used to identify folders. +resume=./ckpts/t2i/model/ # checkpoint root for resume +index_file=dataset/index_v2_json/jade.json # index file for dataloader +results_dir=./log_EXP # save root for results +batch_size=1 # training batch size +image_size=1024 # training image resolution +grad_accu_steps=1 # gradient accumulation +warmup_num_steps=0 # warm-up steps +lr=0.0001 # learning rate +ckpt_every=10000 # create a ckpt every a few steps. +ckpt_latest_every=5000 # create a ckpt named `latest.pt` every a few steps. + + +sh $(dirname "$0")/run_g.sh \ + --task-flag ${task_flag} \ + --noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.03 \ + --predict-type v_prediction \ + --uncond-p 0.44 \ + --uncond-p-t5 0.44 \ + --index-file ${index_file} \ + --random-flip \ + --lr ${lr} \ + --batch-size ${batch_size} \ + --image-size ${image_size} \ + --global-seed 999 \ + --grad-accu-steps ${grad_accu_steps} \ + --warmup-num-steps ${warmup_num_steps} \ + --use-flash-attn \ + --use-fp16 \ + --use-ema \ + --ema-dtype fp32 \ + --results-dir ${results_dir} \ + --resume-split \ + --resume ${resume} \ + --ckpt-every ${ckpt_every} \ + --ckpt-latest-every ${ckpt_latest_every} \ + --log-every 10 \ + --deepspeed \ + --deepspeed-optimizer \ + --use-zero-stage 2 \ + "$@" \ No newline at end of file diff --git a/hydit/train_deepspeed.py b/hydit/train_deepspeed.py new file mode 100644 index 0000000..854a5d9 --- /dev/null +++ b/hydit/train_deepspeed.py @@ -0,0 +1,517 @@ +import gc +import json +import os +import random +import sys +import time +from functools import partial +from glob import glob +from pathlib import Path +import numpy as np + +import deepspeed +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from torch.utils.data import DataLoader +from torch.distributed.optim import ZeroRedundancyOptimizer +from torchvision.transforms import functional as TF +from diffusers.models import AutoencoderKL +from transformers import BertModel, BertTokenizer, logging as tf_logging + +from hydit.config import get_args +from hydit.constants import VAE_EMA_PATH, TEXT_ENCODER, TOKENIZER, T5_ENCODER +from hydit.lr_scheduler import WarmupLR +from hydit.data_loader.arrow_load_stream import TextImageArrowStream +from hydit.diffusion import create_diffusion +from hydit.ds_config import deepspeed_config_from_args +from hydit.modules.ema import EMA +from hydit.modules.fp16_layers import Float16Module +from hydit.modules.models import HUNYUAN_DIT_MODELS +from hydit.modules.posemb_layers import init_image_posemb +from hydit.utils.tools import create_logger, set_seeds, create_exp_folder, model_resume, get_trainable_params +from IndexKits.index_kits import ResolutionGroup +from IndexKits.index_kits.sampler import DistributedSamplerWithStartIndex, BlockDistributedSampler +from peft import LoraConfig, get_peft_model + + +def deepspeed_initialize(args, logger, model, opt, deepspeed_config): + logger.info(f"Initialize deepspeed...") + logger.info(f" Using deepspeed optimizer") + + def get_learning_rate_scheduler(warmup_min_lr, lr, warmup_num_steps, opt): + return WarmupLR(opt, warmup_min_lr, lr, warmup_num_steps) + + logger.info(f" Building scheduler with warmup_min_lr={args.warmup_min_lr}, warmup_num_steps={args.warmup_num_steps}") + model, opt, _, scheduler = deepspeed.initialize( + model=model, + model_parameters=get_trainable_params(model), + config_params=deepspeed_config, + args=args, + lr_scheduler=partial(get_learning_rate_scheduler, args.warmup_min_lr, args.lr, args.warmup_num_steps) if args.warmup_num_steps > 0 else None, + ) + return model, opt, scheduler + +def save_checkpoint(args, rank, logger, model, ema, epoch, train_steps, checkpoint_dir): + def save_lora_weight(checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt"): + cur_ckpt_save_dir = f"{checkpoint_dir}/{tag}" + if rank == 0: + if args.use_fp16: + model.module.module.save_pretrained(cur_ckpt_save_dir) + else: + model.module.save_pretrained(cur_ckpt_save_dir) + + checkpoint_path = "[Not rank 0. Disabled output.]" + + client_state = { + "steps": train_steps, + "epoch": epoch, + "args": args + } + if ema is not None: + client_state['ema'] = ema.state_dict() + + dst_paths = [] + if train_steps % args.ckpt_every == 0: + checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt" + try: + if args.training_parts == "lora": + save_lora_weight(checkpoint_dir, client_state, tag=f"{train_steps:07d}.pt") + else: + model.save_checkpoint(checkpoint_dir, client_state=client_state, tag=f"{train_steps:07d}.pt") + dst_paths.append(checkpoint_path) + logger.info(f"Saved checkpoint to {checkpoint_path}") + except: + logger.error(f"Saved failed to {checkpoint_path}") + + if train_steps % args.ckpt_latest_every == 0 or train_steps == args.max_training_steps: + save_name = "latest.pt" + checkpoint_path = f"{checkpoint_dir}/{save_name}" + try: + if args.training_parts == "lora": + save_lora_weight(checkpoint_dir, client_state, tag=f"{save_name}") + else: + model.save_checkpoint(checkpoint_dir, client_state=client_state, tag=f"{save_name}") + dst_paths.append(checkpoint_path) + logger.info(f"Saved checkpoint to {checkpoint_path}") + except: + logger.error(f"Saved failed to {checkpoint_path}") + + dist.barrier() + if rank == 0 and len(dst_paths) > 0: + # Delete optimizer states to avoid occupying too much disk space. + for dst_path in dst_paths: + for opt_state_path in glob(f"{dst_path}/zero_dp_rank_*_tp_rank_00_pp_rank_00_optim_states.pt"): + os.remove(opt_state_path) + + return checkpoint_path + +@torch.no_grad() +def prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img): + image, text_embedding, text_embedding_mask, text_embedding_t5, text_embedding_mask_t5, kwargs = batch + + # clip & mT5 text embedding + text_embedding = text_embedding.to(device) + text_embedding_mask = text_embedding_mask.to(device) + encoder_hidden_states = text_encoder( + text_embedding.to(device), + attention_mask=text_embedding_mask.to(device), + )[0] + text_embedding_t5 = text_embedding_t5.to(device).squeeze(1) + text_embedding_mask_t5 = text_embedding_mask_t5.to(device).squeeze(1) + with torch.no_grad(): + output_t5 = text_encoder_t5( + input_ids=text_embedding_t5, + attention_mask=text_embedding_mask_t5 if T5_ENCODER['attention_mask'] else None, + output_hidden_states=True + ) + encoder_hidden_states_t5 = output_t5['hidden_states'][T5_ENCODER['layer_index']].detach() + + # additional condition + image_meta_size = kwargs['image_meta_size'].to(device) + style = kwargs['style'].to(device) + + if args.extra_fp16: + image = image.half() + image_meta_size = image_meta_size.half() if image_meta_size is not None else None + + # Map input images to latent space + normalize latents: + image = image.to(device) + vae_scaling_factor = vae.config.scaling_factor + latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor) + + # positional embedding + _, _, height, width = image.shape + reso = f"{height}x{width}" + cos_cis_img, sin_cis_img = freqs_cis_img[reso] + + # Model conditions + model_kwargs = dict( + encoder_hidden_states=encoder_hidden_states, + text_embedding_mask=text_embedding_mask, + encoder_hidden_states_t5=encoder_hidden_states_t5, + text_embedding_mask_t5=text_embedding_mask_t5, + image_meta_size=image_meta_size, + style=style, + cos_cis_img=cos_cis_img, + sin_cis_img=sin_cis_img, + ) + + return latents, model_kwargs + +def main(args): + if args.training_parts == "lora": + args.use_ema = False + + assert torch.cuda.is_available(), "Training currently requires at least one GPU." + + dist.init_process_group("nccl") + world_size = dist.get_world_size() + batch_size = args.batch_size + grad_accu_steps = args.grad_accu_steps + global_batch_size = world_size * batch_size * grad_accu_steps + + rank = dist.get_rank() + device = rank % torch.cuda.device_count() + seed = args.global_seed * world_size + rank + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.cuda.set_device(device) + print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.") + deepspeed_config = deepspeed_config_from_args(args, global_batch_size) + + # Setup an experiment folder + experiment_dir, checkpoint_dir, logger = create_exp_folder(args, rank) + + # Log all the arguments + logger.info(sys.argv) + logger.info(str(args)) + # Save to a json file + args_dict = vars(args) + args_dict['world_size'] = world_size + with open(f"{experiment_dir}/args.json", 'w') as f: + json.dump(args_dict, f, indent=4) + + # Disable the message "Some weights of the model checkpoint at ... were not used when initializing BertModel." + # If needed, just comment the following line. + tf_logging.set_verbosity_error() + + # =========================================================================== + # Building HYDIT + # =========================================================================== + + logger.info("Building HYDIT Model.") + + # --------------------------------------------------------------------------- + # Training sample base size, such as 256/512/1024. Notice that this size is + # just a base size, not necessary the actual size of training samples. Actual + # size of the training samples are correlated with `resolutions` when enabling + # multi-resolution training. + # --------------------------------------------------------------------------- + image_size = args.image_size + if len(image_size) == 1: + image_size = [image_size[0], image_size[0]] + if len(image_size) != 2: + raise ValueError(f"Invalid image size: {args.image_size}") + assert image_size[0] % 8 == 0 and image_size[1] % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder). " \ + f"got {image_size}" + latent_size = [image_size[0] // 8, image_size[1] // 8] + + # initialize model by deepspeed + assert args.deepspeed, f"Must enable deepspeed in this script: train_deepspeed.py" + with deepspeed.zero.Init(data_parallel_group=torch.distributed.group.WORLD, + remote_device=None if args.remote_device == 'none' else args.remote_device, + config_dict_or_path=deepspeed_config, + mpu=None, + enabled=args.zero_stage == 3): + model = HUNYUAN_DIT_MODELS[args.model](args, + input_size=latent_size, + log_fn=logger.info, + ) + # Multi-resolution / Single-resolution training. + if args.multireso: + resolutions = ResolutionGroup(image_size[0], + align=16, + step=args.reso_step, + target_ratios=args.target_ratios).data + else: + resolutions = ResolutionGroup(image_size[0], + align=16, + target_ratios=['1:1']).data + + freqs_cis_img = init_image_posemb(args.rope_img, + resolutions=resolutions, + patch_size=model.patch_size, + hidden_size=model.hidden_size, + num_heads=model.num_heads, + log_fn=logger.info, + rope_real=args.rope_real, + ) + + # Create EMA model and convert to fp16 if needed. + ema = None + if args.use_ema: + ema = EMA(args, model, device, logger) + + # Setup FP16 main model: + if args.use_fp16: + model = Float16Module(model, args) + logger.info(f" Using main model with data type {'fp16' if args.use_fp16 else 'fp32'}") + + diffusion = create_diffusion( + noise_schedule=args.noise_schedule, + predict_type=args.predict_type, + learn_sigma=args.learn_sigma, + mse_loss_weight_type=args.mse_loss_weight_type, + beta_start=args.beta_start, + beta_end=args.beta_end, + noise_offset=args.noise_offset, + ) + + # Setup VAE + logger.info(f" Loading vae from {VAE_EMA_PATH}") + vae = AutoencoderKL.from_pretrained(VAE_EMA_PATH) + # Setup BERT text encoder + logger.info(f" Loading Bert text encoder from {TEXT_ENCODER}") + text_encoder = BertModel.from_pretrained(TEXT_ENCODER, False, revision=None) + # Setup BERT tokenizer: + logger.info(f" Loading Bert tokenizer from {TOKENIZER}") + tokenizer = BertTokenizer.from_pretrained(TOKENIZER) + # Setup T5 text encoder + from hydit.modules.text_encoder import MT5Embedder + mt5_path = T5_ENCODER['MT5'] + embedder_t5 = MT5Embedder(mt5_path, torch_dtype=T5_ENCODER['torch_dtype'], max_length=args.text_len_t5) + tokenizer_t5 = embedder_t5.tokenizer + text_encoder_t5 = embedder_t5.model + + if args.extra_fp16: + logger.info(f" Using fp16 for extra modules: vae, text_encoder") + vae = vae.half().to(device) + text_encoder = text_encoder.half().to(device) + text_encoder_t5 = text_encoder_t5.half().to(device) + else: + vae = vae.to(device) + text_encoder = text_encoder.to(device) + text_encoder_t5 = text_encoder_t5.to(device) + + logger.info(f" Optimizer parameters: lr={args.lr}, weight_decay={args.weight_decay}") + logger.info(" Using deepspeed optimizer") + opt = None + + # =========================================================================== + # Building Dataset + # =========================================================================== + + logger.info(f"Building Streaming Dataset.") + logger.info(f" Loading index file {args.index_file} (v2)") + + dataset = TextImageArrowStream(args=args, + resolution=image_size[0], + random_flip=args.random_flip, + log_fn=logger.info, + index_file=args.index_file, + multireso=args.multireso, + batch_size=batch_size, + world_size=world_size, + random_shrink_size_cond=args.random_shrink_size_cond, + merge_src_cond=args.merge_src_cond, + uncond_p=args.uncond_p, + text_ctx_len=args.text_len, + tokenizer=tokenizer, + uncond_p_t5=args.uncond_p_t5, + text_ctx_len_t5=args.text_len_t5, + tokenizer_t5=tokenizer_t5, + ) + if args.multireso: + sampler = BlockDistributedSampler(dataset, num_replicas=world_size, rank=rank, seed=args.global_seed, + shuffle=False, drop_last=True, batch_size=batch_size) + else: + sampler = DistributedSamplerWithStartIndex(dataset, num_replicas=world_size, rank=rank, seed=args.global_seed, + shuffle=False, drop_last=True) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler, + num_workers=args.num_workers, pin_memory=True, drop_last=True) + logger.info(f" Dataset contains {len(dataset):,} images.") + logger.info(f" Index file: {args.index_file}.") + if args.multireso: + logger.info(f' Using MultiResolutionBucketIndexV2 with step {dataset.index_manager.step} ' + f'and base size {dataset.index_manager.base_size}') + logger.info(f'\n {dataset.index_manager.resolutions}') + + # =========================================================================== + # Loading parameter + # =========================================================================== + + logger.info(f"Loading parameter") + start_epoch = 0 + start_epoch_step = 0 + train_steps = 0 + # Resume checkpoint if needed + if args.resume is not None or len(args.resume) > 0: + model, ema, start_epoch, start_epoch_step, train_steps = model_resume(args, model, ema, logger, len(loader)) + + if args.training_parts == "lora": + loraconfig = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + target_modules=args.target_modules + ) + if args.use_fp16: + model.module = get_peft_model(model.module, loraconfig) + else: + model = get_peft_model(model, loraconfig) + + logger.info(f" Training parts: {args.training_parts}") + + model, opt, scheduler = deepspeed_initialize(args, logger, model, opt, deepspeed_config) + + # =========================================================================== + # Training + # =========================================================================== + + model.train() + if args.use_ema: + ema.eval() + + print(f" Worker {rank} ready.") + dist.barrier() + + iters_per_epoch = len(loader) + logger.info(" ****************************** Running training ******************************") + logger.info(f" Number GPUs: {world_size}") + logger.info(f" Number training samples: {len(dataset):,}") + logger.info(f" Number parameters: {sum(p.numel() for p in model.parameters()):,}") + logger.info(f" Number trainable params: {sum(p.numel() for p in get_trainable_params(model)):,}") + logger.info(" ------------------------------------------------------------------------------") + logger.info(f" Iters per epoch: {iters_per_epoch:,}") + logger.info(f" Batch size per device: {batch_size}") + logger.info(f" Batch size all device: {batch_size * world_size * grad_accu_steps:,} (world_size * batch_size * grad_accu_steps)") + logger.info(f" Gradient Accu steps: {args.grad_accu_steps}") + logger.info(f" Total optimization steps: {args.epochs * iters_per_epoch // grad_accu_steps:,}") + + logger.info(f" Training epochs: {start_epoch}/{args.epochs}") + logger.info(f" Training epoch steps: {start_epoch_step:,}/{iters_per_epoch:,}") + logger.info(f" Training total steps: {train_steps:,}/{min(args.max_training_steps, args.epochs * iters_per_epoch):,}") + logger.info(" ------------------------------------------------------------------------------") + logger.info(f" Noise schedule: {args.noise_schedule}") + logger.info(f" Beta limits: ({args.beta_start}, {args.beta_end})") + logger.info(f" Learn sigma: {args.learn_sigma}") + logger.info(f" Prediction type: {args.predict_type}") + logger.info(f" Noise offset: {args.noise_offset}") + + logger.info(" ------------------------------------------------------------------------------") + logger.info(f" Using EMA model: {args.use_ema} ({args.ema_dtype})") + if args.use_ema: + logger.info(f" Using EMA decay: {ema.max_value if args.use_ema else None}") + logger.info(f" Using EMA warmup power: {ema.power if args.use_ema else None}") + logger.info(f" Using main model fp16: {args.use_fp16}") + logger.info(f" Using extra modules fp16: {args.extra_fp16}") + logger.info(" ------------------------------------------------------------------------------") + logger.info(f" Experiment directory: {experiment_dir}") + logger.info(" *******************************************************************************") + + if args.gc_interval > 0: + gc.disable() + gc.collect() + + # Variables for monitoring/logging purposes: + log_steps = 0 + running_loss = 0 + start_time = time.time() + + if args.async_ema: + ema_stream = torch.cuda.Stream() + + # Training loop + for epoch in range(start_epoch, args.epochs): + logger.info(f" Start random shuffle with seed={seed}") + # Makesure all processors use the same seed to shuffle dataset. + dataset.shuffle(seed=args.global_seed + epoch, fast=True) + logger.info(f" End of random shuffle") + + # Move sampler to start_index + if not args.multireso: + start_index = start_epoch_step * world_size * batch_size + if start_index != sampler.start_index: + sampler.start_index = start_index + # Reset start_epoch_step to zero, to ensure next epoch will start from the beginning. + start_epoch_step = 0 + logger.info(f" Iters left this epoch: {len(loader):,}") + + logger.info(f" Beginning epoch {epoch}...") + step = 0 + for batch in loader: + step += 1 + + latents, model_kwargs = prepare_model_inputs(args, batch, device, vae, text_encoder, text_encoder_t5, freqs_cis_img) + + # training model by deepspeed while use fp16 + if args.use_fp16: + if args.use_ema and args.async_ema: + with torch.cuda.stream(ema_stream): + ema.update(model.module.module, step=step) + torch.cuda.current_stream().wait_stream(ema_stream) + + loss_dict = diffusion.training_losses(model, latents, model_kwargs) + loss = loss_dict["loss"].mean() + model.backward(loss) + last_batch_iteration = (train_steps + 1) // (global_batch_size // (batch_size * world_size)) + model.step(lr_kwargs={'last_batch_iteration': last_batch_iteration}) + + if args.use_ema and not args.async_ema or (args.async_ema and step == len(loader)-1): + if args.use_fp16: + ema.update(model.module.module, step=step) + else: + ema.update(model.module, step=step) + + # =========================================================================== + # Log loss values: + # =========================================================================== + running_loss += loss.item() + log_steps += 1 + train_steps += 1 + if train_steps % args.log_every == 0: + # Measure training speed: + torch.cuda.synchronize() + end_time = time.time() + steps_per_sec = log_steps / (end_time - start_time) + # Reduce loss history over all processes: + avg_loss = torch.tensor(running_loss / log_steps, device=device) + dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) + avg_loss = avg_loss.item() / world_size + # get lr from deepspeed fused optimizer + logger.info(f"(step={train_steps:07d}) " + + (f"(update_step={train_steps // args.grad_accu_steps:07d}) " if args.grad_accu_steps > 1 else "") + + f"Train Loss: {avg_loss:.4f}, " + f"Lr: {opt.param_groups[0]['lr']:.6g}, " + f"Steps/Sec: {steps_per_sec:.2f}, " + f"Samples/Sec: {int(steps_per_sec * batch_size * world_size):d}") + # Reset monitoring variables: + running_loss = 0 + log_steps = 0 + start_time = time.time() + + # collect gc: + if args.gc_interval > 0 and (step % args.gc_interval == 0): + gc.collect() + + if (train_steps % args.ckpt_every == 0 or train_steps % args.ckpt_latest_every == 0 # or train_steps == args.max_training_steps + ) and train_steps > 0: + save_checkpoint(args, rank, logger, model, ema, epoch, train_steps, checkpoint_dir) + + if train_steps >= args.max_training_steps: + logger.info(f"Breaking step loop at {train_steps}.") + break + + if train_steps >= args.max_training_steps: + logger.info(f"Breaking epoch loop at {epoch}.") + break + + dist.destroy_process_group() + + +if __name__ == "__main__": + # Start + main(get_args()) diff --git a/hydit/utils/tools.py b/hydit/utils/tools.py index 66c0b03..71d8292 100644 --- a/hydit/utils/tools.py +++ b/hydit/utils/tools.py @@ -1,7 +1,21 @@ import random +import logging +from pathlib import Path +import shutil import numpy as np +from PIL import Image import torch +import torch.distributed as dist +from tqdm.auto import tqdm +import math +import torch.nn.functional as F +import os + +def get_trainable_params(model): + params = model.parameters() + params = [p for p in params if p.requires_grad] + return params def set_seeds(seed_list, device=None): @@ -15,3 +29,162 @@ def set_seeds(seed_list, device=None): torch.cuda.manual_seed_all(seed) return torch.Generator(device).manual_seed(seed) + +def get_start_epoch(resume_path, ckpt, steps_per_epoch): + if 'epoch' in ckpt: + start_epoch = ckpt['epoch'] + else: + start_epoch = 0 + if 'steps' in ckpt: + train_steps = ckpt['steps'] + else: + try: + train_steps = int(Path(resume_path).stem) + except: + train_steps = start_epoch * steps_per_epoch + + start_epoch_step = train_steps % steps_per_epoch + 1 + return start_epoch, start_epoch_step, train_steps + +def assert_shape(*args): + if len(args) < 2: + return + cond = True + fail_str = f"{args[0] if isinstance(args[0], (list, tuple)) else args[0].shape}" + for i in range(1, len(args)): + shape1 = args[i] if isinstance(args[i], (list, tuple)) else args[i].shape + shape2 = args[i - 1] if isinstance(args[i - 1], (list, tuple)) else args[i - 1].shape + cond = cond and (shape1 == shape2) + fail_str += f" vs {args[i] if isinstance(args[i], (list, tuple)) else args[i].shape}" + assert cond, fail_str + + +def create_logger(logging_dir=None, logging_file=None, ddp=True): + """ + Create a logger that writes to a log file and stdout. + """ + if not ddp or (ddp and dist.get_rank() == 0): # real logger + if logging_file is not None: + file_handler = [logging.FileHandler(logging_file)] + elif logging_dir is not None: + file_handler = [logging.FileHandler(f"{logging_dir}/log.txt")] + else: + file_handler = [] + logging.basicConfig( + level=logging.INFO, + format='[\033[34m%(asctime)s\033[0m] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler()] + file_handler + ) + logger = logging.getLogger(__name__) + else: + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + +def create_exp_folder(args, rank): + if rank == 0: + os.makedirs(args.results_dir, exist_ok=True) + existed_experiments = list(Path(args.results_dir).glob("*dit*")) + if len(existed_experiments) == 0: + experiment_index = 1 + else: + existed_experiments.sort() + print('existed_experiments', existed_experiments) + experiment_index = max([int(x.stem.split('-')[0]) for x in existed_experiments]) + 1 + dist.barrier() + model_string_name = args.task_flag if args.task_flag else args.model.replace("/", "-") + experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder + checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints + if rank == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + logger = create_logger(experiment_dir) + logger.info(f"Experiment directory created at {experiment_dir}") + else: + logger = create_logger() + experiment_dir = "" + + return experiment_dir, checkpoint_dir, logger + + +def model_resume(args, model, ema, logger, len_loader): + """ + Load pretrained weights. + """ + start_epoch = 0 + start_epoch_step = 0 + train_steps = 0 + resume_path = args.resume + if not Path(resume_path).exists(): + raise FileNotFoundError(f" Cannot find checkpoint from {resume_path}") + + logger.info(f" Resume deepspeed={args.resume_deepspeed}, " + f"Resume split={args.resume_split}, " + f"Resume from checkpoint {resume_path}") + # Resume model and ema states (not include optimizer states) from a checkpoint saved by Deepspeed version of DIT. + if args.resume_deepspeed: + assert 'mp_rank_00_model_states.pt' in os.listdir(resume_path), f' Cannot find dp chkpt from {resume_path}' + resume_ckpt = torch.load(os.path.join(resume_path, 'mp_rank_00_model_states.pt'), + map_location=lambda storage, loc: storage) + # Resume main model + if args.ema_to_module: + logger.info(" Resume main model from the ema states.") + model.load_state_dict(resume_ckpt['ema'], strict=args.strict) + else: + logger.info(" Resume main model from the main states.") + model.load_state_dict(resume_ckpt['module'], strict=args.strict) + # Resume EMA model + if args.use_ema: + if args.module_to_ema: + logger.info(" Resume EMA model from the main states.") + ema.load_state_dict(resume_ckpt['module'], strict=args.strict) + else: + logger.info(" Resume EMA model from the EMA states.") + ema.load_state_dict(resume_ckpt['ema'], strict=args.strict) + if not args.reset_loader: + start_epoch, start_epoch_step, train_steps = get_start_epoch(args.resume, resume_ckpt, len_loader) + # Resume model and ema states (not include optimizer states) from two checkpoints separated from DeepSpeed ckpt. + elif args.resume_split: + # Resume main model + if args.ema_to_module: + assert 'pytorch_model_ema.pt' in os.listdir( + resume_path), f' Cannot find pytorch_model_ema.pt from {resume_path}' + logger.info(f" Resume main model from ema states.") + resume_ckpt_ema = torch.load(os.path.join(resume_path, 'pytorch_model_ema.pt'), + map_location=lambda storage, loc: storage) + model.load_state_dict(resume_ckpt_ema, strict=args.strict) + else: + assert 'pytorch_model_module.pt' in os.listdir( + resume_path), f' Cannot find pytorch_model_module.pt from {resume_path}' + logger.info(f" Resume main model from main states.") + resume_ckpt_module = torch.load(os.path.join(resume_path, 'pytorch_model_module.pt'), + map_location=lambda storage, loc: storage) + model.load_state_dict(resume_ckpt_module, strict=args.strict) + # Resume ema model + if args.use_ema: + if args.module_to_ema: + if "resume_ckpt_module" in locals(): + logger.info(f" Resume ema model from main states.") + ema.load_state_dict(resume_ckpt_module, strict=args.strict) + else: + assert 'pytorch_model_module.pt' in os.listdir( + resume_path), f' Cannot find pytorch_model_module.pt from {resume_path}' + logger.info(f" Resume ema model from module states.") + resume_ckpt_module = torch.load(os.path.join(resume_path, 'pytorch_model_module.pt'), + map_location=lambda storage, loc: storage) + ema.load_state_dict(resume_ckpt_module, strict=args.strict) + else: + if "resume_ckpt_ema" in locals(): + logger.info(f" Resume ema model from EMA states.") + ema.load_state_dict(resume_ckpt_ema, strict=args.strict) + else: + assert 'pytorch_model_ema.pt' in os.listdir( + resume_path), f' Cannot find pytorch_model_ema.pt from {resume_path}' + logger.info(f" Resume ema model from EMA states.") + resume_ckpt_ema = torch.load(os.path.join(resume_path, 'pytorch_model_ema.pt'), + map_location=lambda storage, loc: storage) + ema.load_state_dict(resume_ckpt_ema, strict=args.strict) + else: + raise ValueError(" “If `resume` is True, then either `resume_split` or `resume_deepspeed` must be true.”") + + return model, ema, start_epoch, start_epoch_step, train_steps \ No newline at end of file diff --git a/lora/README.md b/lora/README.md new file mode 100644 index 0000000..47f87bc --- /dev/null +++ b/lora/README.md @@ -0,0 +1,102 @@ + +## Using LoRA to fine-tune HunyuanDiT + +### Training + +We provide three types of weights for fine-tuning LoRA, `ema`, `module` and `distill`, and you can choose according to the actual effect. By default, we use `ema` weights. + +Here is an example, we load the `ema` weights into the main model and perform LoRA fine-tuning through the `--ema-to-module` parameter. + +If you want to load the `module` weights into the main model, just remove the `--ema-to-module` parameter. + +If multiple resolution are used, you need to add the `--multireso` and `--reso-step 64 ` parameter. + +```bash +model='DiT-g/2' # model type +task_flag="lora_porcelain_ema_rank64" # task flag +resume=./ckpts/t2i/model/ # resume checkpoint +index_file=dataset/porcelain/jsons/porcelain.json # the selected data indices +results_dir=./log_EXP # save root for results +batch_size=1 # training batch size +image_size=1024 # training image resolution +grad_accu_steps=2 # gradient accumulation steps +warmup_num_steps=0 # warm-up steps +lr=0.0001 # learning rate +ckpt_every=100 # create a ckpt every a few steps. +ckpt_latest_every=2000 # create a ckpt named `latest.pt` every a few steps. +rank=64 # rank of lora +max_training_steps=2000 # Maximum training iteration steps + +PYTHONPATH=./ deepspeed hydit/train_deepspeed.py \ + --task-flag ${task_flag} \ + --model ${model} \ + --training_parts lora \ + --rank ${rank} \ + --resume-split \ + --resume ${resume} \ + --ema-to-module \ + --lr ${lr} \ + --noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.03 \ + --predict-type v_prediction \ + --uncond-p 0.44 \ + --uncond-p-t5 0.44 \ + --index-file ${index_file} \ + --random-flip \ + --batch-size ${batch_size} \ + --image-size ${image_size} \ + --global-seed 999 \ + --grad-accu-steps ${grad_accu_steps} \ + --warmup-num-steps ${warmup_num_steps} \ + --use-flash-attn \ + --use-fp16 \ + --ema-dtype fp32 \ + --results-dir ${results_dir} \ + --ckpt-every ${ckpt_every} \ + --max-training-steps ${max_training_steps}\ + --ckpt-latest-every ${ckpt_latest_every} \ + --log-every 10 \ + --deepspeed \ + --deepspeed-optimizer \ + --use-zero-stage 2 \ + --qk-norm \ + --rope-img base512 \ + --rope-real \ + "$@" +``` + +Recommended parameter settings + +| Parameter | Description | Recommended Parameter Value | Note| +|:---------------:|:---------:|:---------------------------------------------------:|:--:| +| `--batch_size` | Training batch size | 1 | Depends on GPU memory| +| `--grad-accu-steps` | Size of gradient accumulation | 2 | - | +| `--rank` | Rank of lora | 64 | choosing from 8-128 | +| `--max-training-steps` | Training steps | 2000 | Depend on training data size, for reference apply 2000 steps on 100 images| +| `--lr` | Learning rate | 0.0001 | - | + + +### Inference + +After the training is complete, you can use the following command line for inference. +We provide the `--lora_ckpt` parameter for selecting the folder which contains lora weights and configurations. + +a. Using LoRA during inference + +```bash +python sample_t2i.py --prompt "青花瓷风格,一只小狗" --no-enhance --lora_ckpt log_EXP/001-lora_porcelain_ema_rank64/checkpoints/0001000.pt +``` + +b. Using LoRA in gradio +```bash +python app/hydit_app.py --infer-mode fa --no-enhance --lora_ckpt log_EXP/001-lora_porcelain_ema_rank64/checkpoints/0001000.pt +``` + +c. Merge LoRA weights into the main model + +We provide the `--output_merge_path` parameter to set the path for saving the merged weights. + +```bash +PYTHONPATH=./ python lora/merge.py --lora_ckpt log_EXP/001-lora_porcelain_ema_rank64/checkpoints/0001000.pt --output_merge_path ./ckpts/t2i/model/pytorch_model_merge.pt +``` + +d. For more information, please refer to [HYDiT-LoRA](https://huggingface.co/Tencent-Hunyuan/HYDiT-LoRA). diff --git a/lora/merge.py b/lora/merge.py new file mode 100644 index 0000000..4fd7dc6 --- /dev/null +++ b/lora/merge.py @@ -0,0 +1,28 @@ +import torch +import os +from hydit.config import get_args +from hydit.modules.models import HUNYUAN_DIT_MODELS + +from hydit.inference import _to_tuple + +args = get_args() + +image_size = _to_tuple(args.image_size) +latent_size = (image_size[0] // 8, image_size[1] // 8) + +model = HUNYUAN_DIT_MODELS[args.model](args, + input_size=latent_size, + log_fn=print, + ) +model_path = os.path.join(args.model_root, 't2i', 'model', f"pytorch_model_{args.load_key}.pt") +state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) + +print(f"Loading model from {model_path}") +model.load_state_dict(state_dict) + +print(f"Loading lora from {args.lora_ckpt}") +model.load_adapter(args.lora_ckpt) +model.merge_and_unload() + +torch.save(model.state_dict(), args.output_merge_path) +print(f"Model saved to {args.output_merge_path}") \ No newline at end of file diff --git a/lora/train_lora.sh b/lora/train_lora.sh new file mode 100644 index 0000000..d252f4a --- /dev/null +++ b/lora/train_lora.sh @@ -0,0 +1,50 @@ +model='DiT-g/2' # model type +task_flag="lora_jade_ema_rank64" # task flag +resume=./ckpts/t2i/model/ # resume checkpoint +index_file=dataset/index_v2_json/jade.json # the selected data indices +results_dir=./log_EXP # save root for results +batch_size=1 # training batch size +image_size=1024 # training image resolution +grad_accu_steps=2 # gradient accumulation steps +warmup_num_steps=0 # warm-up steps +lr=0.0001 # learning rate +ckpt_every=100 # create a ckpt every a few steps. +ckpt_latest_every=2000 # create a ckpt named `latest.pt` every a few steps. +rank=64 # rank of lora +max_training_steps=2000 # Maximum training iteration steps + +PYTHONPATH=./ deepspeed hydit/train_deepspeed.py \ + --task-flag ${task_flag} \ + --model ${model} \ + --training_parts lora \ + --rank ${rank} \ + --resume-split \ + --resume ${resume} \ + --ema-to-module \ + --lr ${lr} \ + --noise-schedule scaled_linear --beta-start 0.00085 --beta-end 0.03 \ + --predict-type v_prediction \ + --uncond-p 0.44 \ + --uncond-p-t5 0.44 \ + --index-file ${index_file} \ + --random-flip \ + --batch-size ${batch_size} \ + --image-size ${image_size} \ + --global-seed 999 \ + --grad-accu-steps ${grad_accu_steps} \ + --warmup-num-steps ${warmup_num_steps} \ + --use-flash-attn \ + --use-fp16 \ + --ema-dtype fp32 \ + --results-dir ${results_dir} \ + --ckpt-every ${ckpt_every} \ + --max-training-steps ${max_training_steps}\ + --ckpt-latest-every ${ckpt_latest_every} \ + --log-every 10 \ + --deepspeed \ + --deepspeed-optimizer \ + --use-zero-stage 2 \ + --qk-norm \ + --rope-img base512 \ + --rope-real \ + "$@" diff --git a/requirements.txt b/requirements.txt index f350e8a..8de35fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,8 @@ diffusers==0.21.2 peft==0.10.0 protobuf==3.19.0 torchvision==0.14.1 -transformers==4.37.2 +transformers==4.39.1 +peft==0.10.0 accelerate==0.29.3 loguru==0.7.2 einops==0.7.0 @@ -17,3 +18,5 @@ onnx-graphsurgeon==0.3.27 polygraphy==0.47.1 pandas==2.0.3 gradio==3.50.2 +deepspeed==0.6.3 +pyarrow==16.1.0 \ No newline at end of file