Skip to content

Commit

Permalink
Merge pull request #160 from libffcv/v1.0.0
Browse files Browse the repository at this point in the history
V1.0.0
  • Loading branch information
GuillaumeLeclerc committed Mar 3, 2023
2 parents b865918 + 7cd3442 commit 700bdf3
Show file tree
Hide file tree
Showing 29 changed files with 1,114 additions and 298 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,29 @@ Keep your training algorithm the same, just replace the data loader! Look at the
<img src="docs/_static/perf_scatterplot.svg" width='830px'/>

## Installation
### Linux
```
conda create -y -n ffcv python=3.9 cupy pkg-config compilers libjpeg-turbo opencv pytorch torchvision cudatoolkit=11.3 numba -c pytorch -c conda-forge
conda activate ffcv
pip install ffcv
```
Troubleshooting note: if the above commands result in a package conflict error, try running ``conda config --env --set channel_priority flexible`` in the environment and rerunning the installation command.

### Windows
* Install <a href="https://opencv.org/releases/">opencv4</a>
* Add `..../opencv/build/x64/vc15/bin` to PATH environment variable
* Install <a href="https://sourceforge.net/projects/libjpeg-turbo/files/">libjpeg-turbo</a>, download libjpeg-turbo-x.x.x-vc64.exe, not gcc64
* Add `..../libjpeg-turbo64/bin` to PATH environment variable
* Install <a href="https://www.sourceware.org/pthreads-win32/">pthread</a>, download last release.zip
* After unzip, rename Pre-build.2 folder to pthread
* Open `pthread/include/pthread.h`, and add the code below to the top of the file.
```cpp
#define HAVE_STRUCT_TIMESPEC
```
* Add `..../pthread/dll` to PATH environment variable
* Install <a href="https://docs.cupy.dev/en/stable/install.html#installing-cupy">cupy</a> depending on your CUDA Toolkit version.
* `pip install ffcv`

## Citation
If you use FFCV, please cite it as:

Expand Down
Binary file added ffcv/.DS_Store
Binary file not shown.
4 changes: 2 additions & 2 deletions ffcv/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from .basics import FloatField, IntField
from .rgb_image import RGBImageField
from .bytes import BytesField
from .ndarray import NDArrayField
from .ndarray import NDArrayField, TorchTensorField
from .json import JSONField

__all__ = ['Field', 'BytesField', 'IntField', 'FloatField', 'RGBImageField',
'NDArrayField', 'JSONField']
'NDArrayField', 'JSONField', 'TorchTensorField']
25 changes: 24 additions & 1 deletion ffcv/fields/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Callable, TYPE_CHECKING, Tuple, Type
import warnings
import json
from dataclasses import replace

import numpy as np
import torch as ch

from .base import Field, ARG_TYPE
from ..pipeline.operation import Operation
Expand Down Expand Up @@ -55,6 +57,10 @@ def __init__(self, dtype:np.dtype, shape:Tuple[int, ...]):
self.dtype = dtype
self.shape = shape
self.element_size = dtype.itemsize * np.prod(shape)
if dtype == np.uint16:
warnings.warn("Pytorch currently doesn't support uint16"
"we recommend storing as int16 and reinterpret your data later"
"in your pipeline")

@property
def metadata_type(self) -> np.dtype:
Expand Down Expand Up @@ -93,4 +99,21 @@ def encode(self, destination, field, malloc):
data_region[:] = field.reshape(-1).view('<u1')

def get_decoder_class(self) -> Type[Operation]:
return NDArrayDecoder
return NDArrayDecoder


class TorchTensorField(NDArrayField):
"""A subclass of :class:`~ffcv.fields.Field` supporting
multi-dimensional fixed size matrices of any torch type.
"""
def __init__(self, dtype:ch.dtype, shape:Tuple[int, ...]):
self.dtype = dtype
self.shape = shape
dtype = ch.zeros(0, dtype=dtype).numpy().dtype

super().__init__(dtype, shape)


def encode(self, destination, field, malloc):
field = field.numpy()
return super().encode(destination, field, malloc)
11 changes: 8 additions & 3 deletions ffcv/libffcv.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import ctypes
from numba import njit
import numpy as np
import platform
from ctypes import CDLL, c_int64, c_uint8, c_uint64, POINTER, c_void_p, c_uint32, c_bool, cdll
import ffcv._libffcv

lib = CDLL(ffcv._libffcv.__file__)
libc = cdll.LoadLibrary('libc.so.6')
if platform.system() == "Windows":
libc = cdll.msvcrt
read_c = libc._read
else:
libc = cdll.LoadLibrary('libc.so.6')
read_c = libc.pread

read_c = libc.pread
read_c.argtypes = [c_uint32, c_void_p, c_uint64, c_uint64]

def read(fileno:int, destination:np.ndarray, offset:int):
Expand Down Expand Up @@ -47,5 +52,5 @@ def imdecode(source: np.ndarray, dst: np.ndarray,
ctypes_memcopy.argtypes = [c_void_p, c_void_p, c_uint64]

def memcpy(source: np.ndarray, dest: np.ndarray):
return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size)
return ctypes_memcopy(source.ctypes.data, dest.ctypes.data, source.size*source.itemsize)

