Skip to content

Commit

Permalink
✨ feat(datasets): Separate the data preprocessing
Browse files Browse the repository at this point in the history
The data preprocessings of trainset and testset are defined separately with no overlap.

Add `train()` and `eval()` for `BaseDataset` that for informing FL-bench that which data preprocessing should be adopted.

🐞fix(Elastic): Remove deprecated arg in torch's lr_scheduler

🐞fix(ADCOL): Add definitions of some variables in `__init__()`
  • Loading branch information
KarhouTam committed Mar 28, 2024
1 parent ff136dd commit e613cda
Show file tree
Hide file tree
Showing 23 changed files with 134 additions and 103 deletions.
152 changes: 77 additions & 75 deletions data/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,28 @@ def __init__(self) -> None:
self.targets: torch.Tensor = None
self.train_data_transform = None
self.train_target_transform = None
self.general_data_transform = None
self.general_target_transform = None
self.enable_train_transform = True
self.test_data_transform = None
self.test_target_transform = None
self.data_transform = None
self.target_transform = None

def __getitem__(self, index):
data, targets = self.data[index], self.targets[index]
if self.enable_train_transform and self.train_data_transform is not None:
data = self.train_data_transform(data)
if self.enable_train_transform and self.train_target_transform is not None:
targets = self.train_target_transform(targets)
if self.general_data_transform is not None:
data = self.general_data_transform(data)
if self.general_target_transform is not None:
targets = self.general_target_transform(targets)
if self.data_transform is not None:
data = self.data_transform(data)
if self.target_transform is not None:
targets = self.target_transform(targets)

return data, targets

def train(self):
self.data_transform = self.train_data_transform
self.target_transform = self.train_target_transform

def eval(self):
self.data_transform = self.test_data_transform
self.target_transform = self.test_target_transform

def __len__(self):
return len(self.targets)

Expand All @@ -47,8 +53,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
) -> None:
Expand All @@ -68,8 +74,8 @@ def __init__(
self.data = torch.from_numpy(data).float().reshape(-1, 1, 28, 28)
self.targets = torch.from_numpy(targets).long()
self.classes = list(range(62))
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

Expand Down Expand Up @@ -99,8 +105,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
) -> None:
Expand All @@ -119,8 +125,8 @@ def __init__(

self.data = torch.from_numpy(data).permute([0, -1, 1, 2]).float()
self.targets = torch.from_numpy(targets).long()
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform
self.classes = [0, 1]
Expand All @@ -131,8 +137,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
):
Expand All @@ -146,8 +152,8 @@ def __init__(
self.targets = (
torch.Tensor(np.load(root / "raw" / "ydata.npy")).long().squeeze()
)
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

Expand All @@ -157,8 +163,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
):
Expand All @@ -174,8 +180,8 @@ def __init__(
torch.Tensor(np.load(root / "raw" / "ydata.npy")).long().squeeze()
)
self.classes = [0, 1, 2, 3]
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

Expand All @@ -185,8 +191,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
):
Expand All @@ -203,8 +209,8 @@ def __init__(
self.data = torch.cat([train_data, test_data])
self.targets = torch.cat([train_targets, test_targets])
self.classes = list(range(10))
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

Expand All @@ -214,8 +220,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
):
Expand All @@ -232,8 +238,8 @@ def __init__(
self.data = torch.cat([train_data, test_data])
self.targets = torch.cat([train_targets, test_targets])
self.classes = list(range(10))
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

Expand All @@ -243,8 +249,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
):
Expand All @@ -258,8 +264,8 @@ def __init__(
self.data = torch.cat([train_data, test_data])
self.targets = torch.cat([train_targets, test_targets])
self.classes = train_part.classes
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

Expand All @@ -269,8 +275,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
):
Expand All @@ -284,8 +290,8 @@ def __init__(
self.data = torch.cat([train_data, test_data])
self.targets = torch.cat([train_targets, test_targets])
self.classes = train_part.classes
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

Expand All @@ -295,8 +301,8 @@ def __init__(
self,
root,
args,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
):
Expand All @@ -319,8 +325,8 @@ def __init__(
self.data = torch.cat([train_data, test_data])
self.targets = torch.cat([train_targets, test_targets])
self.classes = train_part.classes
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

Expand All @@ -330,8 +336,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
):
Expand All @@ -345,8 +351,8 @@ def __init__(
self.data = torch.cat([train_data, test_data])
self.targets = torch.cat([train_targets, test_targets])
self.classes = train_part.classes
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

Expand All @@ -356,8 +362,8 @@ def __init__(
self,
root,
args,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
):
Expand All @@ -371,8 +377,8 @@ def __init__(
self.data = torch.cat([train_data, test_data])
self.targets = torch.cat([train_targets, test_targets])
self.classes = train_part.classes
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform
super_class = None
Expand Down Expand Up @@ -420,8 +426,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
):
Expand Down Expand Up @@ -472,8 +478,8 @@ def __init__(

self.data = torch.load(root / "data.pt")
self.targets = torch.load(root / "targets.pt")
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

Expand All @@ -483,8 +489,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
):
Expand Down Expand Up @@ -528,8 +534,8 @@ def __init__(

self.data = torch.load(root / "data.pt")
self.targets = torch.load(root / "targets.pt")
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

Expand All @@ -539,8 +545,8 @@ def __init__(
self,
root,
args=None,
general_data_transform=None,
general_target_transform=None,
test_data_transform=None,
test_target_transform=None,
train_data_transform=None,
train_target_transform=None,
) -> None:
Expand Down Expand Up @@ -576,22 +582,18 @@ def __init__(
transforms.ToTensor(),
]
)
self.general_data_transform = general_data_transform
self.general_target_transform = general_target_transform
self.test_data_transform = test_data_transform
self.test_target_transform = test_target_transform
self.train_data_transform = train_data_transform
self.train_target_transform = train_target_transform

def __getitem__(self, index):
data = self.pre_transform(Image.open(self.filename_list[index]).convert("RGB"))
targets = self.targets[index]
if self.enable_train_transform and self.train_data_transform is not None:
data = self.train_data_transform(data)
if self.enable_train_transform and self.train_target_transform is not None:
targets = self.train_target_transform(targets)
if self.general_data_transform is not None:
data = self.general_data_transform(data)
if self.general_target_transform is not None:
targets = self.general_target_transform(targets)
if self.data_transform is not None:
data = self.data_transform(data)
if self.target_transform is not None:
targets = self.target_transform(targets)
return data, targets


Expand Down
2 changes: 2 additions & 0 deletions src/client/adcol.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ def __init__(self, model, discriminator, args, logger, device, client_num):
self.discriminator = discriminator
self.discriminator.to(self.device)
self.client_num = client_num
self.featrure_list = []

def fit(self):
self.model.train()
self.discriminator.eval()
self.dataset.train()
self.featrure_list = []
for i in range(self.local_epoch):
for x, y in self.trainloader:
Expand Down
1 change: 1 addition & 0 deletions src/client/apfl.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def save_state(self):
def fit(self):
self.model.train()
self.local_model.train()
self.dataset.train()
for i in range(self.local_epoch):
for x, y in self.trainloader:
if len(x) <= 1:
Expand Down
1 change: 1 addition & 0 deletions src/client/ditto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def save_state(self):

def fit(self):
self.model.train()
self.dataset.train()
for _ in range(self.local_epoch):
for x, y in self.trainloader:
if len(x) <= 1:
Expand Down
Loading

0 comments on commit e613cda

Please sign in to comment.