diff --git a/CHANGELOG.md b/CHANGELOG.md index 6072972e..dd87e494 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ - Updated README for pytorch 1.8.1 support - Use custom `gpu_storage` instead of thrust vector for faster constructors - pytorch installation instruction updates +- fix transpose kernel map with `kernel_size == stride_size` +- Update reconstruction and vae examples for v0.5 API +- `stack_unet.py` example, API updates +- `MinkowskiToFeature` layer ## [0.5.2] diff --git a/MinkowskiEngine/MinkowskiOps.py b/MinkowskiEngine/MinkowskiOps.py index 39ce41ce..7a39d58f 100644 --- a/MinkowskiEngine/MinkowskiOps.py +++ b/MinkowskiEngine/MinkowskiOps.py @@ -457,6 +457,26 @@ def __repr__(self): return self.__class__.__name__ + "()" +class MinkowskiToFeature(MinkowskiModuleBase): + r""" + Extract features from a sparse tensor and returns a pytorch tensor. + + Can be used to to make a network construction simpler. + + Example:: + + >>> net = nn.Sequential(MinkowskiConvolution(...), MinkowskiGlobalMaxPooling(...), MinkowskiToFeature(), nn.Linear(...)) + >>> torch_tensor = net(sparse_tensor) + + """ + + def forward(self, x: SparseTensor): + assert isinstance( + x, (SparseTensor, TensorField) + ), "Invalid input type for MinkowskiToFeature" + return x.F + + class MinkowskiStackCat(torch.nn.Sequential): def forward(self, x): return cat([module(x) for module in self]) diff --git a/MinkowskiEngine/__init__.py b/MinkowskiEngine/__init__.py index 8a734d82..4b404250 100644 --- a/MinkowskiEngine/__init__.py +++ b/MinkowskiEngine/__init__.py @@ -22,7 +22,7 @@ # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part # of the code. -__version__ = "0.5.2" +__version__ = "0.5.3" import os import sys @@ -189,6 +189,7 @@ MinkowskiLinear, MinkowskiToSparseTensor, MinkowskiToDenseTensor, + MinkowskiToFeature, MinkowskiStackCat, MinkowskiStackSum, MinkowskiStackMean, diff --git a/docs/utils.rst b/docs/utils.rst index 871cf650..f187289e 100644 --- a/docs/utils.rst +++ b/docs/utils.rst @@ -68,3 +68,51 @@ MinkowskiToDenseTensor .. autoclass:: MinkowskiEngine.MinkowskiToDenseTensor .. automethod:: __init__ + + +MinkowskiToFeature +------------------ + +.. autoclass:: MinkowskiEngine.MinkowskiToFeature + + .. automethod:: __init__ + + +MinkowskiStackCat +----------------- + +.. autoclass:: MinkowskiEngine.MinkowskiStackCat + :members: forward + :undoc-members: + + .. automethod:: __init__ + + +MinkowskiStackSum +----------------- + +.. autoclass:: MinkowskiEngine.MinkowskiStackSum + :members: forward + :undoc-members: + + .. automethod:: __init__ + + +MinkowskiStackMean +------------------ + +.. autoclass:: MinkowskiEngine.MinkowskiStackMean + :members: forward + :undoc-members: + + .. automethod:: __init__ + + +MinkowskiStackVar +----------------- + +.. autoclass:: MinkowskiEngine.MinkowskiStackVar + :members: forward + :undoc-members: + + .. automethod:: __init__ \ No newline at end of file diff --git a/examples/stack_unet.py b/examples/stack_unet.py new file mode 100644 index 00000000..b7e849c8 --- /dev/null +++ b/examples/stack_unet.py @@ -0,0 +1,100 @@ +# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +# of the Software, and to permit persons to whom the Software is furnished to do +# so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural +# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part +# of the code. +import torch +import torch.nn as nn + +import MinkowskiEngine as ME +import MinkowskiEngine.MinkowskiFunctional as MF + +from tests.python.common import data_loader + + +class StackUNet(ME.MinkowskiNetwork): + def __init__(self, in_nchannel, out_nchannel, D): + ME.MinkowskiNetwork.__init__(self, D) + channels = [in_nchannel, 16, 32] + self.net = nn.Sequential( + ME.MinkowskiStackSum( + ME.MinkowskiConvolution( + channels[0], + channels[1], + kernel_size=3, + stride=1, + dimension=D, + ), + nn.Sequential( + ME.MinkowskiConvolution( + channels[0], + channels[1], + kernel_size=3, + stride=2, + dimension=D, + ), + ME.MinkowskiStackSum( + nn.Identity(), + nn.Sequential( + ME.MinkowskiConvolution( + channels[1], + channels[2], + kernel_size=3, + stride=2, + dimension=D, + ), + ME.MinkowskiConvolutionTranspose( + channels[2], + channels[1], + kernel_size=3, + stride=1, + dimension=D, + ), + ME.MinkowskiPoolingTranspose( + kernel_size=2, stride=2, dimension=D + ), + ), + ), + ME.MinkowskiPoolingTranspose(kernel_size=2, stride=2, dimension=D), + ), + ), + ME.MinkowskiToFeature(), + nn.Linear(channels[1], out_nchannel, bias=True), + ) + + def forward(self, x): + return self.net(x) + + +if __name__ == "__main__": + # loss and network + net = StackUNet(3, 5, D=2) + print(net) + + # a data loader must return a tuple of coords, features, and labels. + coords, feat, label = data_loader() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + net = net.to(device) + input = ME.SparseTensor(feat, coords, device=device) + + # Forward + output = net(input) diff --git a/examples/unet.py b/examples/unet.py index 89348788..21b9fdc6 100644 --- a/examples/unet.py +++ b/examples/unet.py @@ -26,7 +26,7 @@ import MinkowskiEngine as ME import MinkowskiEngine.MinkowskiFunctional as MF -from tests.common import data_loader +from tests.python.common import data_loader class UNet(ME.MinkowskiNetwork): @@ -115,7 +115,7 @@ def forward(self, x): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net = net.to(device) - input = ME.SparseTensor(feat, coords=coords).to(device) + input = ME.SparseTensor(feat, coords, device=device) # Forward output = net(input)