-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathEnterprise.py
1710 lines (1556 loc) · 89.4 KB
/
Enterprise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import numpy as np
import math
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from DatasetLoad import DatasetLoad
from DatasetLoad import AddGaussianNoise
from sklearn.datasets import make_blobs
from torch import optim
import random
import copy
import time
from sys import getsizeof
from Crypto.PublicKey import RSA
from hashlib import sha256
from Models import ConcatModel, CombinedModel, Generator
# FedAnil++: Consortium Blockchain
from Block import Block
# FedAnil++: Consortium Blockchain
from Consortium_Blockchain import Consortium_Blockchain
from torchvision import transforms
# FedAnil++: Affinity Propagation
from sklearn.cluster import AffinityPropagation
import sys
import warnings
import pyae
import tenseal as ts
import torch.nn as nn
warnings.filterwarnings('ignore')
# FedAnil++: Cosine Similarity (Threshold 1 and 2)
TRESHOLD1 = -0.7
TRESHOLD2 = 0.7
m_type = ["cnn", "glove", "resnet"]
flcnt = 0
lastprc = 0
class Enterprise:
def __init__(self, idx, assigned_train_ds, assigned_test_dl, local_batch_size, learning_rate, loss_func, opti, network_stability, net, dev, miner_acception_wait_time, miner_accepted_transactions_size_limit, validator_threshold, pow_difficulty, even_link_speed_strength, base_data_transmission_speed, even_computation_power, is_malicious, noise_variance, check_signature, not_resync_chain, malicious_updates_discount, knock_out_rounds, lazy_local_enterprise_knock_out_rounds):
self.idx = idx
# deep learning variables
self.train_ds = assigned_train_ds
self.test_dl = assigned_test_dl
self.local_batch_size = local_batch_size
self.loss_func = loss_func
self.network_stability = network_stability
self.net = copy.deepcopy(net)
if opti == "SGD":
self.opti = optim.SGD(self.net.parameters(), lr=learning_rate, momentum=0.9)
else:
self.opti = optim.Adam(self.net.parameters(), lr=learning_rate, betas=(0.9, 0.9))
self.dev = dev
# in real system, new data can come in, so train_dl should get reassigned before training when that happens
self.train_dl = DataLoader(self.train_ds, batch_size=self.local_batch_size, shuffle=True)
self.local_train_parameters = None
self.initial_net_parameters = None
self.global_parameters = None
# FedAnil++: Consortium_blockchain variables
self.role = None
self.pow_difficulty = pow_difficulty
if even_link_speed_strength:
self.link_speed = base_data_transmission_speed
else:
self.link_speed = random.random() * base_data_transmission_speed
self.enterprises_dict = None
self.aio = False
''' simulating hardware equipment strength, such as good processors and RAM capacity. Following recorded times will be shrunk by this value of times
# for local_enterprises, its update time
# for miners, its PoW time
# for validators, its validation time
# might be able to simulate molopoly on computation power when there are block size limit, as faster enterprises' transactions will be accepted and verified first
'''
if even_computation_power:
self.computation_power = 1
else:
self.computation_power = random.randint(0, 4)
self.peer_list = set()
# used in cross_verification and in the PoS
self.online = True
self.rewards = 0
# FedAnil++: Consortium Blockchain
self.consortium_blockchain = Consortium_Blockchain()
# init key pair
self.modulus = None
self.private_key = None
self.public_key = None
self.generate_rsa_key()
# black_list stores enterprise index rather than the object
self.black_list = set()
self.knock_out_rounds = knock_out_rounds
self.lazy_local_enterprise_knock_out_rounds = lazy_local_enterprise_knock_out_rounds
self.local_enterprise_accuracy_accross_records = {}
self.has_added_block = False
self.the_added_block = None
self.is_malicious = is_malicious
#if self.is_malicious:
# print(f"Malicious Node created {self.idx}")
self.noise_variance = noise_variance
self.check_signature = check_signature
self.not_resync_chain = not_resync_chain
self.malicious_updates_discount = malicious_updates_discount
# used to identify slow or lazy local_enterprises
self.active_local_enterprise_record_by_round = {}
self.untrustworthy_local_enterprises_record_by_comm_round = {}
self.untrustworthy_validators_record_by_comm_round = {}
# for picking PoS legitimate blockd;bs
# self.stake_tracker = {} # used some tricks in main.py for ease of programming
# used to determine the slowest enterprise round end time to compare PoW with PoS round end time. If simulate under computation_power = 0, this may end up equaling infinity
self.round_end_time = 0
self.AE = None
self.coded_data_after_ac = None
''' For local_enterprises '''
self.local_updates_rewards_per_transaction = 0
self.received_block_from_miner = None
self.accuracy_this_round = float('-inf')
self.local_enterprise_associated_validator = None
self.local_enterprise_associated_miner = None
self.local_update_time = None
self.local_total_epoch = 0
# FedAnil++: Training Models via Models Random Selection by Each Enterprise.
self.model_type = random.sample(m_type, random.randint(1, 3))
''' For validators '''
self.validator_associated_local_enterprise_set = set()
self.validation_rewards_this_round = 0
self.accuracies_this_round = {}
self.validator_associated_miner = None
# when validator directly accepts local_enterprises' updates
self.unordered_arrival_time_accepted_local_enterprise_transactions = {}
self.validator_accepted_broadcasted_local_enterprise_transactions = None or []
self.final_transactions_queue_to_validate = {}
self.post_validation_transactions_queue = None or []
self.validator_threshold = validator_threshold
self.validator_local_accuracy = None
''' For miners '''
self.miner_associated_local_enterprise_set = set()
self.miner_associated_validator_set = set()
# dict cannot be added to set()
self.unconfirmmed_transactions = None or []
self.broadcasted_transactions = None or []
self.mined_block = None
self.received_propagated_block = None
self.received_propagated_validator_block = None
self.miner_acception_wait_time = miner_acception_wait_time
self.miner_accepted_transactions_size_limit = miner_accepted_transactions_size_limit
# when miner directly accepts validators' updates
self.unordered_arrival_time_accepted_validator_transactions = {}
self.miner_accepted_broadcasted_validator_transactions = None or []
self.final_candidate_transactions_queue_to_mine = {}
self.block_generation_time_point = None
self.unordered_propagated_block_processing_queue = {} # pure simulation queue and does not exist in real distributed system
''' For malicious node '''
self.variance_of_noises = None or []
self.size_of_encoded_data = 0
''' Common Methods '''
''' setters '''
def set_enterprises_dict_and_aio(self, enterprises_dict, aio):
self.enterprises_dict = enterprises_dict
self.aio = aio
def generate_rsa_key(self):
keyPair = RSA.generate(bits=1024)
self.modulus = keyPair.n
self.private_key = keyPair.d
self.public_key = keyPair.e
def init_global_parameters(self):
self.initial_net_parameters = self.net.state_dict()
self.global_parameters = self.net.state_dict()
def return_global_parametesrs(self):
return self.global_parameters
def assign_role(self):
# equal probability
role_choice = random.randint(0, 2)
if role_choice == 0:
self.role = "local_enterprise"
elif role_choice == 1:
self.role = "miner"
else:
self.role = "validator"
# used for hard_assign
def assign_miner_role(self):
self.role = "miner"
def assign_local_enterprise_role(self):
self.role = "local_enterprise"
def assign_validator_role(self):
self.role = "validator"
''' getters '''
def return_idx(self):
return self.idx
def return_rsa_pub_key(self):
return {"modulus": self.modulus, "pub_key": self.public_key}
def return_peers(self):
return self.peer_list
def return_role(self):
return self.role
def is_online(self):
return self.online
def return_is_malicious(self):
return self.is_malicious
def return_black_list(self):
return self.black_list
# FedAnil++: Consortium Blockchain
def return_consortium_blockchain_object(self):
return self.consortium_blockchain
def return_stake(self):
return self.rewards
def return_computation_power(self):
return self.computation_power
def return_the_added_block(self):
return self.the_added_block
def return_round_end_time(self):
return self.round_end_time
''' functions '''
def sign_msg(self, msg):
hash = int.from_bytes(sha256(str(msg).encode('utf-8')).digest(), byteorder='big')
# pow() is python built-in modular exponentiation function
signature = pow(hash, self.private_key, self.modulus)
return signature
def add_peers(self, new_peers):
if isinstance(new_peers, Enterprise):
self.peer_list.add(new_peers)
else:
self.peer_list.update(new_peers)
def remove_peers(self, peers_to_remove):
if isinstance(peers_to_remove, Enterprise):
self.peer_list.discard(peers_to_remove)
else:
self.peer_list.difference_update(peers_to_remove)
def return_model_type(self, index):
return m_type[0]
def online_switcher(self):
old_status = self.online
online_indicator = random.random()
if online_indicator < self.network_stability:
self.online = True
# if back online, update peer and resync chain
if old_status == False:
print(f"{self.idx} goes back online.")
# update peer list
self.update_peer_list()
# resync chain
if self.pow_resync_chain():
self.update_model_after_chain_resync()
else:
self.online = False
print(f"{self.idx} goes offline.")
return self.online
def update_peer_list(self):
print(f"\n{self.idx} - {self.role} is updating peer list...")
old_peer_list = copy.copy(self.peer_list)
online_peers = set()
for peer in self.peer_list:
if peer.is_online():
online_peers.add(peer)
# for online peers, suck in their peer list
for online_peer in online_peers:
self.add_peers(online_peer.return_peers())
# remove itself from the peer_list if there is
self.remove_peers(self)
# remove malicious peers
removed_peers = []
potential_malicious_peer_set = set()
for peer in self.peer_list:
if peer.return_idx() in self.black_list:
potential_malicious_peer_set.add(peer)
self.remove_peers(potential_malicious_peer_set)
removed_peers.extend(potential_malicious_peer_set)
# print updated peer result
if old_peer_list == self.peer_list:
print("Peer list NOT changed.")
else:
print("Peer list has been changed.")
added_peers = self.peer_list.difference(old_peer_list)
if potential_malicious_peer_set:
print("These malicious peers are removed")
for peer in removed_peers:
print(f"e_{peer.return_idx().split('_')[-1]} - {peer.return_role()[0]}", end=', ')
print()
if added_peers:
print("These peers are added")
for peer in added_peers:
print(f"e_{peer.return_idx().split('_')[-1]} - {peer.return_role()[0]}", end=', ')
print()
print("Final peer list:")
for peer in self.peer_list:
print(f"e_{peer.return_idx().split('_')[-1]} - {peer.return_role()[0]}", end=', ')
print()
# WILL ALWAYS RETURN TRUE AS OFFLINE PEERS WON'T BE REMOVED ANY MORE, UNLESS ALL PEERS ARE Malicious Enterprises...but then it should not register with any other peer. Original purpose - if peer_list ends up empty, randomly register with another enterprise
return False if not self.peer_list else True
def check_pow_proof(self, block_to_check):
# remove its block hash(compute_hash() by default) to verify pow_proof as block hash was set after pow
pow_proof = block_to_check.return_pow_proof()
# print("pow_proof", pow_proof)
# print("compute_hash", block_to_check.compute_hash())
return pow_proof.startswith('0' * self.pow_difficulty) and pow_proof == block_to_check.compute_hash()
def check_chain_validity(self, chain_to_check):
chain_len = chain_to_check.return_chain_length()
if chain_len == 0 or chain_len == 1:
pass
else:
chain_to_check = chain_to_check.return_chain_structure()
for block in chain_to_check[1:]:
if self.check_pow_proof(block) and block.return_previous_block_hash() == chain_to_check[chain_to_check.index(block) - 1].compute_hash(hash_entire_block=True):
pass
else:
return False
return True
def accumulate_chain_stake(self, chain_to_accumulate):
accumulated_stake = 0
chain_to_accumulate = chain_to_accumulate.return_chain_structure()
for block in chain_to_accumulate:
accumulated_stake += self.enterprises_dict[block.return_mined_by()].return_stake()
return accumulated_stake
def resync_chain(self, mining_consensus):
if self.not_resync_chain:
return # temporary workaround to save GPU memory
if mining_consensus == 'PoW':
self.pow_resync_chain()
else:
self.pos_resync_chain()
def pos_resync_chain(self):
print(f"{self.role} {self.idx} is looking for a chain with the highest accumulated miner's stake in the network...")
highest_stake_chain = None
updated_from_peer = None
# FedAnil++: Consortium Blockchain
curr_chain_stake = self.accumulate_chain_stake(self.return_consortium_blockchain_object())
for peer in self.peer_list:
if peer.is_online():
peer_chain = peer.return_consortium_blockchain_object()
peer_chain_stake = self.accumulate_chain_stake(peer_chain)
if peer_chain_stake > curr_chain_stake:
if self.check_chain_validity(peer_chain):
print(f"A chain from {peer.return_idx()} with total stake {peer_chain_stake} has been found (> currently compared chain stake {curr_chain_stake}) and verified.")
# Higher stake valid chain found!
curr_chain_stake = peer_chain_stake
highest_stake_chain = peer_chain
updated_from_peer = peer.return_idx()
else:
print(f"A chain from {peer.return_idx()} with higher stake has been found BUT NOT verified. Skipped this chain for syncing.")
if highest_stake_chain:
# compare chain difference
highest_stake_chain_structure = highest_stake_chain.return_chain_structure()
# need more efficient machenism which is to reverse updates by # of blocks
self.return_consortium_blockchain_object().replace_chain(highest_stake_chain_structure)
print(f"{self.idx} chain resynced from peer {updated_from_peer}.")
#return block_iter
return True
print("Chain not resynced.")
return False
def pow_resync_chain(self):
print(f"{self.role} {self.idx} is looking for a longer chain in the network...")
longest_chain = None
updated_from_peer = None
curr_chain_len = self.return_consortium_blockchain_object().return_chain_length()
for peer in self.peer_list:
if peer.is_online():
peer_chain = peer.return_consortium_blockchain_object()
if peer_chain.return_chain_length() > curr_chain_len:
if self.check_chain_validity(peer_chain):
print(f"A longer chain from {peer.return_idx()} with chain length {peer_chain.return_chain_length()} has been found (> currently compared chain length {curr_chain_len}) and verified.")
# Longer valid chain found!
curr_chain_len = peer_chain.return_chain_length()
longest_chain = peer_chain
updated_from_peer = peer.return_idx()
else:
print(f"A longer chain from {peer.return_idx()} has been found BUT NOT verified. Skipped this chain for syncing.")
if longest_chain:
# compare chain difference
longest_chain_structure = longest_chain.return_chain_structure()
# need more efficient machenism which is to reverse updates by # of blocks
self.return_consortium_blockchain_object().replace_chain(longest_chain_structure)
print(f"{self.idx} chain resynced from peer {updated_from_peer}.")
#return block_iter
return True
print("Chain not resynced.")
return False
def load_data_by_index(self, index):
centers = [[1, 1], [-1, -1], [1, -1]]
X, labels_true = make_blobs(
n_samples=300, centers=centers, cluster_std=0.5, random_state=0
)
return X, labels_true
def receive_rewards(self, rewards):
self.rewards += rewards
def verify_miner_transaction_by_signature(self, transaction_to_verify, miner_enterprise_idx):
if miner_enterprise_idx in self.black_list:
print(f"{miner_enterprise_idx} is in miner's blacklist. Trasaction won't get verified.")
return False
if self.check_signature:
transaction_before_signed = copy.deepcopy(transaction_to_verify)
del transaction_before_signed["miner_signature"]
modulus = transaction_to_verify['miner_rsa_pub_key']["modulus"]
pub_key = transaction_to_verify['miner_rsa_pub_key']["pub_key"]
signature = transaction_to_verify["miner_signature"]
# verify
hash = int.from_bytes(sha256(str(sorted(transaction_before_signed.items())).encode('utf-8')).digest(), byteorder='big')
hashFromSignature = pow(signature, pub_key, modulus)
if hash == hashFromSignature:
print(f"A transaction recorded by miner {miner_enterprise_idx} in the block is verified!")
return True
else:
print(f"Signature invalid. Transaction recorded by {miner_enterprise_idx} is NOT verified.")
return False
else:
print(f"A transaction recorded by miner {miner_enterprise_idx} in the block is verified!")
return True
def verify_block(self, block_to_verify, sending_miner):
if not self.online_switcher():
print(f"{self.idx} goes offline when verifying a block")
return False, False
verification_time = time.time()
mined_by = block_to_verify.return_mined_by()
if sending_miner in self.black_list:
print(f"The miner propagating/sending this block {sending_miner} is in {self.idx}'s black list. Block will not be verified.")
return False, False
if mined_by in self.black_list:
print(f"The miner {mined_by} mined this block is in {self.idx}'s black list. Block will not be verified.")
return False, False
# check if the proof is valid(verify _block_hash).
if not self.check_pow_proof(block_to_verify):
print(f"PoW proof of the block from miner {self.idx} is not verified.")
return False, False
# # check if miner's signature is valid
if self.check_signature:
signature_dict = block_to_verify.return_miner_rsa_pub_key()
modulus = signature_dict["modulus"]
pub_key = signature_dict["pub_key"]
signature = block_to_verify.return_signature()
# verify signature
block_to_verify_before_sign = copy.deepcopy(block_to_verify)
block_to_verify_before_sign.remove_signature_for_verification()
hash = int.from_bytes(sha256(str(block_to_verify_before_sign.__dict__).encode('utf-8')).digest(), byteorder='big')
hashFromSignature = pow(signature, pub_key, modulus)
if hash != hashFromSignature:
print(f"Signature of the block sent by miner {sending_miner} mined by miner {mined_by} is not verified by {self.role} {self.idx}.")
return False, False
# check previous hash based on own chain
last_block = self.return_consortium_blockchain_object().return_last_block()
if last_block is not None:
# check if the previous_hash referred in the block and the hash of latest block in the chain match.
last_block_hash = last_block.compute_hash(hash_entire_block=True)
if block_to_verify.return_previous_block_hash() != last_block_hash:
print(f"Block sent by miner {sending_miner} mined by miner {mined_by} has the previous hash recorded as {block_to_verify.return_previous_block_hash()}, but the last block's hash in chain is {last_block_hash}. This is possibly due to a forking event from last round. Block not verified and won't be added. Enterprise needs to resync chain next round.")
return False, False
# All verifications done.
print(f"Block accepted from miner {sending_miner} mined by {mined_by} has been verified by {self.idx}!")
verification_time = (time.time() - verification_time)/self.computation_power
return block_to_verify, verification_time
# FedAnil++: get global model from blockchain
def fetch_global_model(self, blockchain):
try:
last_block_in_blockchain = blockchain.return_last_block()
transaction = last_block_in_blockchain["transaction"]
self.local_model = transaction["gradients"]
except:
pass
# FedAnil++: upload local model in compressed format
def upload_local_model(self, compress_parameters):
new_transaction = {}
new_transaction["gradients"] = compress_parameters
new_block = Block(idx=-1, transactions=new_transaction, miner_rsa_pub_key=0)
self.consortium_blockchain.new_local_block(new_block, self.coded_data_after_ac)
# FedAnil++: fetch list ofcl all local models
def fetch_local_models(self):
self.coded_data_after_ac = self.return_consortium_blockchain_object().return_last_cdata()
return self.return_consortium_blockchain_object().return_local_chain()
# FedAnil++: add consortium blockchain blockchain
def add_block(self, block_to_add):
self.return_consortium_blockchain_object().append_block(block_to_add)
print(f"e_{self.idx.split('_')[-1]} - {self.role[0]} has appened a block to its chain. Chain length now - {self.return_consortium_blockchain_object().return_chain_length()}")
# TODO delete has_added_block
# self.has_added_block = True
self.the_added_block = block_to_add
return True
# also accumulate rewards here
def process_block(self, block_to_process, log_files_folder_path, conn, conn_cursor, when_resync=False):
# collect usable updated params, malicious enterprises identification, get rewards and do local udpates
processing_time = time.time()
if not self.online_switcher():
print(f"{self.role} {self.idx} goes offline when processing the added block. Model not updated and rewards information not upgraded. Outdated information may be obtained by this node if it never resyncs to a different chain.") # may need to set up a flag indicating if a block has been processed
if block_to_process:
mined_by = block_to_process.return_mined_by()
if mined_by in self.black_list:
# in this system black list is also consistent across enterprises as it is calculated based on the information on chain, but individual enterprise can decide its own validation/verification mechanisms and has its own
print(f"The added block is mined by miner {block_to_process.return_mined_by()}, which is in this enterprise's black list. Block will not be processed.")
else:
# process validator sig valid transactions
# used to count positive and negative transactions local_enterprise by local_enterprise, select the transaction to do global update and identify potential malicious local_enterprise
self_rewards_accumulator = 0
valid_transactions_records_by_local_enterprise = {}
valid_validator_sig_local_enterprise_transacitons_in_block = block_to_process.return_transactions()['valid_validator_sig_transacitons']
comm_round = block_to_process.return_block_idx()
self.active_local_enterprise_record_by_round[comm_round] = set()
for valid_validator_sig_local_enterprise_transaciton in valid_validator_sig_local_enterprise_transacitons_in_block:
# verify miner's signature(miner does not get reward for receiving and aggregating)
if self.verify_miner_transaction_by_signature(valid_validator_sig_local_enterprise_transaciton, mined_by):
local_enterprise_enterprise_idx = valid_validator_sig_local_enterprise_transaciton['local_enterprise_enterprise_idx']
self.active_local_enterprise_record_by_round[comm_round].add(local_enterprise_enterprise_idx)
if not local_enterprise_enterprise_idx in valid_transactions_records_by_local_enterprise.keys():
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx] = {}
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['positive_epochs'] = set()
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['negative_epochs'] = set()
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['all_valid_epochs'] = set()
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['finally_used_params'] = None
# epoch of this local_enterprise's local update
local_epoch_seq = valid_validator_sig_local_enterprise_transaciton['local_total_accumulated_epochs_this_round']
positive_direction_validators = valid_validator_sig_local_enterprise_transaciton['positive_direction_validators']
negative_direction_validators = valid_validator_sig_local_enterprise_transaciton['negative_direction_validators']
#all_direction_validators = valid_validator_sig_local_enterprise_transaciton['all_valid_epochs']
# FedAnil++: validation enterprise local update by all validators
#if len(positive_direction_validators) >= len(negative_direction_validators):
if len(negative_direction_validators) == 0:
# local_enterprise transaction can be used
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['positive_epochs'].add(local_epoch_seq)
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['all_valid_epochs'].add(local_epoch_seq)
# see if this is the latest epoch from this local_enterprise
if local_epoch_seq == max(valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['all_valid_epochs']):
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['finally_used_params'] = valid_validator_sig_local_enterprise_transaciton['local_updates_params']
# give rewards to this local_enterprise
if self.idx == local_enterprise_enterprise_idx:
self_rewards_accumulator += valid_validator_sig_local_enterprise_transaciton['local_updates_rewards']
else:
if self.malicious_updates_discount:
# local_enterprise transaction voted negative and has to be applied for a discount
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['negative_epochs'].add(local_epoch_seq)
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['all_valid_epochs'].add(local_epoch_seq)
# see if this is the latest epoch from this local_enterprise
if local_epoch_seq == max(valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['all_valid_epochs']):
# apply discount
discounted_valid_validator_sig_local_enterprise_transaciton_local_updates_params = copy.deepcopy(valid_validator_sig_local_enterprise_transaciton['local_updates_params'])
for var in discounted_valid_validator_sig_local_enterprise_transaciton_local_updates_params:
discounted_valid_validator_sig_local_enterprise_transaciton_local_updates_params[var] *= self.malicious_updates_discount
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['finally_used_params'] = discounted_valid_validator_sig_local_enterprise_transaciton_local_updates_params
# local_enterprise receive discounted rewards for negative update
if self.idx == local_enterprise_enterprise_idx:
self_rewards_accumulator += valid_validator_sig_local_enterprise_transaciton['local_updates_rewards'] * self.malicious_updates_discount
else:
# discount specified as 0, local_enterprise transaction voted negative and cannot be used
valid_transactions_records_by_local_enterprise[local_enterprise_enterprise_idx]['negative_epochs'].add(local_epoch_seq)
# local_enterprise does not receive rewards for negative update
# give rewards to validators and the miner in this transaction
for validator_record in positive_direction_validators + negative_direction_validators:
if self.idx == validator_record['validator']:
self_rewards_accumulator += validator_record['validation_rewards']
if self.idx == validator_record['miner_enterprise_idx']:
self_rewards_accumulator += validator_record['miner_rewards_for_this_tx']
else:
print(f"one validator transaction miner sig found invalid in this block. {self.idx} will drop this block and roll back rewards information")
return
# identify potentially malicious local_enterprise
self.untrustworthy_local_enterprises_record_by_comm_round[comm_round] = set()
for local_enterprise_idx, local_updates_direction_records in valid_transactions_records_by_local_enterprise.items():
if len(local_updates_direction_records['negative_epochs']) > len(local_updates_direction_records['positive_epochs']):
self.untrustworthy_local_enterprises_record_by_comm_round[comm_round].add(local_enterprise_idx)
kick_out_accumulator = 1
# check previous rounds
for comm_round_to_check in range(comm_round - self.knock_out_rounds + 1, comm_round):
if comm_round_to_check in self.untrustworthy_local_enterprises_record_by_comm_round.keys():
if local_enterprise_idx in self.untrustworthy_local_enterprises_record_by_comm_round[comm_round_to_check]:
kick_out_accumulator += 1
if kick_out_accumulator == self.knock_out_rounds:
# kick out
self.black_list.add(local_enterprise_idx)
# is it right?
if when_resync:
msg_end = " when resyncing!\n"
else:
msg_end = "!\n"
if self.enterprises_dict[local_enterprise_idx].return_is_malicious():
msg = f"{self.idx} has successfully identified a malicious local_enterprise enterprise {local_enterprise_idx} in comm_round {comm_round}{msg_end}"
with open(f"{log_files_folder_path}/correctly_kicked_local_enterprises.txt", 'a') as file:
file.write(msg)
conn_cursor.execute("INSERT INTO malicious_local_enterprises_log VALUES (?, ?, ?, ?, ?, ?)", (local_enterprise_idx, 1, self.idx, "", comm_round, when_resync))
conn.commit()
else:
msg = f"WARNING: {self.idx} has mistakenly regard {local_enterprise_idx} as a malicious local_enterprise enterprise in comm_round {comm_round}{msg_end}"
with open(f"{log_files_folder_path}/mistakenly_kicked_local_enterprises.txt", 'a') as file:
file.write(msg)
conn_cursor.execute("INSERT INTO malicious_local_enterprises_log VALUES (?, ?, ?, ?, ?, ?)", (local_enterprise_idx, 0, "", self.idx, comm_round, when_resync))
conn.commit()
print(msg)
# cont = print("Press ENTER to continue")
# identify potentially compromised validator
self.untrustworthy_validators_record_by_comm_round[comm_round] = set()
invalid_validator_sig_local_enterprise_transacitons_in_block = block_to_process.return_transactions()['invalid_validator_sig_transacitons']
for invalid_validator_sig_local_enterprise_transaciton in invalid_validator_sig_local_enterprise_transacitons_in_block:
if self.verify_miner_transaction_by_signature(invalid_validator_sig_local_enterprise_transaciton, mined_by):
validator_enterprise_idx = invalid_validator_sig_local_enterprise_transaciton['validator']
self.untrustworthy_validators_record_by_comm_round[comm_round].add(validator_enterprise_idx)
kick_out_accumulator = 1
# check previous rounds
for comm_round_to_check in range(comm_round - self.knock_out_rounds + 1, comm_round):
if comm_round_to_check in self.untrustworthy_validators_record_by_comm_round.keys():
if validator_enterprise_idx in self.untrustworthy_validators_record_by_comm_round[comm_round_to_check]:
kick_out_accumulator += 1
if kick_out_accumulator == self.knock_out_rounds:
# kick out
self.black_list.add(validator_enterprise_idx)
print(f"{validator_enterprise_idx} has been regarded as a compromised validator by {self.idx} in {comm_round}.")
# actually, we did not let validator do malicious thing if is_malicious=1 is set to this enterprise. In the submission of 2020/10, we only focus on catching malicious local_enterprise
# is it right?
# if when_resync:
# msg_end = " when resyncing!\n"
# else:
# msg_end = "!\n"
# if self.enterprises_dict[validator_enterprise_idx].return_is_malicious():
# msg = f"{self.idx} has successfully identified a compromised validator enterprise {validator_enterprise_idx} in comm_round {comm_round}{msg_end}"
# with open(f"{log_files_folder_path}/correctly_kicked_validators.txt", 'a') as file:
# file.write(msg)
# else:
# msg = f"WARNING: {self.idx} has mistakenly regard {validator_enterprise_idx} as a compromised validator enterprise in comm_round {comm_round}{msg_end}"
# with open(f"{log_files_folder_path}/mistakenly_kicked_validators.txt", 'a') as file:
# file.write(msg)
# print(msg)
# cont = print("Press ENTER to continue")
else:
print(f"one validator transaction miner sig found invalid in this block. {self.idx} will drop this block and roll back rewards information")
return
# give rewards to the miner in this transaction
if self.idx == invalid_validator_sig_local_enterprise_transaciton['miner_enterprise_idx']:
self_rewards_accumulator += invalid_validator_sig_local_enterprise_transaciton['miner_rewards_for_this_tx']
# miner gets mining rewards
if self.idx == mined_by:
self_rewards_accumulator += block_to_process.return_mining_rewards()
# set received rewards this round based on info from this block
self.receive_rewards(self_rewards_accumulator)
print(f"{self.role} {self.idx} has received total {self_rewards_accumulator} rewards for this comm round.")
# collect usable local_enterprise updates and do global updates
finally_used_local_params = []
for local_enterprise_enterprise_idx, local_params_record in valid_transactions_records_by_local_enterprise.items():
if local_params_record['finally_used_params']:
# could be None
finally_used_local_params.append((local_enterprise_enterprise_idx, local_params_record['finally_used_params']))
if self.online_switcher():
self.global_update(finally_used_local_params)
else:
print(f"Unfortunately, {self.role} {self.idx} goes offline when it's doing global_updates.")
processing_time = (time.time() - processing_time)/self.computation_power
return processing_time
def add_to_round_end_time(self, time_to_add):
self.round_end_time += time_to_add
def other_tasks_at_the_end_of_comm_round(self, this_comm_round, log_files_folder_path):
self.kick_out_slow_or_lazy_local_enterprises(this_comm_round, log_files_folder_path)
def kick_out_slow_or_lazy_local_enterprises(self, this_comm_round, log_files_folder_path):
for enterprise in self.peer_list:
if enterprise.return_role() == 'local_enterprise':
if this_comm_round in self.active_local_enterprise_record_by_round.keys():
if not enterprise.return_idx() in self.active_local_enterprise_record_by_round[this_comm_round]:
not_active_accumulator = 1
# check if not active for the past (lazy_local_enterprise_knock_out_rounds - 1) rounds
for comm_round_to_check in range(this_comm_round - self.lazy_local_enterprise_knock_out_rounds + 1, this_comm_round):
if comm_round_to_check in self.active_local_enterprise_record_by_round.keys():
if not enterprise.return_idx() in self.active_local_enterprise_record_by_round[comm_round_to_check]:
not_active_accumulator += 1
if not_active_accumulator == self.lazy_local_enterprise_knock_out_rounds:
# kick out
self.black_list.add(enterprise.return_idx())
msg = f"local_enterprise {enterprise.return_idx()} has been regarded as a lazy local_enterprise by {self.idx} in comm_round {this_comm_round}.\n"
with open(f"{log_files_folder_path}/kicked_lazy_local_enterprises.txt", 'a') as file:
file.write(msg)
else:
# this may happen when a enterprise is put into black list by every local_enterprise in a certain comm round
pass
def update_model_after_chain_resync(self, log_files_folder_path, conn, conn_cursor):
# reset global params to the initial weights of the net
self.global_parameters = copy.deepcopy(self.initial_net_parameters)
# in future version, develop efficient updating algorithm based on chain difference
for block in self.return_consortium_blockchain_object().return_chain_structure():
self.process_block(block, log_files_folder_path, conn, conn_cursor, when_resync=True)
def return_pow_difficulty(self):
return self.pow_difficulty
def register_in_the_network(self, check_online=False):
if self.aio:
self.add_peers(set(self.enterprises_dict.values()))
else:
potential_registrars = set(self.enterprises_dict.values())
# it cannot register with itself
potential_registrars.discard(self)
# pick a registrar
registrar = random.sample(potential_registrars, 1)[0]
if check_online:
if not registrar.is_online():
online_registrars = set()
for registrar in potential_registrars:
if registrar.is_online():
online_registrars.add(registrar)
if not online_registrars:
return False
registrar = random.sample(online_registrars, 1)[0]
# registrant add registrar to its peer list
self.add_peers(registrar)
# this enterprise sucks in registrar's peer list
self.add_peers(registrar.return_peers())
# registrar adds registrant(must in this order, or registrant will add itself from registrar's peer list)
registrar.add_peers(self)
return True
''' Local Enterprise '''
def malicious_local_enterprise_add_noise_to_weights(self, m):
with torch.no_grad():
if hasattr(m, 'weight'):
noise = self.noise_variance * torch.randn(m.weight.size())
variance_of_noise = torch.var(noise)
m.weight.add_(noise.to(self.dev))
self.variance_of_noises.append(float(variance_of_noise))
def malicious_local_enterprise_add_noise_to_datas(self, m):
# done in DatasetLoad.py
pass
# TODO change to computation power
# FedAnil++: Local Update
def local_enterprise_local_update(self, rewards, log_files_folder_path_comm_round, comm_round, local_epochs=1):
print(f"Local Enterprise {self.idx} is doing local_update with computation power {self.computation_power} and link speed {round(self.link_speed,3)} bytes/s")
self.net.load_state_dict(self.global_parameters, strict=True)
# Total Computation Cost (Second)
self.local_update_time = time.time()
# local local_enterprise update by specified epochs
# usually, if validator acception time is specified, local_epochs should be 1
# logging maliciousness
is_malicious_node = "M" if self.return_is_malicious() else "B"
self.local_updates_rewards_per_transaction = 0
# FedAnil++: Training the models that were selected.
for mt in self.model_type:
model_type_name = self.return_model_type(mt)
for epoch in range(local_epochs):
for data, label in self.train_dl:
data, label = data.to(self.dev), label.to(self.dev)
preds = self.net(data, model_type_name)
loss = self.loss_func(preds, label)
loss.backward()
self.opti.step()
self.opti.zero_grad()
self.local_updates_rewards_per_transaction += rewards * (label.shape[0])
# record accuracies to find good -vh
with open(f"{log_files_folder_path_comm_round}/local_enterprise_{self.idx}_{is_malicious_node}_local_updating_accuracies_comm_{comm_round}.txt", "a") as file:
file.write(f"{self.return_idx()} epoch_{epoch+1} {self.return_role()} {is_malicious_node}: {self.validate_model_weights(self.net.state_dict())}\n")
self.local_total_epoch += 1
# FedAnil++: Sparsification
self.net.first_filter(self.global_parameters)
# FedAnil++: K-Medoids based Quantization
self.net.kmedoids_update()
# FedAnil++: Arithmetic Entropy Encoding
self.arithmetic_entropy_coding()
# FedAnil++: Homomorphic Encryption
self.homomorphic_encryption()
# local update done
# Total Computation Cost (Second)
try:
self.local_update_time = (time.time() - self.local_update_time)/self.computation_power
except:
self.local_update_time = float('inf')
#if self.is_malicious:
#self.net.apply(self.malicious_local_enterprise_add_noise_to_weights)
#print(f"malicious local_enterprise {self.idx} has added noise to its local updated weights before transmitting")
#with open(f"{log_files_folder_path_comm_round}/comm_{comm_round}_variance_of_noises.txt", "a") as file:
#file.write(f"{self.return_idx()} {self.return_role()} {is_malicious_node} noise variances: {self.variance_of_noises}\n")
#
# record accuracies to find good -vh
with open(f"{log_files_folder_path_comm_round}/local_enterprise_final_local_accuracies_comm_{comm_round}.txt", "a") as file:
file.write(f"{self.return_idx()} {self.return_role()} {is_malicious_node}: {self.validate_model_weights(self.net.state_dict())}\n")
print(f"Done {local_epochs} epoch(s) and total {self.local_total_epoch} epochs")
self.local_train_parameters = self.net.state_dict()
self.upload_local_model(self.net.state_dict())
return self.local_update_time
# used to simulate time waste when local_enterprise goes offline during transmission to validator
def waste_one_epoch_local_update_time(self, opti):
if self.computation_power == 0:
return float('inf'), None
else:
validation_net = copy.deepcopy(self.net)
currently_used_lr = 0.01
for param_group in self.opti.param_groups:
currently_used_lr = param_group['lr']
# by default use SGD. Did not implement others
if opti == 'SGD':
validation_opti = optim.SGD(validation_net.parameters(), lr=currently_used_lr, momentum=0.9)
else:
validation_opti = optim.Adam(validation_net.parameters(), lr=currently_used_lr, betas=(0.9, 0.9))
local_update_time = time.time()
for data, label in self.train_dl:
data, label = data.to(self.dev), label.to(self.dev)
preds = validation_net(data)
loss = self.loss_func(preds, label)
loss.backward()
validation_opti.step()
validation_opti.zero_grad()
return (time.time() - local_update_time)/self.computation_power, validation_net.state_dict()
def set_accuracy_this_round(self, accuracy):
self.accuracy_this_round = accuracy
def return_accuracy_this_round(self):
return self.accuracy_this_round
def return_link_speed(self):
return self.link_speed
def return_local_updates_and_signature(self, comm_round):
# local_total_accumulated_epochs_this_round also stands for the lastest_epoch_seq for this transaction(local params are calculated after this amount of local epochs in this round)
# last_local_iteration(s)_spent_time may be recorded to determine calculating time? But what if nodes do not wish to disclose its computation power
local_updates_dict = {'local_enterprise_enterprise_idx': self.idx, 'in_round_number': comm_round, "local_updates_params": copy.deepcopy(self.local_train_parameters), "local_updates_rewards": self.local_updates_rewards_per_transaction, "local_iteration(s)_spent_time": self.local_update_time, "local_total_accumulated_epochs_this_round": self.local_total_epoch, "local_enterprise_rsa_pub_key": self.return_rsa_pub_key()}
local_updates_dict["local_enterprise_signature"] = self.sign_msg(sorted(local_updates_dict.items()))
return local_updates_dict
def local_enterprise_reset_vars_for_new_round(self):
self.received_block_from_miner = None
self.accuracy_this_round = float('-inf')
self.local_updates_rewards_per_transaction = 0
self.has_added_block = False
self.the_added_block = None
self.local_enterprise_associated_validator = None
self.local_enterprise_associated_miner = None
self.local_update_time = None
self.local_total_epoch = 0
self.variance_of_noises.clear()
self.round_end_time = 0
def receive_block_from_miner(self, received_block, source_miner):
if not (received_block.return_mined_by() in self.black_list or source_miner in self.black_list):
self.received_block_from_miner = copy.deepcopy(received_block)
print(f"{self.role} {self.idx} has received a new block from {source_miner} mined by {received_block.return_mined_by()}.")
else:
print(f"Either the block sending miner {source_miner} or the miner {received_block.return_mined_by()} mined this block is in local_enterprise {self.idx}'s black list. Block is not accepted.")
def toss_received_block(self):
self.received_block_from_miner = None
def reset_last(self):
global lastprc
lastprc = 0
def return_received_block_from_miner(self):
return self.received_block_from_miner
# FedAnil++: Total Accuracy (%)
def validate_model_weights(self, weights_to_eval=None):
with torch.no_grad():
if weights_to_eval:
self.net.load_state_dict(weights_to_eval, strict=True)
else:
self.net.load_state_dict(self.global_parameters, strict=True)
sum_accu = 0
num = 0
for data, label in self.test_dl:
data, label = data.to(self.dev), label.to(self.dev)
preds = self.net(data)
preds = torch.argmax(preds, dim=1)
sum_accu += (preds == label).float().mean()
num += 1
return sum_accu / num
# FedAnil++: Global Update
def global_update(self, local_update_params_potentially_to_be_used): #global update
# FedAnil++: get local updates from consortium blockchain
self.get_local_params_by_local_enterprises = self.fetch_local_models()
self.global_time = time.time()
# FedAnil++: Arithmetic Entropy Decoding
self.arithmetic_entropy_decoding()
# FedAnil++: Calculating of Cosine Similarity: Object Handler
cosine_similarity_operator = torch.nn.CosineSimilarity(dim=0)
# filter local_params
global lastprc, flcnt
local_params_by_benign_local_enterprises = []
for (local_enterprise_enterprise_idx, local_params) in local_update_params_potentially_to_be_used:
if not local_enterprise_enterprise_idx in self.black_list:
local_params_by_benign_local_enterprises.append(local_params)
else:
print(f"global update skipped for a local_enterprise {local_enterprise_enterprise_idx} in {self.idx}'s black list")
if local_params_by_benign_local_enterprises:
nums_of_local_params = len(local_params_by_benign_local_enterprises)
nums_of_local_param_len = len(local_params_by_benign_local_enterprises[0])
similarity_matrix = np.zeros((nums_of_local_params, nums_of_local_param_len))
i = 0
sum_parameters = None
for local_updates_params in local_params_by_benign_local_enterprises:
j = 0
for var in local_updates_params:
# FedAnil++: Calculating of Cosine Similarity: Distance of Local Models and Global Model in Prior Round
similarity = cosine_similarity_operator(local_params_by_benign_local_enterprises[i][var].view(-1), self.global_parameters[var].view(-1))
if similarity > TRESHOLD1 and similarity < TRESHOLD2:
#sum_parameters[var] += local_updates_params[var]
similarity_matrix[i, j] = similarity
j += 1
i += 1
# FedAnil++: Affinity Propagation
try:
affinity_propagation = AffinityPropagation().fit(similarity_matrix)
cluster_centers = affinity_propagation.cluster_centers_
labels = affinity_propagation.labels_
cluster_nums = np.zeros((len(cluster_centers), 1))
sum_of_params_by_ap = np.zeros((len(cluster_centers), nums_of_local_param_len))
for it in range(nums_of_local_params):
cluster_index = labels[it]
for var in local_params_by_benign_local_enterprises[it]:
sum_of_params_by_ap[cluster_index][var] += local_params_by_benign_local_enterprises[it][var]
cluster_nums[cluster_index] += 1
except:
pass
# FedAnil++: FedAvg the gradients
num_participants = len(local_params_by_benign_local_enterprises)
sum_parameters = None
for mt in m_type:
for local_updates_params in local_params_by_benign_local_enterprises:
if sum_parameters is None:
sum_parameters = copy.deepcopy(local_updates_params)
else:
for var in sum_parameters:
if var.startswith(mt):
sum_parameters[var] += local_updates_params[var]
for var in self.global_parameters:
self.global_parameters[var] = (sum_parameters[var] / num_participants)
print(f"global updates done by {self.idx}")
else:
print(f"There are no available local params for {self.idx} to perform global updates in this comm round.")
if 1 > lastprc:
lastprc = lastprc + 1
# FedAnil++: GAN
print("Genrative Adversial Network process start")
for mt in ['cnn', 'resnet', 'glove']:
#print(f"GAN ({mt})")
self.GAN(mt)
#print("GAN end")
self.global_time = time.time() - self.global_time
def GAN(self, model_type):
#print("GAN start")
select_model = model_type
discriminator = CombinedModel()
discriminator.load_state_dict(self.global_parameters)
generator = Generator()
batch_size = self.local_batch_size
lr = 0.001
num_epochs = 5
loss_function = nn.MSELoss()
train_loader = self.train_dl