Skip to content

Commit

Permalink
support training, lora, and upgrading the Hydit-v1.1 model
Browse files Browse the repository at this point in the history
  • Loading branch information
zml-ai committed Jun 13, 2024
1 parent ebfb793 commit 29766aa
Show file tree
Hide file tree
Showing 45 changed files with 8,502 additions and 411 deletions.
148 changes: 148 additions & 0 deletions IndexKits/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# IndexKits

[TOC]

## Introduction

Index Kits (`index_kits`) for streaming Arrow data.

* Supports creating datasets from configuration files.
* Supports creating index v2 from arrows.
* Supports the creation, reading, and usage of index files.
* Supports the creation, reading, and usage of multi-bucket, multi-resolution indexes.

`index_kits` has the following advantages:

* Optimizes memory usage for streaming reads, supporting the reading of data at the billion level.
* Optimizes memory usage and reading speed of Index files.
* Supports the creation of Base/Multireso Index V2 datasets, including data filtering, repeating, and deduplication during creation.
* Supports loading various dataset types (Arrow files, Base Index V2, multiple Base Index V2, Multireso Index V2, multiple Multireso Index V2).
* Uses a unified API interface to access images, text, and other attributes for all dataset types.
* Built-in shuffle, `resize_and_crop`, and returns the coordinates of the crop.

## Installation

* **Install from pre-compiled whl (recommended)**

* Install from source

```shell
cd IndexKits
pip install -e .
```

## Usage

### Loading an Index V2 File

```python
from index_kits import ArrowIndexV2

index_manager = ArrowIndexV2('data.json')

# You can shuffle the dataset using a seed. If a seed is provided (i.e., the seed is not None), this shuffle will not affect any random state.
index_manager.shuffle(1234, fast=True)

for i in range(len(index_manager)):
# Get an image (requires the arrow to contain an image/binary column):
pil_image = index_manager.get_image(i)
# Get text (requires the arrow to contain a text_zh column):
text = index_manager.get_attribute(i, column='text_zh')
# Get MD5 (requires the arrow to contain an md5 column):
md5 = index_manager.get_md5(i)
# Get any attribute by specifying the column (must be contained in the arrow):
ocr_num = index_manager.get_attribute(i, column='ocr_num')
# Get multiple attributes at once
item = index_manager.get_data(i, columns=['text_zh', 'md5']) # i: in-dataset index
print(item)
# {
# 'index': 3, # in-json index
# 'in_arrow_index': 3, # in-arrow index
# 'arrow_name': '/HunYuanDiT/dataset/porcelain/00000.arrow',
# 'text_zh': 'Fortune arrives with the auspicious green porcelain tea cup',
# 'md5': '1db68f8c0d4e95f97009d65fdf7b441c'
# }
```


### Loading a Set of Arrow Files

If you have a small batch of arrow files, you can directly use `IndexV2Builder` to load these arrow files without creating an Index V2 file.

```python
from index_kits import IndexV2Builder

index_manager = IndexV2Builder(arrow_files).to_index_v2()
```

### Loading a Multi-Bucket (Multi-Resolution) Index V2 File

When using a multi-bucket (multi-resolution) index file, refer to the following example code, especially the definition of SimpleDataset and the usage of MultiResolutionBucketIndexV2.

```python
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T

from index_kits import MultiResolutionBucketIndexV2
from index_kits.sampler import BlockDistributedSampler

class SimpleDataset(Dataset):
def __init__(self, index_file, batch_size, world_size):
# When using multi-bucket (multi-resolution), batch_size and world_size need to be specified for underlying data alignment.
self.index_manager = MultiResolutionBucketIndexV2(index_file, batch_size, world_size)

self.flip_norm = T.Compose(
[
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)

def shuffle(self, seed, fast=False):
self.index_manager.shuffle(seed, fast=fast)

def __len__(self):
return len(self.index_manager)

def __getitem__(self, idx):
image = self.index_manager.get_image(idx)
original_size = image.size # (w, h)
target_size = self.index_manager.get_target_size(idx) # (w, h)
image, crops_coords_left_top = self.index_manager.resize_and_crop(image, target_size)
image = self.flip_norm(image)
text = self.index_manager.get_attribute(idx, column='text_zh')
return image, text, (original_size, target_size, crops_coords_left_top)

batch_size = 8 # batch_size per GPU
world_size = 8 # total number of GPUs
rank = 0 # rank of the current process
num_workers = 4 # customizable based on actual conditions
shuffle = False # must be set to False
drop_last = True # must be set to True

# Correct batch_size and world_size must be passed in, otherwise, the multi-resolution data cannot be aligned correctly.
dataset = SimpleDataset('data_multireso.json', batch_size=batch_size, world_size=world_size)
# Must use BlockDistributedSampler to ensure the samples in a batch have the same resolution.
sampler = BlockDistributedSampler(dataset, num_replicas=world_size, rank=rank,
shuffle=shuffle, drop_last=drop_last, batch_size=batch_size)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler,
num_workers=num_workers, pin_memory=True, drop_last=drop_last)

for epoch in range(10):
# Please use the shuffle method provided by the dataset, not the DataLoader's shuffle parameter.
dataset.shuffle(epoch, fast=True)
for batch in loader:
pass
```


## Fast Shuffle

When your index v2 file contains hundreds of millions of samples, using the default shuffle can be quite slow. Therefore, it’s recommended to enable `fast=True` mode:

```python
index_manager.shuffle(seed=1234, fast=True)
```

