From 9c0d5db71f32f9b03fe95d440f7fe80342dc6bef Mon Sep 17 00:00:00 2001 From: fredwang Date: Sat, 7 Aug 2021 16:24:05 +0800 Subject: [PATCH] Implement CPP axpy where the alpha is a Tensor. axpy(const Tensor &alpha, ..), where alpha has a single element. print the DType and Lang in the log output handle data types (int and float) in tensor.cc --- python/singa/tensor.py | 51 ++++----- src/core/tensor/tensor.cc | 109 +++++++++++++------ src/core/tensor/tensor_math.h | 175 +++++++++++------------------- src/core/tensor/tensor_math_cpp.h | 161 ++++----------------------- test/python/test_tensor.py | 3 +- tool/conda/singa/meta.yaml | 2 +- 6 files changed, 185 insertions(+), 316 deletions(-) diff --git a/python/singa/tensor.py b/python/singa/tensor.py index e9e9ae76b9..963ad1a002 100755 --- a/python/singa/tensor.py +++ b/python/singa/tensor.py @@ -674,72 +674,65 @@ def __add__(self, rhs): if isinstance(rhs, Tensor): return from_raw_tensor(singa.__add__(self.data, rhs.data)) else: - return _call_singa_func(singa.AddFloat, self.data, rhs) - + return _call_singa_func(singa.AddFloat, self.data, float(rhs)) + def __sub__(self, rhs): if isinstance(rhs, Tensor): return from_raw_tensor(singa.__sub__(self.data, rhs.data)) else: - return _call_singa_func(singa.SubFloat, self.data, rhs) - + return _call_singa_func(singa.SubFloat, self.data, float(rhs)) + def __mul__(self, rhs): if isinstance(rhs, Tensor): return from_raw_tensor(singa.__mul__(self.data, rhs.data)) else: - return _call_singa_func(singa.MultFloat, self.data, rhs) - + return _call_singa_func(singa.MultFloat, self.data, float(rhs)) + def __div__(self, rhs): if isinstance(rhs, Tensor): return from_raw_tensor(singa.__div__(self.data, rhs.data)) else: - return _call_singa_func(singa.DivFloat, self.data, rhs) - + return _call_singa_func(singa.DivFloat, self.data, float(rhs)) + def __truediv__(self, rhs): - if isinstance(rhs, Tensor): - return from_raw_tensor(singa.__div__(self.data, rhs.data)) - else: - return _call_singa_func(singa.DivFloat, self.data, rhs) - + return self.__div__(rhs) + def __floordiv__(self, rhs): - if isinstance(rhs, Tensor): - tmp = from_raw_tensor(singa.__div__(self.data, rhs.data)) - return _call_singa_func(singa.Floor, tmp.data) - else: - tmp = _call_singa_func(singa.DivFloat, self.data, rhs) - return _call_singa_func(singa.Floor, tmp.data) + tmp = self.__div__(rhs) + return _call_singa_func(singa.Floor, tmp.data) def __lt__(self, rhs): if isinstance(rhs, Tensor): return from_raw_tensor(singa.__lt__(self.data, rhs.data)) else: - return _call_singa_func(singa.LTFloat, self.data, rhs) - + return _call_singa_func(singa.LTFloat, self.data, float(rhs)) + def __le__(self, rhs): if isinstance(rhs, Tensor): return from_raw_tensor(singa.__le__(self.data, rhs.data)) else: - return _call_singa_func(singa.LEFloat, self.data, rhs) - + return _call_singa_func(singa.LEFloat, self.data, float(rhs)) + def __gt__(self, rhs): if isinstance(rhs, Tensor): return from_raw_tensor(singa.__gt__(self.data, rhs.data)) else: - return _call_singa_func(singa.GTFloat, self.data, rhs) - + return _call_singa_func(singa.GTFloat, self.data, float(rhs)) + def __ge__(self, rhs): if isinstance(rhs, Tensor): return from_raw_tensor(singa.__ge__(self.data, rhs.data)) else: - return _call_singa_func(singa.GEFloat, self.data, rhs) - + return _call_singa_func(singa.GEFloat, self.data, float(rhs)) + def __eq__(self, rhs): if isinstance(rhs, Tensor): return from_raw_tensor(singa.__eq__(self.data, rhs.data)) elif rhs is None: return False else: - return _call_singa_func(singa.EQFloat, self.data, rhs) - + return _call_singa_func(singa.EQFloat, self.data, float(rhs)) + def __radd__(self, lhs): lhs = float(lhs) one = Tensor(self.shape, self.device, self.dtype) diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index 08e5d412fa..46b82aac98 100644 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -1037,33 +1037,59 @@ Tensor SoftMaxBackward(const Tensor &in, int axis, const Tensor &fdout) { }); \ } while (0) -#define GenBinaryTensorFn(op, fn) \ - Tensor op(const Tensor &lhs, const Tensor &rhs) { \ - if (lhs.shape() != rhs.shape()) { \ - auto lhs_ = Broadcast(lhs, rhs.shape()); \ - auto rhs_ = Broadcast(rhs, lhs.shape()); \ - Tensor ret(lhs_.shape(), lhs.device(), lhs.data_type()); \ - fn(lhs_, rhs_, &ret); \ - return ret; \ - } else { \ - Tensor ret(lhs.shape(), lhs.device(), lhs.data_type()); \ - fn(lhs, rhs, &ret); \ - return ret; \ - } \ - } \ - void fn(const Tensor &lhs, const Tensor &rhs, Tensor *ret) { \ - CHECK_EQ(lhs.device(), ret->device()); \ - CHECK_EQ(rhs.device(), ret->device()); \ - if (lhs.shape() != rhs.shape()) { \ - auto lhs_ = Broadcast(lhs, rhs.shape()); \ - auto rhs_ = Broadcast(rhs, lhs.shape()); \ - CHECK(lhs_.shape() == ret->shape()); \ - EltwiseBinaryTensorFn(fn, lhs_, rhs_, ret); \ - } else { \ - CHECK(lhs.shape() == ret->shape()); \ - EltwiseBinaryTensorFn(fn, lhs, rhs, ret); \ - } \ - } // namespace singa +#define GenBinaryTensorFn(op, fn) \ + Tensor op(const Tensor &lhs, const Tensor &rhs) { \ + if (lhs.shape() != rhs.shape()) { \ + if (lhs.data_type() == kFloat32 && rhs.data_type() == kFloat32) { \ + auto lhs_ = Broadcast(lhs, rhs.shape()); \ + auto rhs_ = Broadcast(rhs, lhs.shape()); \ + Tensor ret(lhs_.shape(), lhs.device(), lhs.data_type()); \ + fn(lhs_, rhs_, &ret); \ + return ret; \ + } else { \ + /* lhs tensor and rhs tensor are not both in float, cast to float */\ + Tensor tmp_lhs = lhs.Clone().AsType(kFloat32); \ + Tensor tmp_rhs = rhs.Clone().AsType(kFloat32); \ + tmp_lhs = Broadcast(tmp_lhs, tmp_rhs.shape()); \ + tmp_rhs = Broadcast(tmp_rhs, tmp_lhs.shape()); \ + Tensor ret(tmp_lhs.shape(), tmp_lhs.device(), tmp_lhs.data_type()); \ + fn(tmp_lhs, tmp_rhs, &ret); \ + /* if lhs and rhs are both int, cast back to int */ \ + if (lhs.data_type() == kInt && rhs.data_type() == kInt) \ + return ret.Clone().AsType(kInt); \ + return ret; \ + } \ + } else { \ + if (lhs.data_type() == kFloat32 && rhs.data_type() == kFloat32) { \ + Tensor ret(lhs.shape(), lhs.device(), lhs.data_type()); \ + fn(lhs, rhs, &ret); \ + return ret; \ + } else { \ + /* lhs tensor and rhs tensor are not both in float, cast to float */\ + Tensor tmp_lhs = lhs.Clone().AsType(kFloat32); \ + Tensor tmp_rhs = rhs.Clone().AsType(kFloat32); \ + Tensor ret(tmp_lhs.shape(), tmp_lhs.device(), tmp_lhs.data_type()); \ + fn(tmp_lhs, tmp_rhs, &ret); \ + /* if lhs and rhs are both int, cast back to int */ \ + if (lhs.data_type() == kInt && rhs.data_type() == kInt) \ + return ret.Clone().AsType(kInt); \ + return ret; \ + } \ + } \ + } \ + void fn(const Tensor &lhs, const Tensor &rhs, Tensor *ret) { \ + CHECK_EQ(lhs.device(), ret->device()); \ + CHECK_EQ(rhs.device(), ret->device()); \ + if (lhs.shape() != rhs.shape()) { \ + auto lhs_ = Broadcast(lhs, rhs.shape()); \ + auto rhs_ = Broadcast(rhs, lhs.shape()); \ + CHECK(lhs_.shape() == ret->shape()); \ + EltwiseBinaryTensorFn(fn, lhs_, rhs_, ret); \ + } else { \ + CHECK(lhs.shape() == ret->shape()); \ + EltwiseBinaryTensorFn(fn, lhs, rhs, ret); \ + } \ + } // boradcasting operations: // https://github.com/onnx/onnx/blob/master/docs/Broadcasting.md @@ -1093,12 +1119,29 @@ GenBinaryTensorFn(ReLUBackward, ReLUBackward); } while (0) #define GenTensorScalarFn(op, fn) \ - template \ - Tensor op(const Tensor &in, const SType x) { \ - Tensor ret(in.shape(), in.device(), in.data_type()); \ - fn(in, x, &ret); \ - return ret; \ - } \ + template \ + Tensor op(const Tensor &in, const SType x) { \ + if (in.data_type() == kFloat32 && std::is_same::value){ \ + Tensor ret(in.shape(), in.device(), in.data_type()); \ + fn(in, x, &ret); \ + return ret; \ + } else if (in.data_type() == kFloat32) { \ + Tensor ret(in.shape(), in.device(), in.data_type()); \ + float tmp_x = x; \ + fn(in, tmp_x, &ret); \ + return ret; \ + } else { \ + /* tensor and scalar are not both in float, cast to float */ \ + Tensor tmp_in = in.Clone().AsType(kFloat32); \ + float tmp_x = x; \ + Tensor ret(tmp_in.shape(), tmp_in.device(), tmp_in.data_type()); \ + fn(tmp_in, tmp_x, &ret); \ + /* if tensor and scalar are both int, cast back to int */ \ + if (in.data_type() == kInt && std::is_same::value) \ + return ret.Clone().AsType(kInt); \ + return ret; \ + } \ + } \ template \ void fn(const Tensor &in, const SType x, Tensor *ret) { \ EltwiseTensorScalarFn(fn, in, x, ret); \ diff --git a/src/core/tensor/tensor_math.h b/src/core/tensor/tensor_math.h index 2e6a08ae6c..c294c4ea18 100644 --- a/src/core/tensor/tensor_math.h +++ b/src/core/tensor/tensor_math.h @@ -55,6 +55,9 @@ namespace singa { /// 7. Use size_t for the number of elements, rows or columns. /// 8. Use the same name for the Tensor and Tensor level math functions. +#define LOG_FATAL(Op, DType, Lang) \ + LOG(FATAL) << Op << " not Implemented for DType=" << typeid(DType).name() << " Lang=" << typeid(Lang).name() + const std::string vec2str(const std::vector &vec) { std::ostringstream vts; if (!vec.empty()) { @@ -83,62 +86,63 @@ const std::string vec2str(const std::vector &vec) { /// out[i] = |in[i]| template void Abs(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Abs Not Implemented"; + LOG_FATAL("Abs", DType, Lang); } template void Erf(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Erf Not Implemented"; + LOG_FATAL("Erf", DType, Lang); } template void CastCopy(const Tensor *src, Tensor *dst, Context *ctx) { - LOG(FATAL) << "CastCopy Not Implemented"; + LOG(FATAL) << "CastCopy not Implemented for DTypeSrc=" << typeid(DTypeSrc).name() + << " DTypeDst=" << typeid(DTypeDst).name() << " Lang=" << typeid(Lang).name(); } template void Ceil(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Ceil Not Implemented"; + LOG_FATAL("Ceil", DType, Lang); } template void Floor(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Floor Not Implemented"; + LOG_FATAL("Floor", DType, Lang); } template void Round(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Round Not Implemented"; + LOG_FATAL("Round", DType, Lang); } template void RoundE(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Round Not Implemented"; + LOG_FATAL("RoundE", DType, Lang); } /// out[i] = in[i] + x template void Add(const Tensor &in, const DType x, Tensor *out, Context *ctx) { - LOG(FATAL) << "Add Not Implemented"; + LOG_FATAL("Add", DType, Lang); } /// out[i] = in1[i] + in2[i] template void Add(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "Add-Pair Not Implemented"; + LOG_FATAL("Add-Pair", DType, Lang); } /// Clamp every element into [low, high] /// if in[i]>high, then out[i]=high; if in[i] void Clamp(const DType low, const DType high, const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Clamp Not Implemented"; + LOG_FATAL("Clamp", DType, Lang); } /// out[i] = x / in[i] template void Div(const DType x, const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Div Not Implemented"; + LOG_FATAL("Div", DType, Lang); } /// out[i] = in[i] / x @@ -151,140 +155,140 @@ void Div(const Tensor &in, const DType x, Tensor *out, Context *ctx) { /// out[i] = in1[i] / in2[i] template void Div(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "Div-Pair Not Implemented"; + LOG_FATAL("Div-Pair", DType, Lang); } /// out[i] = in[i] * x template void EltwiseMult(const Tensor &in, const DType x, Tensor *out, Context *ctx) { - LOG(FATAL) << "EltwiseMult Not Implemented"; + LOG_FATAL("EltwiseMult", DType, Lang); } /// out[i] = in1[i] * in2[i] template void EltwiseMult(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "EltwiseMult-Pair Not Implemented"; + LOG_FATAL("EltwiseMult-Pair", DType, Lang); } /// out[i]=(in2[i]>0)?in1[i]:0.f template void ReLUBackward(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "ReLUBackward Not Implemented"; + LOG_FATAL("ReLUBackward", DType, Lang); } /// Base is e, Neper number. out[i]=exp(in[i]) template void Exp(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Exp Not Implemented"; + LOG_FATAL("Exp", DType, Lang); } /// out[i]=(in[i]<=x)?1.f:0.f template void LE(const Tensor &in, const DType x, Tensor *out, Context *ctx) { - LOG(FATAL) << "LE Not Implemented"; + LOG_FATAL("LE", DType, Lang); } /// out[i]=(in1[i]<=in2[i])?1.f:0.f template void LE(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "Tensor-Tensor LE Not Implemented"; + LOG_FATAL("Tensor <= Tensor", DType, Lang); } /// Natural logarithm, the base is e, Neper number out[i]=log(in[i]). template void Log(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Log Not Implemented"; + LOG_FATAL("Log", DType, Lang); } /// out[i]=(in[i] void LT(const Tensor &in, const DType x, Tensor *out, Context *ctx) { - LOG(FATAL) << "LT Not Implemented"; + LOG_FATAL("LT", DType, Lang); } /// out[i]=(in1[i] void LT(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "Tensor-Tensor LT Not Implemented"; + LOG_FATAL("Tensor Tensor LT", DType, Lang); } /// out[i]=(in[i]>=x)?1.f:0.f template void GE(const Tensor &in, const DType x, Tensor *out, Context *ctx) { - LOG(FATAL) << "GE Not Implemented"; + LOG_FATAL("GE", DType, Lang); } /// out[i]=(in1[i]>=in2[i])?1.f:0.f template void GE(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "Tensor-Tensor GE Not Implemented"; + LOG_FATAL("Tensor Tensor GE", DType, Lang); } /// out[i]=(in[i]>x)?1.f:0.f template void GT(const Tensor &in, const DType x, Tensor *out, Context *ctx) { - LOG(FATAL) << "GT Not Implemented"; + LOG_FATAL("GT", DType, Lang); } /// out[i]=(in[i]>in2[i])?1.f:0.f template void GT(const Tensor &in, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "Tensor-Tensor GT Not Implemented"; + LOG_FATAL("Tensor Tensor GT", DType, Lang); } /// out[i]=(in[i]==x)?1.f:0.f template void EQ(const Tensor &in, const DType x, Tensor *out, Context *ctx) { - LOG(FATAL) << "EQ Not Implemented"; + LOG_FATAL("EQ", DType, Lang); } /// out[i]=(in[i]==in2[i])?1.f:0.f template void EQ(const Tensor &in, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "Tensor-Tensor EQ Not Implemented"; + LOG_FATAL("Tensor Tensor EQ", DType, Lang); } /// out[i] = pow(in[i], x) template void Pow(const Tensor &in, const DType x, Tensor *out, Context *ctx) { - LOG(FATAL) << "Pow Not Implemented"; + LOG_FATAL("Pow", DType, Lang); } /// out[i]=pow(in1[i], in2[i]) template void Pow(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "Pow-Pair Not Implemented"; + LOG_FATAL("Tensor Tensor Pow", DType, Lang); } /// out[i]=max(0, in[i]) template void ReLU(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "ReLU Not Implemented"; + LOG_FATAL("ReLU", DType, Lang); } /// out[i] = x template void Set(const DType x, Tensor *out, Context *ctx) { - LOG(FATAL) << "Set Not Implemented"; + LOG_FATAL("Set", DType, Lang); } /// out[i]=sigmoid(in[i]) template void Sigmoid(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Sigmoid Not Implemented"; + LOG_FATAL("Sigmoid", DType, Lang); } /// out[i] = log(exp(in[i]) + 1) template void SoftPlus(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "SoftPlus Not Implemented"; + LOG_FATAL("SoftPlus", DType, Lang); } /// out[i] = in[i] / (abs(in[i]) + 1) template void SoftSign(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "SoftSign Not Implemented"; + LOG_FATAL("SoftSign", DType, Lang); } /// out[i] = sign(in[i]) template void Sign(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Sign Not Implemented"; + LOG_FATAL("Sign", DType, Lang); } /// out[i]=sqrt(in[i]) template void Sqrt(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Sqrt Not Implemented"; + LOG_FATAL("Sqrt", DType, Lang); } /// out[i]=square(in[i]) @@ -302,13 +306,13 @@ void Sub(const Tensor &in, const DType x, Tensor *out, Context *ctx) { /// out[i] = in1[i] - in2[i] template void Sub(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "Sub-Pair Not Implemented"; + LOG_FATAL("Tensor Tensor Sub", DType, Lang); } /// sum all elements of in into out template void Sum(const Tensor &in, DType *out, Context *ctx) { - LOG(FATAL) << "Sum Not Implemented"; + LOG_FATAL("Sum", DType, Lang); } /// out[i]=fn(in[i]) @@ -316,8 +320,7 @@ void Sum(const Tensor &in, DType *out, Context *ctx) { template \ void fn(const Tensor &in, Tensor *out, Context *ctx) { \ std::string str = stringfn; \ - str += " Not Implemented"; \ - LOG(FATAL) << str; \ + LOG_FATAL(str, DType, Lang); \ } GenUnaryNotImplemented(Cos, "Cos"); @@ -339,7 +342,7 @@ GenUnaryNotImplemented(Atanh, "Atanh"); /// strides template void Transform(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Transform Not Implemented"; + LOG_FATAL("Transform", DType, Lang); } // ************************************** @@ -350,19 +353,19 @@ void Transform(const Tensor &in, Tensor *out, Context *ctx) { // If DType is not float, then convert the threshold to DType template void Bernoulli(const float p, Tensor *out, Context *ctx) { - LOG(FATAL) << "Bernoulli Not Implemented"; + LOG_FATAL("Bernoulli", DType, Lang); } // The random generator should be extracted from ctx. // If DType is not float, then convert the mean and std to DType template void Gaussian(const DType mean, const DType std, Tensor *out, Context *ctx) { - LOG(FATAL) << "Gaussian Not Implemented"; + LOG_FATAL("Gaussian", DType, Lang); } // The random generator should be extracted from ctx. // If DType is not float, then convert the low and high to DType template void Uniform(const DType low, const DType high, Tensor *out, Context *ctx) { - LOG(FATAL) << "Uniform Not Implemented"; + LOG_FATAL("Uniform", DType, Lang); } // ********************************************************* @@ -372,52 +375,52 @@ void Uniform(const DType low, const DType high, Tensor *out, Context *ctx) { /// outurn the index of the element with the max value. template void Amax(const Tensor &in, size_t *out, Context *ctx) { - LOG(FATAL) << "Amax Not Implemented"; + LOG_FATAL("Amax", DType, Lang); } /// outurn the index of the element with the min value. template void Amin(const Tensor &in, size_t *out, Context *ctx) { - LOG(FATAL) << "Amin Not Implemented"; + LOG_FATAL("Amin", DType, Lang); } /// out = sum |x| for all x in in template void Asum(const Tensor &in, DType *out, Context *ctx) { - LOG(FATAL) << "Asum Not Implemented"; + LOG_FATAL("Asum", DType, Lang); } /// out = alpha * in + out template void Axpy(const DType alpha, const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Axpy Not Implemented"; + LOG_FATAL("Axpy", DType, Lang); } /// out = alpha * in + out template void Axpy(const Tensor &alpha, const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Axpy Not Implemented"; + LOG_FATAL("Axpy Tensor alpha", DType, Lang); } /// out = ||in||_2^2, i.e, L2 norm. template void Nrm2(const Tensor &in, float *out, Context *ctx) { - LOG(FATAL) << "Nrm2 Not Implemented"; + LOG_FATAL("Nrm2", DType, Lang); } /// out *= x template void Scale(const DType x, Tensor *out, Context *ctx) { - LOG(FATAL) << "Scale Not Implemented"; + LOG_FATAL("Scale", DType, Lang); } /// inner product of array in1 and in2 template void Dot(const Tensor &in1, const Tensor &in2, DType *out, Context *ctx) { - LOG(FATAL) << "Dot Not Implemented"; + LOG_FATAL("Inner-product Dot", DType, Lang); } template void Dot(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { - LOG(FATAL) << "Dot Not Implemented"; + LOG_FATAL("Dot", DType, Lang); } /// out = alpha * A * v + beta * out. @@ -425,7 +428,7 @@ void Dot(const Tensor &in1, const Tensor &in2, Tensor *out, Context *ctx) { template void GEMV(const DType alpha, const Tensor &A, const Tensor &v, const DType beta, Tensor *out, Context *ctx) { - LOG(FATAL) << "GEMV Not Implemented"; + LOG_FATAL("GEMV", DType, Lang); } /// multiply a matrix with a diagnoal matrix constructed using values from 'v'. @@ -433,7 +436,7 @@ void GEMV(const DType alpha, const Tensor &A, const Tensor &v, const DType beta, template void DGMM(const bool side_right, const Tensor &M, const Tensor &v, Tensor *out, Context *ctx) { - LOG(FATAL) << "DGMM Not Implemented"; + LOG_FATAL("DGMM", DType, Lang); } /// C = alpha * A * B + beta * C. @@ -441,24 +444,24 @@ void DGMM(const bool side_right, const Tensor &M, const Tensor &v, Tensor *out, template void GEMM(const DType alpha, const Tensor &A, const Tensor &B, const DType beta, Tensor *C, Context *ctx) { - LOG(FATAL) << "GEMM Not Implemented"; + LOG_FATAL("GEMM", DType, Lang); } template void GEMMBatched(const DType alpha, const Tensor &A, const Tensor &B, const DType beta, Tensor *C, Context *ctx) { - LOG(FATAL) << "GEMM Batched Not Implemented"; + LOG_FATAL("GEMMBatched", DType, Lang); } template void SoftMax(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Not Implemented"; + LOG_FATAL("SoftMax", DType, Lang); } template void SoftMaxBackward(const Tensor &in, Tensor *out, const Tensor &fdout, Context *ctx) { - LOG(FATAL) << "Not Implemented"; + LOG_FATAL("SoftMaxBackend", DType, Lang); } // yisen todo @@ -466,68 +469,20 @@ template void ComputeCrossEntropy(bool int_target, const size_t batchsize, const size_t dim, const Tensor &p, const Tensor &t, Tensor *loss, Context *ctx) { - LOG(FATAL) << "Not Implemented"; + LOG_FATAL("ComputeCrossEntropy", DType, Lang); } template void SoftmaxCrossEntropyBwd(bool int_target, const size_t batchsize, const size_t dim, const Tensor &p, const Tensor &t, Tensor *grad, Context *ctx) { - LOG(FATAL) << "Not Implemented"; + LOG_FATAL("ComputeCrossEntropyBwd", DType, Lang); } template void RowMax(const Tensor &in, Tensor *out, Context *ctx) { - LOG(FATAL) << "Not Implemented"; -} -// ************************************** -// Matrix functions -// ************************************** -/* -/// Add the vector v to every column of A as the column of out -template -void AddCol(const size_t nrow, const size_t ncol, const Tensor &A, const Tensor -&v, - Tensor *out, Context *ctx) { - LOG(FATAL) << "AddCol Not Implemented"; -} -// TODO(wangwei) unify AddRow and AddCol. -/// Add the vector v to every row of A as the row of out -template -void AddRow(const size_t nrow, const size_t ncol, const Tensor &A, const Tensor -&v, - Tensor *out, Context *ctx) { - LOG(FATAL) << "AddRow Not Implemented"; -} -/// outer-product. -/// in1 and in2 are vectors of len m and n. out is matrix of shape m * n -template -void Outer(const size_t m, const size_t n, const Tensor &in1, const Tensor &in2, - Tensor *out, Context *ctx) { - LOG(FATAL) << "Outer Not Implemented"; -} - -/// Sum the columns of the in matrix into a vector -template -void SumColumns(const size_t nrow, const size_t ncol, const Tensor &in, Tensor -*out, - Context *ctx) { - LOG(FATAL) << "SumColumns Not Implemented"; -} -template -void Set(const DType x, Tensor *out, Context *ctx) { - LOG(FATAL) << "Not Implemented"; -} - -// TODO(wangwei) unify SumRow and SumCol. -/// Sum the rows of the in matrix into a vector -template -void SumRows(const size_t nrow, const size_t ncol, const Tensor &in, Tensor -*out, - Context *ctx) { - LOG(FATAL) << "SumRows Not Implemented"; + LOG_FATAL("RowMax", DType, Lang); } -*/ } // namespace singa #endif // SINGA_CORE_MATH_H_ diff --git a/src/core/tensor/tensor_math_cpp.h b/src/core/tensor/tensor_math_cpp.h index 2c06f63241..b3113abc3e 100644 --- a/src/core/tensor/tensor_math_cpp.h +++ b/src/core/tensor/tensor_math_cpp.h @@ -785,20 +785,6 @@ void Asum(const Tensor &in, float *out, Context *ctx) { *out = cblas_sasum(in.Size(), inPtr, 1); // not using strided traversal } -// template <> -// void Axpy(const float alpha, -// const Tensor& in, Tensor *out, Context *ctx) { -// //check input tensor for strides first -// if (in.stride() == out->stride()) { -// const float *inPtr = static_cast(in.block()->data()); -// float *outPtr = static_cast(out->block()->mutable_data()); -// cblas_saxpy(in.Size(), alpha, inPtr, 1, outPtr, 1); -// } else { -// //LOG(FATAL) << "Axpy, input and output strides do not match." ; -// EltwiseMult(in, alpha, out, ctx); -// } -// } - template <> void Axpy(const float alpha, const Tensor &in, Tensor *out, Context *ctx) { @@ -817,20 +803,25 @@ void Axpy(const float alpha, const Tensor &in, Tensor *out, } } -// template <> -// void Axpy(const float alpha, -// const Tensor& in, Tensor *out, Context *ctx) { -// //check input tensor for strides first -// if (in.stride() == out->stride()) { -// const float *inPtr = static_cast(in.block()->data()); -// float *outPtr = static_cast(out->block()->mutable_data()); -// cblas_saxpy(in.Size(), alpha, inPtr, 1, outPtr, 1); -// } else if(out->transpose()) { -// LOG(FATAL) << "output is already transposed." ; -// } else { -// LOG(FATAL) << "Axpy, input and output strides do not match." ; -// } -// } +template <> +void Axpy(const Tensor &alpha, const Tensor &in, Tensor *out, + Context *ctx) { + // check input tensor for strides first + const float *inPtr = static_cast(in.block()->data()); + float *outPtr = static_cast(out->block()->mutable_data()); + const float a = *static_cast(alpha.block()->data()); + + if (in.stride() == out->stride()) { + cblas_saxpy(in.Size(), a, inPtr, 1, outPtr, 1); + } else { + // LOG(FATAL) << "Axpy, input and output strides do not match." ; + Tensor t(in.shape(), in.device(), in.data_type()); + EltwiseMult(in, a, &t, ctx); + float *tPtr = static_cast(t.block()->mutable_data()); + cblas_saxpy(in.Size(), 1, tPtr, 1, outPtr, 1); + } +} + template <> void Dot(const Tensor &in1, const Tensor &in2, float *out, @@ -1148,121 +1139,7 @@ void RowMax(const Tensor &in, Tensor *out, Context *ctx) { } } -// =========Matrix operations ================================================ -/* -template <> -void SoftMax(const Tensor &in, Tensor *out, Context* ctx) { - CHECK_LE(in.nDim(), 2u) << "Axis is required for SoftMax on multi dimemsional -tensor"; - out->CopyData(in); - size_t nrow = 1, ncol = in.Size(), size = ncol; - if (in.nDim() == 2u) { - nrow = in.shape(0); - ncol = size / nrow; - out->Reshape(Shape{nrow, ncol}); - } - Tensor tmp = RowMax(*out); - SubColumn(tmp, out); - Exp(*out, out); - - SumColumns(*out, &tmp); - DivColumn(tmp, out); - out->Reshape(in.shape()); -} -template <> -void AddCol(const size_t nrow, const size_t ncol, - const Tensor& A, const Tensor& v, Tensor* out, - Context *ctx) { - float *outPtr = static_cast(out->mutable_data()); - const float *APtr = static_cast(A.data()); - const float *vPtr = static_cast(v.data()); - for (size_t r = 0; r < nrow; r++) { - size_t offset = r * ncol; - for (size_t c = 0; c < ncol; c++) { - outPtr[offset + c] = APtr[offset + c] + vPtr[r]; - } - } -} - -template <> -void AddRow(const size_t nrow, const size_t ncol, - const Tensor& A, const Tensor& v, Tensor* out, - Context *ctx) { - float *outPtr = static_cast(out->mutable_data()); - const float *APtr = static_cast(A.data()); - const float *vPtr = static_cast(v.data()); - for (size_t r = 0; r < nrow; r++) { - size_t offset = r * ncol; - for (size_t c = 0; c < ncol; c++) { - outPtr[offset + c] = APtr[offset + c] + vPtr[c]; - } - } -} -template <> -void Outer(const size_t m, const size_t n, const Tensor& in1, - const Tensor& in2, Tensor* out, Context *ctx) { - float *outPtr = static_cast(out->mutable_data()); - const float *in1Ptr = static_cast(in1.data()); - const float *in2Ptr = static_cast(in2.data()); - for (size_t r = 0; r < m; r++) { - size_t offset = r * n; - for (size_t c = 0; c < n; c++) { - outPtr[offset + c] = in1Ptr[r] * in2Ptr[c]; - } - } -} -template <> -void Softmax(const size_t nrow, const size_t ncol, - const Tensor& in, Tensor* out, Context *ctx) { - float *outPtr = static_cast(out->mutable_data()); - const float *inPtr = static_cast(in.data()); - float *bPtr = new float[ncol]; - for (size_t r = 0; r < nrow; r++) { - size_t offset = r * ncol; - float denom = 0.f; - for (size_t c = 0; c < ncol; c++) { - bPtr[c] = exp(inPtr[offset + c]); - denom += bPtr[c]; - } - for (size_t c = 0; c < ncol; c++) { - size_t idx = offset + c; - outPtr[idx] = bPtr[c] / denom; - } - } - delete bPtr; -} - -template <> -void SumColumns(const size_t nrow, const size_t ncol, - const Tensor& in, Tensor* out, Context *ctx) { - float *outPtr = static_cast(out->mutable_data()); - const float *inPtr = static_cast(in.data()); - for (size_t c = 0; c < ncol; c++) { - outPtr[c] = 0.f; - } - for (size_t r = 0; r < nrow; r++) { - size_t offset = r * ncol; - for (size_t c = 0; c < ncol; c++) { - outPtr[c] += inPtr[offset + c]; - } - } -} - -template <> -void SumRows(const size_t nrow, const size_t ncol, - const Tensor& in, Tensor* out, Context *ctx) { - float *outPtr = static_cast(out->mutable_data()); - const float *inPtr = static_cast(in.data()); - for (size_t r = 0; r < nrow; r++) { - size_t offset = r * ncol; - outPtr[r] = 0.f; - for (size_t c = 0; c < ncol; c++) { - outPtr[r] += inPtr[offset + c]; - } - } -} -*/ } // namespace singa #endif // SINGA_CORE_TENSOR_TENSOR_MATH_CPP_H_ diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py index 82d6d5cb0e..533541271b 100644 --- a/test/python/test_tensor.py +++ b/test/python/test_tensor.py @@ -19,6 +19,7 @@ import math import unittest +import random import numpy as np from singa import tensor @@ -546,7 +547,7 @@ def _kint_float(self, dev=gpu_dev): x_val = np.random.randint(0, 10, (2, 3)) x = tensor.from_numpy(x_val) x.to_device(dev) - scalar = np.random.random((1,))[0] * 100 + scalar = random.random() * 100 y = x + scalar self.assertEqual(y.dtype, tensor.float32) np.testing.assert_array_almost_equal(tensor.to_numpy(y), x_val + scalar) diff --git a/tool/conda/singa/meta.yaml b/tool/conda/singa/meta.yaml index cface0df3c..5cacf3b1b3 100644 --- a/tool/conda/singa/meta.yaml +++ b/tool/conda/singa/meta.yaml @@ -20,7 +20,7 @@ # https://docs.conda.io/projects/conda-build/en/latest/resources/define-metadata.html#templating-with-jinja # {% set data = load_setup_py_data(setup_file='../../../python/singa/setup.py', from_recipe_dir=True) %} -{% set version = "2.1.0.dev" %} +{% set version = "3.2.0" %} package: name: singa