diff --git a/tests/diffusion_labs/test_adm.py b/tests/diffusion_labs/test_adm.py index c1ad55a39..b6066bb40 100644 --- a/tests/diffusion_labs/test_adm.py +++ b/tests/diffusion_labs/test_adm.py @@ -11,10 +11,10 @@ from tests.test_utils import assert_expected, set_rng_seed from torch import nn from torchmultimodal.diffusion_labs.models.adm_unet.adm import ADMStack, ADMUNet -from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( +from torchmultimodal.diffusion_labs.modules.layers.attention_block import ( ADMAttentionBlock, ) -from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ADMResBlock +from torchmultimodal.diffusion_labs.modules.layers.res_block import ADMResBlock @pytest.fixture(autouse=True) diff --git a/tests/diffusion_labs/test_adm_blocks.py b/tests/diffusion_labs/test_adm_blocks.py index 2b8789555..d89c6df91 100644 --- a/tests/diffusion_labs/test_adm_blocks.py +++ b/tests/diffusion_labs/test_adm_blocks.py @@ -10,10 +10,10 @@ import torch from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( +from torchmultimodal.diffusion_labs.modules.layers.attention_block import ( ADMAttentionBlock, ) -from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ( +from torchmultimodal.diffusion_labs.modules.layers.res_block import ( adm_res_block, adm_res_downsample_block, adm_res_upsample_block, diff --git a/tests/diffusion_labs/test_adm_crossattention.py b/tests/diffusion_labs/test_adm_crossattention.py index fd61c778d..23543993c 100644 --- a/tests/diffusion_labs/test_adm_crossattention.py +++ b/tests/diffusion_labs/test_adm_crossattention.py @@ -8,7 +8,7 @@ import torch from tests.test_utils import assert_expected, set_rng_seed -from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( +from torchmultimodal.diffusion_labs.modules.layers.attention_block import ( adm_attention, ADMCrossAttention, ) diff --git a/torchmultimodal/diffusion_labs/models/adm_unet/adm.py b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py index 7d1793781..9eab9530d 100644 --- a/torchmultimodal/diffusion_labs/models/adm_unet/adm.py +++ b/torchmultimodal/diffusion_labs/models/adm_unet/adm.py @@ -9,10 +9,8 @@ import torch from torch import nn, Tensor -from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import ( - adm_attn_block, -) -from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ( +from torchmultimodal.diffusion_labs.modules.layers.attention_block import adm_attn_block +from torchmultimodal.diffusion_labs.modules.layers.res_block import ( adm_res_block, adm_res_downsample_block, adm_res_upsample_block, diff --git a/torchmultimodal/diffusion_labs/models/adm_unet/attention_block.py b/torchmultimodal/diffusion_labs/modules/layers/attention_block.py similarity index 100% rename from torchmultimodal/diffusion_labs/models/adm_unet/attention_block.py rename to torchmultimodal/diffusion_labs/modules/layers/attention_block.py diff --git a/torchmultimodal/diffusion_labs/models/adm_unet/res_block.py b/torchmultimodal/diffusion_labs/modules/layers/res_block.py similarity index 100% rename from torchmultimodal/diffusion_labs/models/adm_unet/res_block.py rename to torchmultimodal/diffusion_labs/modules/layers/res_block.py