Skip to content

Commit

Permalink
Merge branch 'public-dev' into public-main
Browse files Browse the repository at this point in the history
  • Loading branch information
amogh7joshi committed Sep 9, 2023
2 parents 0c380be + 906240f commit 73e646f
Show file tree
Hide file tree
Showing 11 changed files with 269 additions and 48 deletions.
2 changes: 1 addition & 1 deletion agml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = '0.5.1'
__version__ = '0.5.2'
__all__ = ['data', 'synthetic', 'backend', 'viz', 'io']


Expand Down
122 changes: 108 additions & 14 deletions agml/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import copy
import glob
import fnmatch
from typing import Union
from collections.abc import Sequence
from decimal import getcontext, Decimal
Expand All @@ -31,7 +32,9 @@
from agml.utils.data import load_public_sources
from agml.utils.general import NoArgument, resolve_list_value
from agml.utils.random import inject_random_state
from agml.backend.config import data_save_path, synthetic_data_save_path
from agml.backend.config import (
data_save_path, synthetic_data_save_path, SUPER_BASE_DIR
)
from agml.backend.experimental import AgMLExperimentalFeatureWrapper
from agml.backend.tftorch import (
get_backend, set_backend,
Expand Down Expand Up @@ -101,14 +104,25 @@ class AgMLDataLoader(AgMLSerializable, metaclass = AgMLDataLoaderMeta):
See the methods for examples on how to use an `AgMLDataLoader` effectively.
"""
serializable = frozenset((
'info', 'builder', 'manager', 'train_data',
'val_data', 'test_data', 'is_split', 'meta_properties'))
'info', 'builder', 'manager', 'train_data', 'train_content', 'val_data',
'val_content', 'test_data', 'test_content', 'is_split', 'meta_properties'))

def __new__(cls, dataset, **kwargs):
# If a single dataset is passed, then we use the base `AgMLDataLoader`.
# However, if an iterable of datasets is passed, then we need to
# dispatch to the subclass `AgMLMultiDatasetLoader` for them.
if isinstance(dataset, (str, DatasetMetadata)):
if '*' in dataset: # enables wildcard search for datasets
valid_datasets = fnmatch.filter(load_public_sources().keys(), dataset)
if len(valid_datasets) == 0:
raise ValueError(
f"Wildcard search for dataset '{dataset}' yielded no results.")
if len(valid_datasets) == 1:
log(f"Wildcard search for dataset '{dataset}' yielded only "
f"one result. Returning a regular, single-element data loader.")
return super(AgMLDataLoader, cls).__new__(cls)
from agml.data.multi_loader import AgMLMultiDatasetLoader
return AgMLMultiDatasetLoader(valid_datasets, **kwargs)
return super(AgMLDataLoader, cls).__new__(cls)
elif isinstance(dataset, Sequence):
if len(dataset) == 1:
Expand Down Expand Up @@ -152,8 +166,11 @@ def __init__(self, dataset, **kwargs):
# If the dataset is split, then the `AgMLDataLoader`s with the
# split and reduced data are stored as accessible class properties.
self._train_data = None
self._train_content = None
self._val_data = None
self._val_content = None
self._test_data = None
self._test_content = None
self._is_split = False

# Set the direct access metadata properties like `num_images` and
Expand Down Expand Up @@ -208,7 +225,9 @@ def custom(cls, name, dataset_path = None, classes = None, **kwargs):
Parameters
----------
name : str
A name for the custom dataset (this can be any valid string).
A name for the custom dataset (this can be any valid string). This
can also be a path to the dataset (in which case the name will be
the base directory inferred from the path).
dataset_path : str, optional
A custom path to load the dataset from. If this is not passed,
we will assume that the dataset is at the traditional path:
Expand All @@ -231,6 +250,11 @@ def custom(cls, name, dataset_path = None, classes = None, **kwargs):
f"a string that is not an existing dataset in "
f"the AgML public data source repository.")

# Check if the `name` is itself the path to the dataset.
if os.path.exists(name):
dataset_path = name
name = os.path.basename(name)

# Locate the path to the dataset.
if dataset_path is None:
dataset_path = os.path.abspath(os.path.join(data_save_path(), name))
Expand Down Expand Up @@ -624,7 +648,7 @@ def train_data(self):
if isinstance(self._train_data, AgMLDataLoader):
return self._train_data
self._train_data = self._generate_split_loader(
self._train_data, split = 'train')
self._train_content, split = 'train')
return self._train_data

@property
Expand All @@ -633,7 +657,7 @@ def val_data(self):
if isinstance(self._val_data, AgMLDataLoader):
return self._val_data
self._val_data = self._generate_split_loader(
self._val_data, split = 'val')
self._val_content, split = 'val')
return self._val_data

@property
Expand All @@ -642,7 +666,7 @@ def test_data(self):
if isinstance(self._test_data, AgMLDataLoader):
return self._test_data
self._test_data = self._generate_split_loader(
self._test_data, split = 'test')
self._test_content, split = 'test')
return self._test_data

def eval(self) -> "AgMLDataLoader":
Expand Down Expand Up @@ -980,8 +1004,7 @@ def __call__(self, contents, name):
# Re-map the annotation ID.
category_ids = annotations['category_id']
category_ids[np.where(category_ids == 0)[0]] = 1 # fix
new_ids = np.array([self._map[c]
for c in category_ids])
new_ids = np.array([self._map[c] for c in category_ids])
annotations['category_id'] = new_ids
return image, annotations

Expand Down Expand Up @@ -1175,14 +1198,87 @@ def split(self, train = None, val = None, test = None, shuffle = True):

# Build new `DataBuilder`s and `DataManager`s for the split data.
for split, content in contents.items():
setattr(self, f'_{split}_data', content)
setattr(self, f'_{split}_content', content)

# Otherwise, raise an error for an invalid type.
else:
raise TypeError(
"Expected either only ints or only floats when generating "
f"a data split, got {[type(i) for i in arg_dict.values()]}.")

def save_split(self, name, overwrite = False):
"""Saves the current split of data to an internal location.
This method can be used to save the current split of data to an
internal file, such that the same split can be later loaded using
the `load_split` method (for reproducibility). This method will only
save the actual split of data, not any of the transforms or other
parameters which have been applied to the loader.
Parameters
----------
name: str
The name of the split to save. This name will be used to identify
the split when loading it later.
overwrite: bool
Whether to overwrite an existing split with the same name.
"""
# Ensure that there exist data splits (train/val/test data).
if (
self._train_content is None
and self._val_content is None
and self._test_content is None
):
raise NotImplementedError("Cannot save a split of data when no "
"split has been generated.")

# Get each of the individual splits.
splits = {'train': self._train_content,
'val': self._val_content,
'test': self._test_content}

# Save the split to the internal location.
split_dir = os.path.join(SUPER_BASE_DIR, 'splits', self.name)
os.makedirs(split_dir, exist_ok = True)
if os.path.exists(os.path.join(split_dir, f'{name}.json')):
if not overwrite:
raise FileExistsError(f"A split with the name {name} already exists.")
with open(os.path.join(split_dir, f'{name}.json'), 'w') as f:
json.dump(splits, f)

def load_split(self, name, **kwargs):
"""Loads a previously saved split of data.
This method can be used to load a previously saved split of data
if the split was saved using the `save_split` method. This method
will only load the actual split of data, not any of the transforms
or other parameters which have been applied to the loader. You can
use the traditional split accessors (`train_data`, `val_data`, and
`test_data`) to access the loaded data.
Parameters
----------
name: str
The name of the split to load. This name will be used to identify
the split to load.
"""
if kwargs.get('manual_split_set', False):
splits = kwargs['manual_split_set']

else:
# Ensure that the split exists.
split_dir = os.path.join(SUPER_BASE_DIR, 'splits', self.name)
if not os.path.exists(os.path.join(split_dir, f'{name}.json')):
raise FileNotFoundError(f"Could not find a split with the name {name}.")

# Load the split from the internal location.
with open(os.path.join(split_dir, f'{name}.json'), 'r') as f:
splits = json.load(f)

# Set the split contents.
for split, content in splits.items():
setattr(self, f'_{split}_content', content)

def batch(self, batch_size = None):
"""Batches sets of image and annotation data according to a size.
Expand Down Expand Up @@ -1611,11 +1707,9 @@ def export_torch(self, **loader_kwargs):

# The `DataLoader` automatically batches objects using its
# own mechanism, so we remove batching from this DataLoader.
batch_size = loader_kwargs.pop(
'batch_size', obj._manager._batch_size)
batch_size = loader_kwargs.pop('batch_size', obj._manager._batch_size)
obj.batch(None)
shuffle = loader_kwargs.pop(
'shuffle', obj._manager._shuffle)
shuffle = loader_kwargs.pop('shuffle', obj._manager._shuffle)

# The `collate_fn` for object detection is different because
# the COCO JSON dictionaries each have different formats. So,
Expand Down
Loading

0 comments on commit 73e646f

Please sign in to comment.