From a983c19a36ff5260a2d6f552cc81a69cd2014bfb Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 26 May 2023 09:29:50 +0200 Subject: [PATCH] Update Teacher Initialization for MT Setups --- .../probabilistic_domain_adaptation/livecell/punet_adamt.py | 3 ++- .../livecell/punet_mean_teacher.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/experiments/probabilistic_domain_adaptation/livecell/punet_adamt.py b/experiments/probabilistic_domain_adaptation/livecell/punet_adamt.py index d660c278..686f5765 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/punet_adamt.py +++ b/experiments/probabilistic_domain_adaptation/livecell/punet_adamt.py @@ -78,6 +78,7 @@ def _train_source_target(args, source_cell_type, target_cell_type): device=device, log_image_interval=100, save_root=args.save_root, + reinit_teacher=True, compile_model=False ) trainer.fit(args.n_iterations) @@ -119,7 +120,7 @@ def run_evaluation(args): def main(): - parser = common.get_parser(default_iterations=100000, default_batch_size=4) + parser = common.get_parser(default_iterations=10000, default_batch_size=2) parser.add_argument("--confidence_threshold", default=None, type=float) parser.add_argument("--consensus_masking", action='store_true') args = parser.parse_args() diff --git a/experiments/probabilistic_domain_adaptation/livecell/punet_mean_teacher.py b/experiments/probabilistic_domain_adaptation/livecell/punet_mean_teacher.py index 9866b3ab..9b649acf 100644 --- a/experiments/probabilistic_domain_adaptation/livecell/punet_mean_teacher.py +++ b/experiments/probabilistic_domain_adaptation/livecell/punet_mean_teacher.py @@ -82,6 +82,7 @@ def _train_source_target(args, source_cell_type, target_cell_type): device=device, log_image_interval=100, save_root=args.save_root, + reinit_teacher=False, compile_model=False ) trainer.fit(args.n_iterations)