From b98ad74f9aa46fe080498a3c555b54c60dc50e54 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Fri, 13 Oct 2023 15:43:47 -0700 Subject: [PATCH] [diffusion_labs] Move attention and res blocks under modules [ghstack-poisoned] --- tests/diffusion_labs/test_adm.py | 4 ++-- tests/diffusion_labs/test_adm_blocks.py | 4 ++-- tests/diffusion_labs/test_adm_crossattention.py | 2 +- torchmultimodal/diffusion_labs/models/adm_unet/adm.py | 6 ++---- .../{models/adm_unet => modules/layers}/attention_block.py | 0 .../{models/adm_unet => modules/layers}/res_block.py | 0 6 files changed, 7 insertions(+), 9 deletions(-) rename torchmultimodal/diffusion_labs/{models/adm_unet => modules/layers}/attention_block.py (100%) rename torchmultimodal/diffusion_labs/{models/adm_unet => modules/layers}/res_block.py (100%) diff --git a/tests/diffusion_labs/test_adm.py b/tests/diffusion_labs/test_adm.py index c1ad55a39..b6066bb40 100644 --- a/tests/diffusion_labs/test_adm.py +++ b/tests/diffusion_labs/test_adm.py @@ -11,10 +11,10 @@ from tests.test_utils import assert_expected, set_rng_seed from torch import nn from torchmultimodal.diffusion_labs.models.adm_unet.adm import ADMStack, ADMUNet -from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( +from torchmultimodal.diffusion_labs.modules.layers.attention_block import ( ADMAttentionBlock, ) -from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ADMResBlock +from torchmultimodal.diffusion_labs.modules.layers.res_block import ADMResBlock @pytest.fixture(autouse=True) diff --git a/tests/diffusion_labs/test_adm_blocks.py b/tests/diffusion_labs/test_adm_blocks.py index 2b8789555..d89c6df91 100644 --- a/tests/diffusion_labs/test_adm_blocks.py +++ b/tests/diffusion_labs/test_adm_blocks.py @@ -10,10 +10,10 @@ import torch from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( +from torchmultimodal.diffusion_labs.modules.layers.attention_block import ( ADMAttentionBlock, ) -from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ( +from torchmultimodal.diffusion_labs.modules.layers.res_block import ( adm_res_block, adm_res_downsample_block, adm_res_upsample_block, diff --git a/tests/diffusion_labs/test_adm_crossattention.py b/tests/diffusion_labs/test_adm_crossattention.py index fd61c778d..23543993c 100644 --- a/tests/diffusion_labs/test_adm_crossattention.py +++ b/tests/diffusion_labs/test_adm_crossattention.py @@ -8,7 +8,7 @@ import torch from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( +from torchmultimodal.diffusion_labs.modules.layers.attention_block import ( adm_attention, ADMCrossAttention, ) diff --git a/torchmultimodal/diffusion_labs/models/adm_unet/adm.py b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py index fc831cc9c..9aa9b00fc 100644 --- a/torchmultimodal/diffusion_labs/models/adm_unet/adm.py +++ b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py @@ -9,10 +9,8 @@ import torch from torch import nn, Tensor -from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( - adm_attn_block, -) -from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ( +from torchmultimodal.diffusion_labs.modules.layers.attention_block import adm_attn_block +from torchmultimodal.diffusion_labs.modules.layers.res_block import ( adm_res_block, adm_res_downsample_block, adm_res_upsample_block, diff --git a/torchmultimodal/diffusion_labs/models/adm_unet/attention_block.py b/torchmultimodal/diffusion_labs/modules/layers/attention_block.py similarity index 100% rename from torchmultimodal/diffusion_labs/models/adm_unet/attention_block.py rename to torchmultimodal/diffusion_labs/modules/layers/attention_block.py diff --git a/torchmultimodal/diffusion_labs/models/adm_unet/res_block.py b/torchmultimodal/diffusion_labs/modules/layers/res_block.py similarity index 100% rename from torchmultimodal/diffusion_labs/models/adm_unet/res_block.py rename to torchmultimodal/diffusion_labs/modules/layers/res_block.py