Skip to content

Commit

Permalink
[diffusion_labs] Move attention and res blocks under modules
Browse files Browse the repository at this point in the history
ghstack-source-id: b9ab08988a1272c658b37b4dcf1cddf48cc20270
Pull Request resolved: #490
  • Loading branch information
ebsmothers committed Oct 16, 2023
1 parent a60a47e commit 011c1a0
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 9 deletions.
4 changes: 2 additions & 2 deletions tests/diffusion_labs/test_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from tests.test_utils import assert_expected, set_rng_seed
from torch import nn
from torchmultimodal.diffusion_labs.models.adm_unet.adm import ADMStack, ADMUNet
from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import (
from torchmultimodal.diffusion_labs.modules.layers.attention_block import (
ADMAttentionBlock,
)
from torchmultimodal.diffusion_labs.models.adm_unet.res_block import ADMResBlock
from torchmultimodal.diffusion_labs.modules.layers.res_block import ADMResBlock


@pytest.fixture(autouse=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/diffusion_labs/test_adm_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@

import torch
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import (
from torchmultimodal.diffusion_labs.modules.layers.attention_block import (
ADMAttentionBlock,
)
from torchmultimodal.diffusion_labs.models.adm_unet.res_block import (
from torchmultimodal.diffusion_labs.modules.layers.res_block import (
adm_res_block,
adm_res_downsample_block,
adm_res_upsample_block,
Expand Down
2 changes: 1 addition & 1 deletion tests/diffusion_labs/test_adm_crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
from tests.test_utils import assert_expected, set_rng_seed
from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import (
from torchmultimodal.diffusion_labs.modules.layers.attention_block import (
adm_attention,
ADMCrossAttention,
)
Expand Down
6 changes: 2 additions & 4 deletions torchmultimodal/diffusion_labs/models/adm_unet/adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@

import torch
from torch import nn, Tensor
from torchmultimodal.diffusion_labs.models.adm_unet.attention_block import (
adm_attn_block,
)
from torchmultimodal.diffusion_labs.models.adm_unet.res_block import (
from torchmultimodal.diffusion_labs.modules.layers.attention_block import adm_attn_block
from torchmultimodal.diffusion_labs.modules.layers.res_block import (
adm_res_block,
adm_res_downsample_block,
adm_res_upsample_block,
Expand Down

0 comments on commit 011c1a0

Please sign in to comment.