Skip to content

Commit

Permalink
fixed discriminator
Browse files Browse the repository at this point in the history
  • Loading branch information
mk314k committed Nov 25, 2023
1 parent c9103bf commit 3a8dc39
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
4 changes: 2 additions & 2 deletions models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def conv3d(in_channel:int, out_channel:int):
nn.Module: Conv3d Module
"""
return nn.Sequential(
nn.Conv3d(channel[0], channel[1], kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm3d(channel[1]),
nn.Conv3d(in_channel, out_channel, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm3d(out_channel),
nn.LeakyReLU(0.2, inplace=True),
)

Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
fancy_einsum
torch
torchvision
tqdm
Expand Down
10 changes: 5 additions & 5 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,17 @@ def preprocess(images, sizex=None, sizey=None):



def downsample(mat:np.ndarray, down_sample=64)->np.ndarray:
def downsample(mat:np.ndarray, new_size=64)->np.ndarray:
"""
Args:
mat (np.ndarray): 3d image
down_sample (int, optional): the new number of voxels cubes. Defaults to 64.
new_size (int, optional): the new number of voxels cubes. Defaults to 64.
Returns:
np.ndarray: 3d images after downsampling
"""
lost_dim = int(mat.shape[2]/down_sample)
return mat.reshape((-1, down_sample, lost_dim, down_sample, lost_dim, down_sample, lost_dim)).mean(
lost_dim = int(mat.shape[2]/new_size)
return mat.reshape((-1, new_size, lost_dim, new_size, lost_dim, new_size, lost_dim)).mean(
axis=(2, 4, 6)
)

Expand Down Expand Up @@ -97,4 +97,4 @@ def load_data(data_path:str, down_sample=None):
img.append(cv2.imread(folder_path + "/" + format(i, "03d") + ".png")) # pylint: disable=no-member
img2d.append(preprocess(np.array(img)))
img3d.append(mat_data)
return img2d, img3d
return img2d, img3d

0 comments on commit 3a8dc39

Please sign in to comment.