Skip to content

Commit

Permalink
add a flag to enable bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
FindDefinition committed Dec 25, 2023
1 parent 125a194 commit a717f0d
Show file tree
Hide file tree
Showing 31 changed files with 197 additions and 112 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Changelog
## [2.3.7] - 2023-12-25
### Added
- add a flag to enable bf16 `SPCONV_ADD_BF16`, must compile from source.

## [2.3.6] - 2023-04-19
### Fixed
- Fix a CI bug that cpu cumm and spconv use different gcc compiler, must be same.
Expand Down
5 changes: 3 additions & 2 deletions example/mnist/mnist_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ def train(args, model, device, train_loader, optimizer, epoch):
model.train()
scaler = torch.cuda.amp.grad_scaler.GradScaler()
amp_ctx = contextlib.nullcontext()
assert args.fp16
if args.fp16:
amp_ctx = torch.cuda.amp.autocast()
amp_ctx = torch.cuda.amp.autocast(dtype=torch.float16)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
Expand Down Expand Up @@ -107,7 +108,7 @@ def test(args, model, device, test_loader):
correct = 0
amp_ctx = contextlib.nullcontext()
if args.fp16:
amp_ctx = torch.cuda.amp.autocast()
amp_ctx = torch.cuda.amp.autocast(dtype=torch.float16)

with torch.no_grad():
for data, target in test_loader:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm>=0.4.8"]
requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.11", "cumm>=0.5.0"]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu120-0.4.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
# requires = ["setuptools>=41.0", "wheel", "pccm>=0.4.0", "cumm @ file:///io/dist/cumm_cu117-0.4.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"]
build-backend = "setuptools.build_meta"
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
cuda_ver_str = cuda_ver.replace(".", "") # 10.2 to 102

RELEASE_NAME += "-cu{}".format(cuda_ver_str)
deps = ["cumm-cu{}>=0.4.5, <0.5.0".format(cuda_ver_str)]
deps = ["cumm-cu{}>=0.5.0, <0.6.0".format(cuda_ver_str)]
else:
deps = ["cumm>=0.4.5, <0.5.0"]
deps = ["cumm>=0.5.0, <0.6.0"]



Expand All @@ -53,7 +53,7 @@
VERSION = None

# What packages are required for this module to be executed?
REQUIRED = ["pccm>=0.4.0", "ccimport>=0.4.0", "pybind11>=2.6.0", "fire", "numpy", *deps]
REQUIRED = ["pccm>=0.4.11", "ccimport>=0.4.0", "pybind11>=2.6.0", "fire", "numpy", *deps]