This will globally shuffle without keeping the indices of the same arrow together. Although this may reduce reading speed, it requires a trade-off based on the model’s forward time.
149 changes: 149 additions & 0 deletions IndexKits/bin/idk
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#!/usr/bin/env python3

import argparse

from index_kits.dataset.make_dataset_core import startup, make_multireso
from index_kits.common import show_index_info
from index_kits import __version__


def common_args(parser):
parser.add_argument('-t', '--target', type=str, required=True, help='Save path')


def get_args():
parser = argparse.ArgumentParser(description="""
IndexKits is a tool to build and manage index files for large-scale datasets.
It supports both base index and multi-resolution index.
Introduction
------------
This command line tool provides the following functionalities:
1. Show index v2 information
2. Build base index v2
3. Build multi-resolution index v2
Examples
--------
1. Show index v2 information
index_kits show /path/to/index.json
2. Build base index v2
Default usage:
index_kits base -c /path/to/config.yaml -t /path/to/index.json
Use multiple processes:
index_kits base -c /path/to/config.yaml -t /path/to/index.json -w 40
3. Build multi-resolution index v2
Build with a configuration file:
index_kits multireso -c /path/to/config.yaml -t /path/to/index_mb_gt512.json
Build by specifying arguments without a configuration file:
index_kits multireso --src /path/to/index.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
Build by specifying target-ratios:
index_kits multireso --src /path/to/index.json --base-size 512 --target-ratios 1:1 4:3 3:4 16:9 9:16 --min-size 512 -t /path/to/index_mb_gt512.json
Build with multiple source index files.
index_kits multireso --src /path/to/index1.json /path/to/index2.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
""", formatter_class=argparse.RawTextHelpFormatter)
sub_parsers = parser.add_subparsers(dest='task', required=True)

# Show index message
show_parser = sub_parsers.add_parser('show', description="""
Show base/multireso index v2 information.
Example
-------
index_kits show /path/to/index.json
""", formatter_class=argparse.RawTextHelpFormatter)
show_parser.add_argument('src', type=str, help='Path to a base/multireso index file.')
show_parser.add_argument('--arrow-files', action='store_true', help='Show arrow files only.')
show_parser.add_argument('--depth', type=int, default=1,
help='Arrow file depth. Default is 1, the level of last folder in the arrow file path. '
'Set it to 0 to show the full path including `xxx/last_folder/*.arrow`.')

# Single resolution bucket
base_parser = sub_parsers.add_parser('base', description="""
Build base index v2.
Example
-------
index_kits base -c /path/to/config.yaml -t /path/to/index.json
""", formatter_class=argparse.RawTextHelpFormatter)
base_parser.add_argument('-c', '--config', type=str, required=True, help='Configuration file path')
common_args(base_parser)
base_parser.add_argument('-w', '--world-size', type=int, default=1)
base_parser.add_argument('--work-dir', type=str, default='.', help='Work directory')
base_parser.add_argument('--use-cache', action='store_true', help='Use cache to avoid reprocessing. '
'Perform merge pkl results directly.')

# Multi-resolution bucket
mo_parser = sub_parsers.add_parser('multireso', description="""
Build multi-resolution index v2
Example
-------
Build with a configuration file:
index_kits multireso -c /path/to/config.yaml -t /path/to/index_mb_gt512.json
Build by specifying arguments without a configuration file:
index_kits multireso --src /path/to/index.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
Build by specifying target-ratios:
index_kits multireso --src /path/to/index.json --base-size 512 --target-ratios 1:1 4:3 3:4 16:9 9:16 --min-size 512 -t /path/to/index_mb_gt512.json
Build with multiple source index files.
index_kits multireso --src /path/to/index1.json /path/to/index2.json --base-size 512 --reso-step 32 --min-size 512 -t /path/to/index_mb_gt512.json
""", formatter_class=argparse.RawTextHelpFormatter)
mo_parser.add_argument('-c', '--config', type=str, default=None,
help='Configuration file path in a yaml format. Either --config or --src must be provided.')
mo_parser.add_argument('-s', '--src', type=str, nargs='+', default=None,
help='Source index files. Either --config or --src must be provided.')
common_args(mo_parser)
mo_parser.add_argument('--base-size', type=int, default=None, help="Base size. Typically set as 256/512/1024 according to image size you train model.")
mo_parser.add_argument('--reso-step', type=int, default=None,
help="Resolution step. Either reso_step or target_ratios must be provided.")
mo_parser.add_argument('--target-ratios', type=str, nargs='+', default=None,
help="Target ratios. Either reso_step or target_ratios must be provided.")
mo_parser.add_argument('--md5-file', type=str, default=None,
help='You can provide an md5 to height and width file to accelerate the process. '
'It is a pickle file that contains a dict, which maps md5 to (height, width) tuple.')
mo_parser.add_argument('--align', type=int, default=16, help="Used when --target-ratios is provided. Align size of source image height and width.")
mo_parser.add_argument('--min-size', type=int, default=0,
help="Minimum size. Images smaller than this size will be ignored.")

# Common
parser.add_argument('-v', '--version', action='version', version=f'%(prog)s {__version__}')

args = parser.parse_args()
return args


if __name__ == '__main__':
args = get_args()
if args.task == 'show':
show_index_info(args.src,
args.arrow_files,
args.depth,
)
elif args.task == 'base':
startup(args.config,
args.target,
args.world_size,
args.work_dir,
use_cache=args.use_cache,
)
elif args.task == 'multireso':
make_multireso(args.target,
args.config,
args.src,
args.base_size,
args.reso_step,
args.target_ratios,
args.align,
args.min_size,
args.md5_file,
)
Loading

0 comments on commit 29766aa

Please sign in to comment.