From 675726323034625ded612f220b389ba956ce9687 Mon Sep 17 00:00:00 2001 From: AgentDS Date: Fri, 24 Sep 2021 16:24:46 +0800 Subject: [PATCH] remove abstract method sample_nums_count in data partition --- fedlab/utils/dataset/functional.py | 9 +++++++++ fedlab/utils/dataset/partition.py | 6 ++---- requirements.txt | 1 + 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/fedlab/utils/dataset/functional.py b/fedlab/utils/dataset/functional.py index d434af4c..2a981482 100644 --- a/fedlab/utils/dataset/functional.py +++ b/fedlab/utils/dataset/functional.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import pandas as pd import warnings @@ -240,3 +241,11 @@ def client_inner_dirichlet_partition(targets, num_clients, num_classes, dir_alph client_dict = dict([(cid, client_indices[cid]) for cid in range(num_clients)]) return client_dict + + +def samples_num_count(client_dict, num_clients): + client_samples_nums = [[cid, client_dict[cid].shape[0]] for cid in + range(num_clients)] + client_sample_count = pd.DataFrame(data=client_samples_nums, + columns=['client', 'num_samples']).set_index('client') + return client_sample_count diff --git a/fedlab/utils/dataset/partition.py b/fedlab/utils/dataset/partition.py index 797fc8f0..0ecc32a8 100644 --- a/fedlab/utils/dataset/partition.py +++ b/fedlab/utils/dataset/partition.py @@ -37,10 +37,6 @@ def __getitem__(self, index): def __len__(self): raise NotImplementedError - @abstractmethod - def _samples_num_count(self): - raise NotImplementedError - class CIFAR10Partitioner(DataPartitioner): """CIFAR10 data partitioner. @@ -127,6 +123,8 @@ def __init__(self, targets, num_clients, # perform partition according to setting self.client_dict = self._perform_partition() + # get sample number count for each client + self.client_sample_count = F.samples_num_count(self.client_dict, self.num_clients) def _perform_partition(self): if self.balance is None: diff --git a/requirements.txt b/requirements.txt index 1b54422d..9b13df6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ numpy +pandas spacy pynvml