Skip to content

Commit

Permalink
StackX documentation and unet example
Browse files Browse the repository at this point in the history
  • Loading branch information
chrischoy committed Apr 14, 2021
1 parent 3fd51aa commit af1fdac
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
- Updated README for pytorch 1.8.1 support
- Use custom `gpu_storage` instead of thrust vector for faster constructors
- pytorch installation instruction updates
- fix transpose kernel map with `kernel_size == stride_size`
- Update reconstruction and vae examples for v0.5 API
- `stack_unet.py` example, API updates
- `MinkowskiToFeature` layer

## [0.5.2]

Expand Down
20 changes: 20 additions & 0 deletions MinkowskiEngine/MinkowskiOps.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,26 @@ def __repr__(self):
return self.__class__.__name__ + "()"


class MinkowskiToFeature(MinkowskiModuleBase):
r"""
Extract features from a sparse tensor and returns a pytorch tensor.
Can be used to to make a network construction simpler.
Example::
>>> net = nn.Sequential(MinkowskiConvolution(...), MinkowskiGlobalMaxPooling(...), MinkowskiToFeature(), nn.Linear(...))
>>> torch_tensor = net(sparse_tensor)
"""

def forward(self, x: SparseTensor):
assert isinstance(
x, (SparseTensor, TensorField)
), "Invalid input type for MinkowskiToFeature"
return x.F


class MinkowskiStackCat(torch.nn.Sequential):
def forward(self, x):
return cat([module(x) for module in self])
Expand Down
3 changes: 2 additions & 1 deletion MinkowskiEngine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
# of the code.
__version__ = "0.5.2"
__version__ = "0.5.3"

import os
import sys
Expand Down Expand Up @@ -189,6 +189,7 @@
MinkowskiLinear,
MinkowskiToSparseTensor,
MinkowskiToDenseTensor,
MinkowskiToFeature,
MinkowskiStackCat,
MinkowskiStackSum,
MinkowskiStackMean,
Expand Down
48 changes: 48 additions & 0 deletions docs/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,51 @@ MinkowskiToDenseTensor
.. autoclass:: MinkowskiEngine.MinkowskiToDenseTensor

.. automethod:: __init__


MinkowskiToFeature
------------------

.. autoclass:: MinkowskiEngine.MinkowskiToFeature

.. automethod:: __init__


MinkowskiStackCat
-----------------

.. autoclass:: MinkowskiEngine.MinkowskiStackCat
:members: forward
:undoc-members:

.. automethod:: __init__


MinkowskiStackSum
-----------------

.. autoclass:: MinkowskiEngine.MinkowskiStackSum
:members: forward
:undoc-members:

.. automethod:: __init__


MinkowskiStackMean
------------------

.. autoclass:: MinkowskiEngine.MinkowskiStackMean
:members: forward
:undoc-members:

.. automethod:: __init__


MinkowskiStackVar
-----------------

.. autoclass:: MinkowskiEngine.MinkowskiStackVar
:members: forward
:undoc-members:

.. automethod:: __init__
100 changes: 100 additions & 0 deletions examples/stack_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
# of the code.
import torch
import torch.nn as nn

import MinkowskiEngine as ME
import MinkowskiEngine.MinkowskiFunctional as MF

from tests.python.common import data_loader


class StackUNet(ME.MinkowskiNetwork):
def __init__(self, in_nchannel, out_nchannel, D):
ME.MinkowskiNetwork.__init__(self, D)
channels = [in_nchannel, 16, 32]
self.net = nn.Sequential(
ME.MinkowskiStackSum(
ME.MinkowskiConvolution(
channels[0],
channels[1],
kernel_size=3,
stride=1,
dimension=D,
),
nn.Sequential(
ME.MinkowskiConvolution(
channels[0],
channels[1],
kernel_size=3,
stride=2,
dimension=D,
),
ME.MinkowskiStackSum(
nn.Identity(),
nn.Sequential(
ME.MinkowskiConvolution(
channels[1],
channels[2],
kernel_size=3,
stride=2,
dimension=D,
),
ME.MinkowskiConvolutionTranspose(
channels[2],
channels[1],
kernel_size=3,
stride=1,
dimension=D,
),
ME.MinkowskiPoolingTranspose(
kernel_size=2, stride=2, dimension=D
),
),
),
ME.MinkowskiPoolingTranspose(kernel_size=2, stride=2, dimension=D),
),
),
ME.MinkowskiToFeature(),
nn.Linear(channels[1], out_nchannel, bias=True),
)

def forward(self, x):
return self.net(x)


if __name__ == "__main__":
# loss and network
net = StackUNet(3, 5, D=2)
print(net)

# a data loader must return a tuple of coords, features, and labels.
coords, feat, label = data_loader()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net = net.to(device)
input = ME.SparseTensor(feat, coords, device=device)

# Forward
output = net(input)
4 changes: 2 additions & 2 deletions examples/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import MinkowskiEngine as ME
import MinkowskiEngine.MinkowskiFunctional as MF

from tests.common import data_loader
from tests.python.common import data_loader


class UNet(ME.MinkowskiNetwork):
Expand Down Expand Up @@ -115,7 +115,7 @@ def forward(self, x):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

net = net.to(device)
input = ME.SparseTensor(feat, coords=coords).to(device)
input = ME.SparseTensor(feat, coords, device=device)

# Forward
output = net(input)

0 comments on commit af1fdac

Please sign in to comment.