# What packages are optional?
EXTRAS = {
Expand Down
5 changes: 4 additions & 1 deletion spconv/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,7 @@ class AllocKeys:

SPCONV_INT8_DEBUG = os.getenv("SPCONV_INT8_DEBUG", "0") == "1"

SPCONV_DO_SORT = os.getenv("SPCONV_DO_SORT", "1") == "1"
SPCONV_DO_SORT = os.getenv("SPCONV_DO_SORT", "1") == "1"


SPCONV_ADD_BF16 = True # not available in release package
138 changes: 72 additions & 66 deletions spconv/core.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/hash/core.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class HashTable:
key_itemsize: int
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/all/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
from cumm.tensorview import CUDAKernelTimer
class ThrustCustomAllocatorV2:
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/all/ops1d.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class Point2Voxel:
hashdata: Tensor
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/all/ops2d.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class Point2Voxel:
hashdata: Tensor
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/all/ops3d.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class Point2Voxel:
hashdata: Tensor
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/all/ops4d.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class Point2Voxel:
hashdata: Tensor
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/all/ops_cpu1d.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class Point2VoxelCPU:
densehashdata: Tensor
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/all/ops_cpu2d.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class Point2VoxelCPU:
densehashdata: Tensor
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/all/ops_cpu3d.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class Point2VoxelCPU:
densehashdata: Tensor
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/all/ops_cpu4d.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class Point2VoxelCPU:
densehashdata: Tensor
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/alloc.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class ExternalAllocator:
def zeros(self, name: str, shape: List[int], dtype: int, device: int, stream: int = 0, is_temp_memory: bool = False, scale: float = 1.0) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/convops/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview import Tensor
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/convops/convops.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview.gemm import ConvAlgoDesp
from cumm.tensorview import Tensor
from cumm.tensorview.gemm import NVRTCParams
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/convops/gemmops.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview import Tensor
from cumm.tensorview.gemm import NVRTCParams
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/convops/spops.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
from cumm.tensorview.gemm import Activation
from cumm.tensorview import CUDAKernelTimer
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/sparse/inference.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
from cumm.tensorview.gemm import Activation
class InferenceOps:
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/csrc/utils/boxops.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class BoxOps:
@staticmethod
Expand Down
8 changes: 4 additions & 4 deletions spconv/core_cc/csrc/utils/pcc.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview import Tensor
class PointCloudCompress:
@staticmethod
Expand Down Expand Up @@ -43,8 +43,8 @@ class PointCloudCompress:
data:
"""
...
class EncodeType:
XYZ_8 = EnumClassValue(0) # type: EnumClassValue
XYZI_8 = EnumClassValue(1) # type: EnumClassValue
class EncodeType(enum.Enum):
XYZ_8 = 0
XYZI_8 = 1
@staticmethod
def __members__() -> Dict[str, EnumClassValue]: ...
2 changes: 1 addition & 1 deletion spconv/core_cc/cumm/common.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
class CompileInfo:
@staticmethod
def get_compiled_cuda_version() -> Tuple[int, int]: ...
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/cumm/conv/main.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview.gemm import ConvParams
class ConvMainUnitTest:
@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion spconv/core_cc/cumm/gemm/main.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import overload, Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from pccm.stubs import EnumValue, EnumClassValue
from pccm.stubs import EnumValue, EnumClassValue, enum
from cumm.tensorview.gemm import GemmAlgoDesp
from cumm.tensorview.gemm import GemmParams
class GemmMainUnitTest:
Expand Down
1 change: 1 addition & 0 deletions spconv/pytorch/cppcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
torch.float32: tv.float32,
torch.float64: tv.float64,
torch.float16: tv.float16,
torch.bfloat16: tv.bfloat16,
torch.int32: tv.int32,
torch.int64: tv.int64,
torch.int8: tv.int8,
Expand Down
62 changes: 54 additions & 8 deletions spconv/pytorch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import sys
import pickle

Expand All @@ -33,6 +34,13 @@
from functools import reduce
from cumm import tensorview as tv

import collections
import collections.abc
import numpy as np
HAS_NUMPY = True
from torch._six import string_classes
from typing import Any

_MAX_INT32 = 2147483647

_T = TypeVar("_T")
Expand All @@ -41,20 +49,57 @@
def identity_decorator(func: _T) -> _T:
return func

# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
# may be falsely detected as "Iterables."
def _cast(value, dtype):
if isinstance(value, torch.Tensor):
is_eligible = (value.is_floating_point() and value.is_cuda and (value.dtype is not torch.float64))
return value.to(dtype) if is_eligible else value
elif isinstance(value, string_classes):
return value
elif HAS_NUMPY and isinstance(value, np.ndarray):
return value
elif isinstance(value, collections.abc.Mapping):
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
elif isinstance(value, collections.abc.Iterable):
iterable = map(lambda v: _cast(v, dtype), value)
if isinstance(value, list) or isinstance(value, tuple):
return type(value)(iterable)
else:
return iterable
else:
return value

if PYTORCH_VERSION >= [1, 6, 0]:
import torch.cuda.amp as amp
_TORCH_CUSTOM_FWD = amp.custom_fwd(cast_inputs=torch.float16)
_TORCH_CUSTOM_FWD = amp.custom_fwd
_TORCH_CUSTOM_BWD = amp.custom_bwd

else:
_TORCH_CUSTOM_FWD = identity_decorator
_TORCH_CUSTOM_BWD = identity_decorator

def custom_fwd_based_on_current_autocast_state(fwd):
if PYTORCH_VERSION < [1, 6, 0]:
return fwd
if fwd is None:
return custom_fwd_based_on_current_autocast_state
@functools.wraps(fwd)
def decorate_fwd(*args, **kwargs):
autocast_context = torch.is_autocast_enabled()
args[0]._fwd_used_autocast = False
if autocast_context:
autocast_dtype = torch.get_autocast_gpu_dtype()
with torch.cuda.amp.autocast(enabled=False):
return fwd(*_cast(args, autocast_dtype), **_cast(kwargs, autocast_dtype))
else:
return fwd(*args, **kwargs)
return decorate_fwd


class SparseConvFunction(Function):
@staticmethod
@_TORCH_CUSTOM_FWD
@custom_fwd_based_on_current_autocast_state
def forward(ctx,
features,
filters,
Expand All @@ -67,6 +112,7 @@ def forward(ctx,
act_alpha: float = 0.0,
act_beta: float = 0.0,
act_type: tv.gemm.Activation = tv.gemm.Activation.None_):

ctx.save_for_backward(indice_pairs, indice_pair_num, features, filters)
ctx.algo = algo
ctx.timer = timer
Expand Down Expand Up @@ -119,7 +165,7 @@ def backward(ctx, grad_output):

class SparseInverseConvFunction(Function):
@staticmethod
@_TORCH_CUSTOM_FWD
@custom_fwd_based_on_current_autocast_state
def forward(ctx,
features,
filters,
Expand Down Expand Up @@ -186,7 +232,7 @@ def backward(ctx, grad_output):

class SparseImplicitGemmFunction(Function):
@staticmethod
@_TORCH_CUSTOM_FWD
@custom_fwd_based_on_current_autocast_state
def forward(ctx,
features: torch.Tensor,
filters: torch.Tensor,
Expand Down Expand Up @@ -288,7 +334,7 @@ def backward(ctx, grad_output):

class SubMConvFunction(Function):
@staticmethod
@_TORCH_CUSTOM_FWD
@custom_fwd_based_on_current_autocast_state
def forward(ctx,
features,
filters,
Expand Down Expand Up @@ -355,7 +401,7 @@ def backward(ctx, grad_output):

class SparseMaxPoolFunction(Function):
@staticmethod
@_TORCH_CUSTOM_FWD
@custom_fwd_based_on_current_autocast_state
def forward(ctx, features, indice_pairs, indice_pair_num,
num_activate_out):
out = ops.indice_maxpool(features, indice_pairs, indice_pair_num,
Expand All @@ -375,7 +421,7 @@ def backward(ctx, grad_output):

class SparseMaxPoolImplicitGemmFunction(Function):
@staticmethod
@_TORCH_CUSTOM_FWD
@custom_fwd_based_on_current_autocast_state
def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor,
indice_pairs_bwd: torch.Tensor, num_activate_out: int):
out = ops.indice_maxpool_implicit_gemm(features, indice_pairs_fwd,
Expand All @@ -395,7 +441,7 @@ def backward(ctx, grad_output):

class SparseAvgPoolImplicitGemmFunction(Function):
@staticmethod
@_TORCH_CUSTOM_FWD
@custom_fwd_based_on_current_autocast_state
def forward(ctx, features: torch.Tensor, indice_pairs_fwd: torch.Tensor,
indice_pairs_bwd: torch.Tensor, num_activate_out: int,
calc_count):
Expand Down
Loading

0 comments on commit a717f0d

Please sign in to comment.