86 changes: 50 additions & 36 deletions ffcv/loader/epoch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@
(`OrderOption.QUASI_RANDOM`) in the dataloader constructor's `order` argument.
'''

def select_buffer(buffer, batch_slot, count):
"""Util function to select the relevent subpart of a buffer for a given
batch_slot and batch size"""
if buffer is None:
return None
if isinstance(buffer, tuple):
return tuple(select_buffer(x, batch_slot, count) for x in buffer)

return buffer[batch_slot][:count]


class EpochIterator(Thread):
def __init__(self, loader: 'Loader', order: Sequence[int]):
super().__init__(daemon=True)
Expand All @@ -33,6 +44,10 @@ def __init__(self, loader: 'Loader', order: Sequence[int]):
self.terminate_event = Event()
self.memory_context = self.loader.memory_manager.schedule_epoch(
batches)

if IS_CUDA:
self.current_stream = ch.cuda.current_stream()

try:
self.memory_context.__enter__()
except MemoryError as e:
Expand All @@ -44,23 +59,13 @@ def __init__(self, loader: 'Loader', order: Sequence[int]):

self.storage_state = self.memory_context.state

self.memory_bank_per_stage = defaultdict(list)

self.cuda_streams = [(ch.cuda.Stream() if IS_CUDA else None)
for _ in range(self.loader.batches_ahead + 2)]

# Allocate all the memory
memory_allocations = {}
for (p_id, p) in self.loader.pipelines.items():
memory_allocations[p_id] = p.allocate_memory(self.loader.batch_size,
self.loader.batches_ahead + 2)

# Assign each memory bank to the pipeline stage it belongs to
for s_ix, banks in self.loader.memory_bank_keys_per_stage.items():
for (pipeline_name, op_id) in banks:
self.memory_bank_per_stage[s_ix].append(
memory_allocations[pipeline_name][op_id]
)
self.memory_allocations = self.loader.graph.allocate_memory(
self.loader.batch_size,
self.loader.batches_ahead + 2
)

self.start()

Expand All @@ -77,6 +82,7 @@ def run(self):
self.current_batch_slot = (
slot + 1) % (self.loader.batches_ahead + 2)
result = self.run_pipeline(b_ix, ixes, slot, events[slot])
# print("RES", b_ix, "ready")
to_output = (slot, result)
while True:
try:
Expand All @@ -88,23 +94,24 @@ def run(self):
if self.terminate_event.is_set():
return
if IS_CUDA:
# print("SUB", b_ix)
# We were able to submit this batch
# Therefore it means that the user must have entered the for loop for
# (batch_slot - batch_ahead + 1) % (batches ahead + 2)
# Therefore batch_slot - batch_ahead must have all it's work submitted
# We will record an event of all the work submitted on the main stream
# and make sure no one overwrite the data until they are done
just_finished_slot = (slot - self.loader.batches_ahead) % (self.loader.batches_ahead + 2)
just_finished_slot = (slot - self.loader.batches_ahead - 1) % (self.loader.batches_ahead + 2)
# print("JFS", just_finished_slot)
event = ch.cuda.Event()
event.record(ch.cuda.default_stream())
event.record(self.current_stream)
events[just_finished_slot] = event
b_ix += 1

except StopIteration:
self.output_queue.put(None)

def run_pipeline(self, b_ix, batch_indices, batch_slot, cuda_event):
# print(b_ix, batch_indices)
self.memory_context.start_batch(b_ix)
args = []
if IS_CUDA:
Expand All @@ -114,28 +121,35 @@ def run_pipeline(self, b_ix, batch_indices, batch_slot, cuda_event):
ctx = nullcontext()
first_stage = False


code, outputs = self.loader.code
with ctx:
if IS_CUDA:
if cuda_event:
cuda_event.wait()
for stage, banks in self.memory_bank_per_stage.items():
args.insert(0, batch_indices)
for bank in banks:
if bank is not None:
if isinstance(bank, tuple):
bank = tuple(x[batch_slot] for x in bank)
else:
bank = bank[batch_slot]
args.append(bank)
args.append(self.metadata)
args.append(self.storage_state)
code = self.loader.code_per_stage[stage]
result = code(*args)
args = list(result)
if first_stage:
first_stage = False
self.memory_context.end_batch(b_ix)
return tuple(x[:len(batch_indices)] for x in args)

args = {
'batch_indices': batch_indices,
'storage_state': self.storage_state,
'metadata': self.metadata,
**{
f'memory_{k}':select_buffer(v, batch_slot, len(batch_indices))
for (k, v) in self.memory_allocations['operation'].items()
},
**{
f'shared_memory_{k}': select_buffer(v, batch_slot, len(batch_indices))
for (k, v) in self.memory_allocations['shared'].items()
}
}

for stage_code, define_outputs in code:
results = stage_code(**args)
for node_id, result in zip(define_outputs, results):
args[f'result_{node_id}'] = result
pass

result = tuple(args[f'result_{x}'] for x in outputs)
return result

def __next__(self):
result = self.output_queue.get()
Expand All @@ -146,7 +160,7 @@ def __next__(self):
if IS_CUDA:
stream = self.cuda_streams[slot]
# We wait for the copy to be done
ch.cuda.current_stream().wait_stream(stream)
self.current_stream.wait_stream(stream)
return result

def __iter__(self):
Expand Down
Loading

0 comments on commit 700bdf3

Please sign in to comment.