Skip to content

Commit

Permalink
Merge pull request #2 from rjbruin/demo
Browse files Browse the repository at this point in the history
Refactor of ckconv.nn APIs + demo notebook for Arxiv paper
  • Loading branch information
rjbruin authored Dec 7, 2021
2 parents c67d144 + 9d57fd7 commit 3a6c0bb
Show file tree
Hide file tree
Showing 17 changed files with 1,080 additions and 167 deletions.
15 changes: 6 additions & 9 deletions antialiasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@


def regularize_gabornet(
model, kernel_size, target, fn, method, factor, gauss_stddevs=1.0, gauss_factor=0.5
model, horizon, factor, target="gabor", fn="l2_relu", method="together", gauss_stddevs=1.0
):
"""Regularize a FlexNet.
"""
# if method != "summed":
# raise NotImplementedError()

Expand Down Expand Up @@ -35,7 +37,7 @@ def regularize_gabornet(
else:
flexconv_freqs = magnet_freqs

nyquist_freq = torch.ones_like(flexconv_freqs) * nyquist_frequency(kernel_size)
nyquist_freq = torch.ones_like(flexconv_freqs) * nyquist_frequency(horizon)

elif method == "together" and target == "gabor":
if masks:
Expand All @@ -44,7 +46,7 @@ def regularize_gabornet(
flexconv_freqs = magnet_freqs

# Divide Nyquist frequency by amount of filters in each layer
nyquist_freq = torch.ones_like(flexconv_freqs) * nyquist_frequency(kernel_size)
nyquist_freq = torch.ones_like(flexconv_freqs) * nyquist_frequency(horizon)
nyquist_freq = nyquist_freq / nyquist_freq.shape[1]

elif method in ["together", "together+mask"] and target == "gabor+mask":
Expand All @@ -59,14 +61,9 @@ def regularize_gabornet(
raise NotImplementedError()

# Divide Nyquist frequency by amount of filters in each layer
nyquist_freq = torch.ones_like(flexconv_freqs) * nyquist_frequency(kernel_size)
nyquist_freq = torch.ones_like(flexconv_freqs) * nyquist_frequency(horizon)
nyquist_freq = nyquist_freq / nyquist_freq.shape[1]

# if method == "distributed":
# # Further divide Nyquist freq between sines and gausses
# nyquist_freq[:, :, 0] = nyquist_freq[:, :, 0] * (1.0 - gauss_factor)
# nyquist_freq[:, :, 1] = nyquist_freq[:, :, 1] * gauss_factor

if fn == "l2_relu":
# L2 ReLU
return factor * l2_relu(flexconv_freqs, nyquist_freq)
Expand Down
19 changes: 6 additions & 13 deletions cfg/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,9 @@ profile:
time: False
summary: [0, ]
summary_depth: 5













testcase:
load: False
save: False
epochs: 1
batches: 20
path: ""
235 changes: 127 additions & 108 deletions ckconv/nn/ckconv.py

Large diffs are not rendered by default.

573 changes: 573 additions & 0 deletions demo/flexconv.ipynb

Large diffs are not rendered by default.

Binary file added demo/kernel_sizes_flexnet.pdf
Binary file not shown.
Binary file added demo/kernel_sizes_flexnet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/truck5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
87 changes: 75 additions & 12 deletions models/ckresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,37 +80,100 @@ def __init__(
nonlinearity = net_config.nonlinearity

# Unpack dim_linear
dim_linear = kernel_config.dim_linear
scale_sigma = kernel_config.srf.scale
kernel_scale_sigma = kernel_config.srf.scale

# Unpack conv_type
conv_type = conv_config.type

# Unpack kernel_config
kernel_type = kernel_config.type
kernel_dim_linear = kernel_config.dim_linear
kernel_no_hidden = kernel_config.no_hidden
kernel_no_layers = kernel_config.no_layers
kernel_activ_function = kernel_config.activ_function
kernel_norm = kernel_config.norm
kernel_omega_0 = kernel_config.omega_0
kernel_learn_omega_0 = kernel_config.learn_omega_0
kernel_weight_norm = kernel_config.weight_norm
kernel_steerable = kernel_config.steerable
kernel_init_spatial_value = kernel_config.init_spatial_value
kernel_bias_init = kernel_config.bias_init
kernel_input_scale = kernel_config.input_scale
kernel_sampling_rate_norm = kernel_config.sampling_rate_norm

# Define Convolution Type:
# -------------------------
# Unpack other conv_config values in case normal convolutions are used.
conv_use_fft = conv_config.use_fft
conv_horizon = conv_config.horizon
conv_padding = conv_config.padding
conv_stride = conv_config.stride
conv_bias = conv_config.bias

# Unpack mask_config
mask_use = mask_config.use
mask_type = mask_config.type
mask_init_value = mask_config.init_value
mask_temperature = mask_config.temperature
mask_dynamic_cropping = mask_config.dynamic_cropping
mask_threshold = mask_config.threshold

# Define partials for types of convs
if conv_type == "CKConv":
ConvType = partial(
ckconv.nn.CKConv,
kernel_config=kernel_config,
conv_config=conv_config,
horizon=conv_horizon,
kernel_type=kernel_type,
kernel_dim_linear=kernel_dim_linear,
kernel_no_hidden=kernel_no_hidden,
kernel_no_layers=kernel_no_layers,
kernel_activ_function=kernel_activ_function,
kernel_norm=kernel_norm,
kernel_omega_0=kernel_omega_0,
kernel_learn_omega_0=kernel_learn_omega_0,
kernel_weight_norm=kernel_weight_norm,
kernel_steerable=kernel_steerable,
kernel_init_spatial_value=kernel_init_spatial_value,
kernel_bias_init=kernel_bias_init,
kernel_input_scale=kernel_input_scale,
kernel_sampling_rate_norm=kernel_sampling_rate_norm,
conv_use_fft=conv_use_fft,
conv_bias=conv_bias,
conv_padding=conv_padding,
conv_stride=1,
)
elif conv_type == "FlexConv":
ConvType = partial(
ckconv.nn.FlexConv,
kernel_config=kernel_config,
conv_config=conv_config,
mask_config=mask_config,
horizon=conv_horizon,
kernel_type=kernel_type,
kernel_dim_linear=kernel_dim_linear,
kernel_no_hidden=kernel_no_hidden,
kernel_no_layers=kernel_no_layers,
kernel_activ_function=kernel_activ_function,
kernel_norm=kernel_norm,
kernel_omega_0=kernel_omega_0,
kernel_learn_omega_0=kernel_learn_omega_0,
kernel_weight_norm=kernel_weight_norm,
kernel_steerable=kernel_steerable,
kernel_init_spatial_value=kernel_init_spatial_value,
kernel_bias_init=kernel_bias_init,
kernel_input_scale=kernel_input_scale,
kernel_sampling_rate_norm=kernel_sampling_rate_norm,
conv_use_fft=conv_use_fft,
conv_bias=conv_bias,
conv_padding=conv_padding,
conv_stride=conv_stride,
mask_use=mask_use,
mask_type=mask_type,
mask_init_value=mask_init_value,
mask_temperature=mask_temperature,
mask_dynamic_cropping=mask_dynamic_cropping,
mask_threshold=mask_threshold,
)
elif conv_type == "Conv":
ConvType = partial(
getattr(torch.nn, f"Conv{dim_linear}d"),
getattr(torch.nn, f"Conv{kernel_dim_linear}d"),
kernel_size=int(conv_horizon),
padding=conv_padding,
stride=conv_stride,
Expand All @@ -123,15 +186,15 @@ def __init__(
init_order=4.0,
init_scale=0.0,
use_cuda=True, # NOTE(rjbruin): hardcoded for now
scale_sigma=scale_sigma,
scale_sigma=kernel_scale_sigma,
)
else:
raise NotImplementedError(f"conv_type = {conv_type}")
# -------------------------

# Define NormType
NormType = {
"BatchNorm": getattr(torch.nn, f"BatchNorm{dim_linear}d"),
"BatchNorm": getattr(torch.nn, f"BatchNorm{kernel_dim_linear}d"),
"LayerNorm": ckconv.nn.LayerNorm,
}[norm]

Expand All @@ -140,7 +203,7 @@ def __init__(
]

# Define LinearType
LinearType = getattr(ckconv.nn, f"Linear{dim_linear}d")
LinearType = getattr(ckconv.nn, f"Linear{kernel_dim_linear}d")

# Create Input Layers
self.cconv1 = ConvType(in_channels=in_channels, out_channels=hidden_channels)
Expand Down Expand Up @@ -215,7 +278,7 @@ def __init__(
# -------------------------

# Save variables in self
self.dim_linear = dim_linear
self.dim_linear = kernel_dim_linear

def forward(self, x):
# First layers
Expand Down
97 changes: 82 additions & 15 deletions models/cktcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
in_channels: int,
out_channels: int,
ConvType: Union[CKConv, FlexConv, Conv1d, Conv2d],
NonlinearType: torch.nn.Module,
NormType: torch.nn.Module,
LinearType: torch.nn.Module,
dropout: float,
Expand Down Expand Up @@ -44,15 +45,16 @@ def __init__(
in_channels=in_channels,
out_channels=out_channels,
ConvType=ConvType,
NonlinearType=NonlinearType,
NormType=NormType,
LinearType=LinearType,
dropout=dropout,
)

def forward(self, x):
shortcut = self.shortcut(x)
out = self.dp(torch.relu(self.norm1(self.cconv1(x))))
out = torch.relu(self.dp(torch.relu(self.norm2(self.cconv2(out)))) + shortcut)
out = self.dp(self.nonlinear(self.norm1(self.cconv1(x))))
out = self.nonlinear(self.dp(self.nonlinear(self.norm2(self.cconv2(out)))) + shortcut)
return out


Expand All @@ -74,38 +76,98 @@ def __init__(
norm = net_config.norm
dropout = net_config.dropout
block_width_factors = net_config.block_width_factors
nonlinearity = net_config.nonlinearity

# Unpack dim_linear
dim_linear = kernel_config.dim_linear

# Unpack conv_type
conv_type = conv_config.type
# Unpack kernel_config
kernel_type = kernel_config.type
kernel_dim_linear = kernel_config.dim_linear
kernel_no_hidden = kernel_config.no_hidden
kernel_no_layers = kernel_config.no_layers
kernel_activ_function = kernel_config.activ_function
kernel_norm = kernel_config.norm
kernel_omega_0 = kernel_config.omega_0
kernel_learn_omega_0 = kernel_config.learn_omega_0
kernel_weight_norm = kernel_config.weight_norm
kernel_steerable = kernel_config.steerable
kernel_init_spatial_value = kernel_config.init_spatial_value
kernel_bias_init = kernel_config.bias_init
kernel_input_scale = kernel_config.input_scale
kernel_sampling_rate_norm = kernel_config.sampling_rate_norm

# Define Convolution Type:
# -------------------------
# Unpack other conv_config values in case normal convolutions are used.
conv_type = conv_config.type
conv_horizon = conv_config.horizon
conv_padding = conv_config.padding
conv_stride = conv_config.stride
conv_bias = conv_config.bias
conv_use_fft = conv_config.use_fft

# Unpack mask_config
mask_use = mask_config.use
mask_type = mask_config.type
mask_init_value = mask_config.init_value
mask_temperature = mask_config.temperature
mask_dynamic_cropping = mask_config.dynamic_cropping
mask_threshold = mask_config.threshold

# Define partials for types of convs
if conv_type == "CKConv":
ConvType = partial(
ckconv.nn.CKConv,
kernel_config=kernel_config,
conv_config=conv_config,
horizon=conv_horizon,
kernel_type=kernel_type,
kernel_dim_linear=kernel_dim_linear,
kernel_no_hidden=kernel_no_hidden,
kernel_no_layers=kernel_no_layers,
kernel_activ_function=kernel_activ_function,
kernel_norm=kernel_norm,
kernel_omega_0=kernel_omega_0,
kernel_learn_omega_0=kernel_learn_omega_0,
kernel_weight_norm=kernel_weight_norm,
kernel_steerable=kernel_steerable,
kernel_init_spatial_value=kernel_init_spatial_value,
kernel_bias_init=kernel_bias_init,
kernel_input_scale=kernel_input_scale,
kernel_sampling_rate_norm=kernel_sampling_rate_norm,
conv_use_fft=conv_use_fft,
conv_padding=conv_padding,
conv_stride=conv_stride,
conv_bias=conv_bias,
)
elif conv_type == "FlexConv":
ConvType = partial(
ckconv.nn.FlexConv,
kernel_config=kernel_config,
conv_config=conv_config,
mask_config=mask_config,
horizon=conv_horizon,
kernel_type=kernel_type,
kernel_dim_linear=kernel_dim_linear,
kernel_no_hidden=kernel_no_hidden,
kernel_no_layers=kernel_no_layers,
kernel_activ_function=kernel_activ_function,
kernel_norm=kernel_norm,
kernel_omega_0=kernel_omega_0,
kernel_learn_omega_0=kernel_learn_omega_0,
kernel_weight_norm=kernel_weight_norm,
kernel_steerable=kernel_steerable,
kernel_init_spatial_value=kernel_init_spatial_value,
kernel_bias_init=kernel_bias_init,
kernel_input_scale=kernel_input_scale,
kernel_sampling_rate_norm=kernel_sampling_rate_norm,
conv_use_fft=conv_use_fft,
conv_padding=conv_padding,
conv_stride=conv_stride,
conv_bias=conv_bias,
mask_use=mask_use,
mask_type=mask_type,
mask_init_value=mask_init_value,
mask_temperature=mask_temperature,
mask_dynamic_cropping=mask_dynamic_cropping,
mask_threshold=mask_threshold,
)
elif conv_type == "Conv":
ConvType = partial(
getattr(torch.nn, f"Conv{dim_linear}d"),
getattr(torch.nn, f"Conv{kernel_dim_linear}d"),
kernel_size=int(conv_horizon),
padding=conv_padding,
stride=conv_stride,
Expand All @@ -117,12 +179,16 @@ def __init__(

# Define NormType
NormType = {
"BatchNorm": getattr(torch.nn, f"BatchNorm{dim_linear}d"),
"BatchNorm": getattr(torch.nn, f"BatchNorm{kernel_dim_linear}d"),
"LayerNorm": ckconv.nn.LayerNorm,
}[norm]

NonlinearType = {"ReLU": torch.nn.ReLU, "LeakyReLU": torch.nn.LeakyReLU}[
nonlinearity
]

# Define LinearType
LinearType = getattr(ckconv.nn, f"Linear{dim_linear}d")
LinearType = getattr(ckconv.nn, f"Linear{kernel_dim_linear}d")

# Create Blocks
# -------------------------
Expand Down Expand Up @@ -159,6 +225,7 @@ def __init__(
in_channels=input_ch,
out_channels=hidden_ch,
ConvType=ConvType,
NonlinearType=NonlinearType,
NormType=NormType,
LinearType=LinearType,
dropout=dropout,
Expand Down
Binary file added testcases/c10-fn5-reg.npy
Binary file not shown.
Binary file added testcases/c10-fn5.npy
Binary file not shown.
Binary file added testcases/mnist-fn7-randomsettings.npy
Binary file not shown.
Binary file added testcases/scifar-tcn-flexconv.npy
Binary file not shown.
Binary file added testcases/smnist-tcn-ckconv.npy
Binary file not shown.
Loading

0 comments on commit 3a6c0bb

Please sign in to comment.