Skip to content

Commit

Permalink
fix GroupNorm init (#10494)
Browse files Browse the repository at this point in the history
Creating GroupNorm with device and dtype throws Exceptions.

```python
import oneflow as flow
m = flow.nn.GroupNorm(2, 3, device=flow.device("cpu"), dtype=flow.float32)
```

Exception messages:
```
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/zhengjianhua/oneflow/python/oneflow/nn/modules/normalization.py", line 140, in __init__
    self.weight = flow.nn.Parameter(flow.Tensor(num_channels, **factory_kwargs))
TypeError: Error: _legacy_tensor_ctor(): received an invalid combination of arguments. The valid signatures are:
        *0: Tensor (*, Device device=None)
        *1: Tensor (*, Placement placement, SbpList sbp)
        *2: Tensor (Tensor other)
        *3: Tensor (PyObject* data, *, Device device=None)
        *4: Tensor (PyObject* data, *, Placement placement, SbpList sbp)
        *5: Tensor (Shape size, *, Device device=None)
        *6: Tensor (Shape size, *, Placement placement, SbpList sbp)
```

---------

Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
fpzh2011 and oneflow-ci-bot committed May 20, 2024
1 parent ea585f6 commit b8c457c
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/oneflow/nn/modules/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,12 @@ def __init__(
if dtype:
factory_kwargs["dtype"] = dtype
if self.affine:
self.weight = flow.nn.Parameter(flow.Tensor(num_channels, **factory_kwargs))
self.bias = flow.nn.Parameter(flow.Tensor(num_channels, **factory_kwargs))
self.weight = flow.nn.Parameter(
flow.Tensor(num_channels).to(**factory_kwargs)
)
self.bias = flow.nn.Parameter(
flow.Tensor(num_channels).to(**factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
Expand Down

0 comments on commit b8c457c

Please sign in to comment.