Skip to content

Commit

Permalink
Implement CPP axpy where the alpha is a Tensor.
Browse files Browse the repository at this point in the history
axpy(const Tensor &alpha, ..), where alpha has a single element.

print the DType and Lang in the log output

handle data types (int and float) in tensor.cc
  • Loading branch information
nudles committed Aug 9, 2021
1 parent 03a1ba3 commit 9c0d5db
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 316 deletions.
51 changes: 22 additions & 29 deletions python/singa/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,72 +674,65 @@ def __add__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__add__(self.data, rhs.data))
else:
return _call_singa_func(singa.AddFloat, self.data, rhs)

return _call_singa_func(singa.AddFloat, self.data, float(rhs))
def __sub__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__sub__(self.data, rhs.data))
else:
return _call_singa_func(singa.SubFloat, self.data, rhs)

return _call_singa_func(singa.SubFloat, self.data, float(rhs))
def __mul__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__mul__(self.data, rhs.data))
else:
return _call_singa_func(singa.MultFloat, self.data, rhs)

return _call_singa_func(singa.MultFloat, self.data, float(rhs))
def __div__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__div__(self.data, rhs.data))
else:
return _call_singa_func(singa.DivFloat, self.data, rhs)

return _call_singa_func(singa.DivFloat, self.data, float(rhs))
def __truediv__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__div__(self.data, rhs.data))
else:
return _call_singa_func(singa.DivFloat, self.data, rhs)

return self.__div__(rhs)

def __floordiv__(self, rhs):
if isinstance(rhs, Tensor):
tmp = from_raw_tensor(singa.__div__(self.data, rhs.data))
return _call_singa_func(singa.Floor, tmp.data)
else:
tmp = _call_singa_func(singa.DivFloat, self.data, rhs)
return _call_singa_func(singa.Floor, tmp.data)
tmp = self.__div__(rhs)
return _call_singa_func(singa.Floor, tmp.data)

def __lt__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__lt__(self.data, rhs.data))
else:
return _call_singa_func(singa.LTFloat, self.data, rhs)

return _call_singa_func(singa.LTFloat, self.data, float(rhs))
def __le__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__le__(self.data, rhs.data))
else:
return _call_singa_func(singa.LEFloat, self.data, rhs)

return _call_singa_func(singa.LEFloat, self.data, float(rhs))
def __gt__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__gt__(self.data, rhs.data))
else:
return _call_singa_func(singa.GTFloat, self.data, rhs)

return _call_singa_func(singa.GTFloat, self.data, float(rhs))
def __ge__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__ge__(self.data, rhs.data))
else:
return _call_singa_func(singa.GEFloat, self.data, rhs)

return _call_singa_func(singa.GEFloat, self.data, float(rhs))
def __eq__(self, rhs):
if isinstance(rhs, Tensor):
return from_raw_tensor(singa.__eq__(self.data, rhs.data))
elif rhs is None:
return False
else:
return _call_singa_func(singa.EQFloat, self.data, rhs)

return _call_singa_func(singa.EQFloat, self.data, float(rhs))
def __radd__(self, lhs):
lhs = float(lhs)
one = Tensor(self.shape, self.device, self.dtype)
Expand Down
109 changes: 76 additions & 33 deletions src/core/tensor/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1037,33 +1037,59 @@ Tensor SoftMaxBackward(const Tensor &in, int axis, const Tensor &fdout) {
}); \
} while (0)

