Skip to content

Commit

Permalink
Merge pull request #78 from ayrna/datasets
Browse files Browse the repository at this point in the history
[API] `FGNet` interface changes to match `VisionDataset`
  • Loading branch information
franberchez authored Jul 23, 2024
2 parents e4198dc + a7e7684 commit 728e643
Show file tree
Hide file tree
Showing 8 changed files with 1,084 additions and 2,058 deletions.
186 changes: 156 additions & 30 deletions dlordinal/datasets/fgnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
import shutil
from pathlib import Path
from typing import Union
from typing import Any, Callable, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
from PIL import Image
from skimage.io import imread, imsave
from skimage.transform import resize
from skimage.util import img_as_ubyte
Expand All @@ -18,44 +19,82 @@ class FGNet(VisionDataset):
"""
Base class for FGNet dataset.
Attributes
----------
root : Path
Root directory of the dataset.
target_size : tuple
Size of the images after resizing.
categories : list
List of categories to be used.
test_size : float
Size of the test set.
validation_size : float
Size of the validation set.
transform : callable, optional
A function/transform that takes in a PIL image and returns a transformed version.
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
data : pd.DataFrame
Dataframe containing the dataset.
Parameters
----------
root : str or Path
Root directory of dataset
download : bool, optional
If True, downloads the dataset from the internet and puts it in root directory.
If dataset is already downloaded, it is not downloaded again.
process_data : bool, optional
If True, processes the dataset and puts it in root directory.
If dataset is already processed, it is not processed again.
Root directory of the dataset.
download : bool, optional, default = True
If True, downloads the dataset from the internet and puts it in the root directory.
If the dataset is already downloaded, it is not downloaded again.
target_size : tuple, optional
Size of the images after resizing.
Size of the images after resizing. Default is (128, 128).
categories : list, optional
List of categories to be used.
List of categories to be used. Default is [3, 11, 16, 24, 40].
test_size : float, optional
Size of the test set.
Size of the test set. Default is 0.2.
validation_size : float, optional
Size of the validation set.
Size of the validation set. Default is 0.15.
train : bool, optional
If True, returns the training dataset, otherwise returns the test dataset. Default is True.
transform : callable, optional
A function/transform that takes in a PIL image and returns a transformed version.
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
"""

# Attributes
root: Path
target_size: tuple
categories: list
test_size: float
validation_size: float
transform: Optional[Callable]
target_transform: Optional[Callable]
data: pd.DataFrame

def __init__(
self,
root: Union[str, Path],
download: bool = False,
process_data: bool = True,
download: bool = True,
target_size: tuple = (128, 128),
categories: list = [3, 11, 16, 24, 40],
test_size: float = 0.2,
validation_size: float = 0.15,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super(FGNet, self).__init__(root)
super(FGNet, self).__init__(
str(root), transform=transform, target_transform=target_transform
)

self.root = Path(self.root)
self.root = Path(root)
self.root.parent.mkdir(parents=True, exist_ok=True)
self.target_size = target_size
self.categories = categories
self.test_size = test_size
self.validation_size = validation_size
self.transform = transform
self.target_transform = target_transform

original_path = self.root / "FGNET/images"
processed_path = self.root / "FGNET/data_processed"
Expand All @@ -76,20 +115,97 @@ def __init__(
" download it"
)

if process_data:
self.process(original_path, processed_path)
self.split(
original_csv_path,
train_csv_path,
test_csv_path,
original_images_path,
train_images_path,
test_images_path,
)
self.process(original_path, processed_path)
self.split(
original_csv_path,
train_csv_path,
test_csv_path,
original_images_path,
train_images_path,
test_images_path,
)

# Load train and test dataframes
if train:
self.data = pd.read_csv(train_csv_path)
else:
self.data = pd.read_csv(test_csv_path)

def __str__(self) -> str:
return "FGNet"

def __len__(self) -> int:
"""
Obtain the number of samples in the dataset.
Returns
-------
int
Number of samples in the dataset.
"""

return len(self.data)

def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Get a sample from the dataset.
Parameters
----------
index : int
Index of the sample to get.
Returns
-------
tuple
(image, target) where target is the class index of the target class.
"""
img_path = (
self.root / "FGNET" / "data_processed" / self.data.iloc[index]["path"]
)

# Cargar la imagen como PIL.Image.Image
image = Image.open(img_path)
image = image.convert("RGB")

# Aplicar transformación si está definida
if self.transform:
image = self.transform(image)

target = int(self.data.iloc[index]["category"])
if self.target_transform:
target = self.target_transform(target)

return image, target

@property
def targets(self) -> List[int]:
"""
Return the targets of the dataset.
Returns
-------
list
List of targets.
"""

if self.target_transform:
return self.target_transform(self.data["category"])
else:
return self.data["category"].tolist()

@property
def classes(self) -> List[int]:
"""
Return the unique classes in the dataset.
Returns
-------
list
List of unique classes.
"""
return np.unique(self.data["category"]).tolist()

def download(self) -> None:
"""
Download the FGNet dataset and extract it.
Expand All @@ -100,7 +216,7 @@ def download(self) -> None:

download_and_extract_archive(
"http://yanweifu.github.io/FG_NET_data/FGNET.zip",
self.root,
str(self.root),
filename="fgnet.zip",
md5="1206978cac3626321b84c22b24cc8d19",
)
Expand Down Expand Up @@ -205,7 +321,9 @@ def get_age_from_filename(self, filename):
Filename of the image.
"""
m = re.match("[0-9]+A([0-9]+).*", filename)
return int(m.groups()[0])
if m:
return int(m.groups()[0])
return None

def find_category(self, real_age):
"""
Expand Down Expand Up @@ -289,7 +407,11 @@ def split_dataframe(
x = np.array(df["path"])
y = np.array(df["category"])
x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=self.test_size, random_state=1
x,
y,
test_size=self.test_size,
random_state=1,
stratify=y,
)

for path, label in zip(x_train, y_train):
Expand All @@ -303,7 +425,11 @@ def split_dataframe(
shutil.copy(original_images_path / path, test_path / path)

x_train, x_val, y_train, y_val = train_test_split(
x_train, y_train, test_size=self.validation_size, random_state=1
x_train,
y_train,
test_size=self.validation_size,
random_state=1,
stratify=y_train,
)

train = np.hstack((x_train[:, np.newaxis], y_train[:, np.newaxis]))
Expand Down
Loading

0 comments on commit 728e643

Please sign in to comment.