#define GenBinaryTensorFn(op, fn) \
Tensor op(const Tensor &lhs, const Tensor &rhs) { \
if (lhs.shape() != rhs.shape()) { \
auto lhs_ = Broadcast(lhs, rhs.shape()); \
auto rhs_ = Broadcast(rhs, lhs.shape()); \
Tensor ret(lhs_.shape(), lhs.device(), lhs.data_type()); \
fn(lhs_, rhs_, &ret); \
return ret; \
} else { \
Tensor ret(lhs.shape(), lhs.device(), lhs.data_type()); \
fn(lhs, rhs, &ret); \
return ret; \
} \
} \
void fn(const Tensor &lhs, const Tensor &rhs, Tensor *ret) { \
CHECK_EQ(lhs.device(), ret->device()); \
CHECK_EQ(rhs.device(), ret->device()); \
if (lhs.shape() != rhs.shape()) { \
auto lhs_ = Broadcast(lhs, rhs.shape()); \
auto rhs_ = Broadcast(rhs, lhs.shape()); \
CHECK(lhs_.shape() == ret->shape()); \
EltwiseBinaryTensorFn(fn, lhs_, rhs_, ret); \
} else { \
CHECK(lhs.shape() == ret->shape()); \
EltwiseBinaryTensorFn(fn, lhs, rhs, ret); \
} \
} // namespace singa
#define GenBinaryTensorFn(op, fn) \
Tensor op(const Tensor &lhs, const Tensor &rhs) { \
if (lhs.shape() != rhs.shape()) { \
if (lhs.data_type() == kFloat32 && rhs.data_type() == kFloat32) { \
auto lhs_ = Broadcast(lhs, rhs.shape()); \
auto rhs_ = Broadcast(rhs, lhs.shape()); \
Tensor ret(lhs_.shape(), lhs.device(), lhs.data_type()); \
fn(lhs_, rhs_, &ret); \
return ret; \
} else { \
/* lhs tensor and rhs tensor are not both in float, cast to float */\
Tensor tmp_lhs = lhs.Clone().AsType(kFloat32); \
Tensor tmp_rhs = rhs.Clone().AsType(kFloat32); \
tmp_lhs = Broadcast(tmp_lhs, tmp_rhs.shape()); \
tmp_rhs = Broadcast(tmp_rhs, tmp_lhs.shape()); \
Tensor ret(tmp_lhs.shape(), tmp_lhs.device(), tmp_lhs.data_type()); \
fn(tmp_lhs, tmp_rhs, &ret); \
/* if lhs and rhs are both int, cast back to int */ \
if (lhs.data_type() == kInt && rhs.data_type() == kInt) \
return ret.Clone().AsType(kInt); \
return ret; \
} \
} else { \
if (lhs.data_type() == kFloat32 && rhs.data_type() == kFloat32) { \
Tensor ret(lhs.shape(), lhs.device(), lhs.data_type()); \
fn(lhs, rhs, &ret); \
return ret; \
} else { \
/* lhs tensor and rhs tensor are not both in float, cast to float */\
Tensor tmp_lhs = lhs.Clone().AsType(kFloat32); \
Tensor tmp_rhs = rhs.Clone().AsType(kFloat32); \
Tensor ret(tmp_lhs.shape(), tmp_lhs.device(), tmp_lhs.data_type()); \
fn(tmp_lhs, tmp_rhs, &ret); \
/* if lhs and rhs are both int, cast back to int */ \
if (lhs.data_type() == kInt && rhs.data_type() == kInt) \
return ret.Clone().AsType(kInt); \
return ret; \
} \
} \
} \
void fn(const Tensor &lhs, const Tensor &rhs, Tensor *ret) { \
CHECK_EQ(lhs.device(), ret->device()); \
CHECK_EQ(rhs.device(), ret->device()); \
if (lhs.shape() != rhs.shape()) { \
auto lhs_ = Broadcast(lhs, rhs.shape()); \
auto rhs_ = Broadcast(rhs, lhs.shape()); \
CHECK(lhs_.shape() == ret->shape()); \
EltwiseBinaryTensorFn(fn, lhs_, rhs_, ret); \
} else { \
CHECK(lhs.shape() == ret->shape()); \
EltwiseBinaryTensorFn(fn, lhs, rhs, ret); \
} \
}

// boradcasting operations:
// https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md
Expand Down Expand Up @@ -1093,12 +1119,29 @@ GenBinaryTensorFn(ReLUBackward, ReLUBackward);
} while (0)

#define GenTensorScalarFn(op, fn) \
template <typename SType> \
Tensor op(const Tensor &in, const SType x) { \
Tensor ret(in.shape(), in.device(), in.data_type()); \
fn(in, x, &ret); \
return ret; \
} \
template <typename SType> \
Tensor op(const Tensor &in, const SType x) { \
if (in.data_type() == kFloat32 && std::is_same<SType, float>::value){ \
Tensor ret(in.shape(), in.device(), in.data_type()); \
fn(in, x, &ret); \
return ret; \
} else if (in.data_type() == kFloat32) { \
Tensor ret(in.shape(), in.device(), in.data_type()); \
float tmp_x = x; \
fn(in, tmp_x, &ret); \
return ret; \
} else { \
/* tensor and scalar are not both in float, cast to float */ \
Tensor tmp_in = in.Clone().AsType(kFloat32); \
float tmp_x = x; \
Tensor ret(tmp_in.shape(), tmp_in.device(), tmp_in.data_type()); \
fn(tmp_in, tmp_x, &ret); \
/* if tensor and scalar are both int, cast back to int */ \
if (in.data_type() == kInt && std::is_same<SType, int>::value) \
return ret.Clone().AsType(kInt); \
return ret; \
} \
} \
template <typename SType> \
void fn(const Tensor &in, const SType x, Tensor *ret) { \
EltwiseTensorScalarFn(fn, in, x, ret); \
Expand Down
Loading

0 comments on commit 9c0d5db

Please sign in to comment.