Apache Singa
A General Distributed Deep Learning Library
tensor.h
1 
19 #ifndef SINGA_CORE_TENSOR_H_
20 #define SINGA_CORE_TENSOR_H_
21 
22 #include <vector>
23 #include <tuple>
24 #include <memory>
25 
26 #include "singa/core/common.h"
27 #include "singa/core/device.h"
28 #include "singa/proto/core.pb.h"
29 #include "singa/utils/logging.h"
30 
31 using std::vector;
32 using std::tuple;
33 namespace singa {
34 
35 typedef vector<size_t> Shape;
37 const size_t kDataWidth[] = {sizeof(float), sizeof(float) / 2,
38  sizeof(int), sizeof(char),
39  sizeof(double), sizeof(unsigned char)
40  };
41 inline size_t SizeOf(DataType t) {
42  static_assert(kNumDataType == sizeof(kDataWidth) / sizeof(size_t),
43  "Num of data types not match num of data width");
44  CHECK_GT(kNumDataType, t);
45  return kDataWidth[t];
46 }
47 
56 class Tensor {
57  public:
58  ~Tensor();
59  Tensor();
60 
62  explicit Tensor(const Shape &shape, DataType dtype = kFloat32);
63 
65  Tensor(const Shape &shape,
66  std::shared_ptr<Device> dev,
67  DataType dtype = kFloat32);
68 
70  Tensor(const Tensor &from);
71 
73  Tensor(Tensor &&from);
74 
75  // --------------------------------------------------------------------------
76  // ---Following methods return info of the class without making any changes--
77  // --------------------------------------------------------------------------
78 
82  Block *block() const { return block_; }
83 
84  std::shared_ptr<Device> device() const { return device_; }
85 
87  template <typename SType>
88  const SType *data() const {
89  return static_cast<const SType *>(block()->data());
90  }
91 
93  const DataType data_type() const { return data_type_; }
94 
95  const Shape &shape() const { return shape_; }
96 
97  const size_t shape(const size_t idx) const {
98  CHECK_LT(idx, shape_.size());
99  return shape_.at(idx);
100  }
101 
102  size_t nDim() const { return shape_.size(); }
103 
104  bool empty() const { return nDim() == 0; }
105 
107  bool transpose() const {
108  if (!stride_.empty()) {
109  auto last = stride_.front();
110  for (auto s : stride_) {
111  if (s > last && last > 0)
112  return true;
113  if (s > 0)
114  last = s;
115  }
116  }
117  return false;
118  }
119 
120  const vector<int>& stride() const { return stride_; }
121 
123  bool initailized() const {
124  return block_ != nullptr && block_->initialized();
125  }
126 
128  size_t Size() const {
129  if (block_ == nullptr) return 0u;
130  CHECK_EQ(block_->size() % SizeOf(data_type_), 0u);
131  return block_->size() / SizeOf(data_type_);
132  }
133 
135  size_t MemSize() const { return block_->size(); }
136 
139  template <typename SType>
140  void GetValue(SType *value, const size_t num) {
141  CHECK(device_ == defaultDevice);
142  const SType* ptr = data<SType>();
143  for (size_t i = 0; i < num; i++) value[i] = ptr[i];
144  }
145 
147  void ToProto(singa::TensorProto *proto) const;
148 
150  float L1() const;
151 
153  float L2() const;
154  // --------------------------------------------------------------------------
155  // ---Following methods changes the internal data
156  // --------------------------------------------------------------------------
157 
159  template <typename SType>
160  void SetValue(const SType x);
161 
164  template <typename SType>
165  void CopyDataFromHostPtr(const SType *src, const size_t num,
166  const size_t offset = 0);
167 
170  void CopyData(const Tensor &other);
171 
173  void FromProto(const singa::TensorProto &proto);
174 
175 
177  void RepeatData(const vector<size_t>& repeats, int axis, int total_repeats,
178  const Tensor &other);
179 
180  // --------------------------------------------------------------------------
181  // ---Following methods returns a new Tensor without change original tensor
182  // --------------------------------------------------------------------------
183 
184  Tensor Repeat(const vector<size_t>& repeats, int axis,
185  std::shared_ptr<Device> device = nullptr);
186 
189  Tensor Clone(std::shared_ptr<Device> device = nullptr) const;
190 
191  // --------------------------------------------------------------------------
192  // ---Following methods change the tensor and return itself
193  // --------------------------------------------------------------------------
195  Tensor &operator=(const Tensor &in);
196 
198  Tensor &operator=(Tensor &&in);
199 
200  Tensor &operator+=(const Tensor &in);
201 
202  Tensor &operator-=(const Tensor &in);
203 
204  Tensor &operator*=(const Tensor &in);
205 
206  Tensor &operator/=(const Tensor &in);
207 
208  // Scalar operations.
209 
211  template <typename SType>
212  Tensor &operator+=(const SType x);
213 
215  template <typename SType>
216  Tensor &operator-=(const SType x);
217 
219  template <typename SType>
220  Tensor &operator*=(const SType x);
221 
223  template <typename SType>
224  Tensor &operator/=(const SType x);
225 
227  Tensor &Reshape(const Shape &shape);
228 
229 
231  Tensor& Resize(const Shape& shape);
232 
234  Tensor& T();
235 
237  Tensor& Transpose();
238 
240  Tensor& Transpose(const vector<size_t> &axes);
241 
244  Tensor& Broadcast(const Shape& shape);
245 
249  Tensor& ResetLike(const Tensor &t);
250 
252  Tensor& AsType(const DataType type);
253 
256  Tensor& ToDevice(std::shared_ptr<Device> dev);
257 
259  Tensor& ToHost();
260 
261  protected:
262 
263  //generate strides automatically if stride field is not passed
264  void generate_stride() {
265  stride_.clear();
266  if (shape_.size() == 0) {
267  stride_.push_back(1);
268  return;
269  }
270 
271  size_t dim = Size();
272  int cumulative_product = 1;
273  for (size_t n = 0; n < shape_.size(); ++n) {
274  cumulative_product = cumulative_product * shape_[n];
275  stride_.push_back(dim / cumulative_product);
276  }
277  }
278 
279  void set_strides(const vector<int> new_strides) {
280  stride_ = new_strides;
281  }
282 
283  protected:
284  DataType data_type_ = kFloat32;
285  std::shared_ptr<Device> device_ = nullptr;
288  Block *block_ = nullptr;
289  Shape shape_ = {};
290  vector<int> stride_ = {};
291 }; //end of tensor class
292 
293 
294 inline size_t Product(const Shape &shape, int start = 0, size_t len = 0) {
295  if (len == 0) len = shape.size();
296  if (len == 0) return 0;
297  CHECK_LE(len, shape.size());
298  size_t v = 1;
299  for (unsigned int i = start; i < len; i++) v *= shape[i];
300  return v;
301 }
302 
303 
304 inline void CheckDataTypeAndLang(const Tensor &in1, const Tensor &in2) {
305  CHECK_EQ(in1.data_type(), in2.data_type());
306  CHECK_EQ(in1.device()->lang(), in2.device()->lang());
307 }
308 
309 
310 template <typename FromType, typename ToType>
311 ToType TypeCast(const FromType &x) {
312  // TODO(wangwei) cast fp16; prevent some casts, e.g., float to char
313  return static_cast<ToType>(x);
314 }
315 
316 Tensor Boradcast(const Shape& shape);
317 
320 Tensor Reshape(const Tensor &in, const Shape &s);
321 
322 Tensor Resize(const Tensor &in, const Shape &s);
323 
325 Tensor Transpose(const Tensor& in);
326 
329 Tensor Broadcast(const Tensor& in, const Shape& shape);
330 
332 Tensor Transpose(const Tensor& in, const vector<size_t> &axes);
333 
336 void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num,
337  const size_t dst_offset = 0, const size_t src_offset = 0);
338 
339 
340 void RepeatDataToFrom(bool broadcast_flag, const vector<size_t>& repeats, int axis,
341  Tensor *dst, const Tensor &in, const size_t num);
342 
343 // =============Element-wise operations====================================
344 Tensor Abs(const Tensor &in);
345 Tensor Exp(const Tensor &in);
346 Tensor Log(const Tensor &in);
347 Tensor ReLU(const Tensor &in);
348 Tensor Sigmoid(const Tensor &in);
349 Tensor Sign(const Tensor &in);
350 Tensor Sqrt(const Tensor &in);
351 Tensor Square(const Tensor &in);
352 Tensor Tanh(const Tensor &in);
353 Tensor Transform(const Tensor &in);
354 
355 void Abs(const Tensor &in, Tensor *out);
356 void Exp(const Tensor &in, Tensor *out);
357 void Log(const Tensor &in, Tensor *out);
358 void ReLU(const Tensor &in, Tensor *out);
359 void Sigmoid(const Tensor &in, Tensor *out);
360 void Sign(const Tensor &in, Tensor *out);
361 void Sqrt(const Tensor &in, Tensor *out);
362 void Square(const Tensor &in, Tensor *out);
363 void Tanh(const Tensor &in, Tensor *out);
364 void Transform(const Tensor &in, Tensor *out);
365 
367 template <typename SType>
368 Tensor Pow(const Tensor &in, const SType x);
370 template <typename SType>
371 void Pow(const Tensor &in, const SType x, Tensor *out);
373 Tensor Pow(const Tensor &base, const Tensor &exp);
375 void Pow(const Tensor &base, const Tensor &exp, Tensor *out);
376 
378 template <typename SType>
379 Tensor operator<(const Tensor &in, const SType x);
380 template <typename SType>
381 void LT(const Tensor &in, const SType x, Tensor *out);
382 
384 Tensor operator<(const Tensor &in1, const Tensor& in2);
385 void LT(const Tensor &in1, const Tensor& in2, Tensor *out);
386 
388 template <typename SType>
389 Tensor operator<=(const Tensor &in, const SType x);
390 template <typename SType>
391 void LE(const Tensor &in, const SType x, Tensor *out);
392 
394 Tensor operator<=(const Tensor &in1, const Tensor& in2);
395 void LE(const Tensor &in1, const Tensor& in2, Tensor *out);
396 
398 template <typename SType>
399 Tensor operator>(const Tensor &in, const SType x);
400 template <typename SType>
401 void GT(const Tensor &in, const SType x, Tensor *out);
402 
404 Tensor operator>(const Tensor &in1, const Tensor& in2);
405 void GT(const Tensor &in1, const Tensor& in2, Tensor *out);
406 
407 
409 template <typename SType>
410 Tensor operator>=(const Tensor &in, const SType x);
411 template <typename SType>
412 void GE(const Tensor &in, const SType x, Tensor *out);
413 
415 Tensor operator>=(const Tensor &in1, const Tensor& in2);
416 void GE(const Tensor &in1, const Tensor& in2, Tensor *out);
417 
418 
419 Tensor operator+(const Tensor &lhs, const Tensor &rhs);
420 void Add(const Tensor &lhs, const Tensor &rhs, Tensor *out);
421 Tensor operator-(const Tensor &lhs, const Tensor &rhs);
422 void Sub(const Tensor &lhs, const Tensor &rhs, Tensor *out);
423 Tensor operator*(const Tensor &lhs, const Tensor &rhs);
424 void EltwiseMult(const Tensor &lhs, const Tensor &rhs, Tensor *out);
425 Tensor operator/(const Tensor &lhs, const Tensor &rhs);
426 void Div(const Tensor &lhs, const Tensor &rhs, Tensor *out);
427 
428 template <typename SType>
429 Tensor operator+(const Tensor &in, const SType x);
430 template <typename SType>
431 void Add(const Tensor &in, const SType x, Tensor *out);
432 
433 template <typename SType>
434 Tensor operator-(const Tensor &in, const SType x);
435 template <typename SType>
436 void Sub(const Tensor &in, const SType x, Tensor *out);
437 
438 template <typename SType>
439 Tensor operator*(const Tensor &in, const SType x);
440 template <typename SType>
441 void EltwiseMult(const Tensor &in, const SType x, Tensor *out);
442 
444 template <typename SType>
445 Tensor operator/(const Tensor &in, const SType x);
447 template <typename SType>
448 void Div(const Tensor &in, const SType x, Tensor *out);
449 
451 template <typename SType>
452 Tensor Div(const SType x, const Tensor &in);
454 template <typename SType>
455 void Div(const SType x, const Tensor &in, Tensor *out);
456 
457 template <typename SType = float>
458 SType Sum(const Tensor &in);
459 
460 
461 // ============Matrix (row/column) operations==================================
466 Tensor Average(const Tensor &in, const int axis);
467 
469 void AddColumn(const Tensor &v, Tensor *M);
471 template <typename SType>
472 void AddColumn(const SType alpha, const SType beta, const Tensor &v,
473  Tensor *out);
475 void AddRow(const Tensor &v, Tensor *out);
477 template <typename SType>
478 void AddRow(const SType alpha, const SType beta, const Tensor &v, Tensor *M);
480 void DivColumn(const Tensor &v, Tensor *M);
482 void DivRow(const Tensor &v, Tensor *M);
484 void MultColumn(const Tensor &v, Tensor *M);
486 void MultRow(const Tensor &v, Tensor *M);
488 Tensor SoftMax(const Tensor &in);
489 
490 Tensor RowMax(const Tensor &in);
492 void SoftMax(const Tensor &in, Tensor *out);
494 void SubColumn(const Tensor &v, Tensor *M);
496 void SubRow(const Tensor &v, Tensor *M);
498 void SumColumns(const Tensor &M, Tensor *out);
500 void SumRows(const Tensor &M, Tensor *out);
501 
506 Tensor Sum(const Tensor &in, const int axis);
507 
508 // ================Random operations==========================================
510 template <typename SType>
511 void Bernoulli(const SType p, Tensor *out);
513 template <typename SType>
514 void Gaussian(const SType mean, const SType std, Tensor *out);
516 template <typename SType>
517 void Uniform(const SType low, const SType high, Tensor *out);
518 
519 // ================Blas operations============================================
520 // TODO(wangwei) make amax/amin/asum a member function of tensor
521 
523 template <typename SType>
524 void Axpy(SType alpha, const Tensor &in, Tensor *out);
525 
528 Tensor Mult(const Tensor &A, const Tensor &B);
531 void Mult(const Tensor &A, const Tensor &B, Tensor *C);
534 template <typename SType>
535 void Mult(const SType alpha, const Tensor &A, const Tensor &B, const SType beta,
536  Tensor *C);
537 
538 // *****************
539 // Misc.
540 // ****************
549 
550 void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss);
551 
555 
556 void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p);
557 
560 Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t);
561 Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t);
562 
565 Tensor CopyRows(const Tensor &in, const size_t start, const size_t end);
567 Tensor SliceRows(const Tensor &in, const size_t start, const size_t end);
569 Tensor SliceOn(const Tensor &in, const size_t start, const size_t end,
570  int axis);
573 Tensor CopyColumns(const Tensor &in, const size_t start, const size_t end);
575 Tensor SliceColumns(const Tensor &in, const size_t start, const size_t end);
578 Tensor ConcatenateRows(const vector<Tensor> &in);
580 Tensor ConcatOn(const std::vector<Tensor> &in, int axis);
582 Tensor ConcatRows(const vector<Tensor> &in);
585 Tensor ConcatenateColumns(const vector<Tensor> &in);
587 Tensor ConcatColumns(const vector<Tensor> &in);
588 } // namespace singa
589 
590 #endif // SINGA_CORE_TENSOR_H_
float L2() const
Return average L2 norm.
Tensor SliceRows(const Tensor &in, const size_t start, const size_t end)
Alias of CopyRows.
Tensor & operator=(const Tensor &in)
Copy assignment.
void RepeatData(const vector< size_t > &repeats, int axis, int total_repeats, const Tensor &other)
TODO(wangwei) merge RepeatData into Repeat?
void CopyData(const Tensor &other)
Copy data from another Tensor which may be on a diff device.
Tensor ConcatOn(const std::vector< Tensor > &in, int axis)
Return a tensor concatenated of the input tensors along the give axis.
void Axpy(SType alpha, const Tensor &in, Tensor *out)
out = alpha*in + out
void AddColumn(const Tensor &v, Tensor *M)
Add column &#39;v&#39; with each column of matrix M.
bool transpose() const
The stride should decrease except dim with stride=0 due to broadcasting.
Definition: tensor.h:107
Tensor operator>=(const Tensor &in, const SType x)
Element-wise operation, out[i]= (in[i] >= x) ? 1.f : 0.f.
void SoftmaxCrossEntropyBwd(const Tensor &t, Tensor *p)
Compute the dx, given prediction probability &#39;p&#39; (p=softmax(x)) and the target (ground truth) labels ...
Tensor CopyColumns(const Tensor &in, const size_t start, const size_t end)
Return a tensor consisting of columns ([start, end)) from &#39;in&#39;.
void CopyDataToFrom(Tensor *dst, const Tensor &src, const size_t num, const size_t dst_offset=0, const size_t src_offset=0)
Copy &#39;num&#39; elements of src to dst.
Tensor Average(const Tensor &in, const int axis)
Average elements in the Tensor, currently only support vector and matrix.
Tensor SliceOn(const Tensor &in, const size_t start, const size_t end, int axis)
Slice the input tensor along the give axis to generate a new tensor.
Tensor & ToDevice(std::shared_ptr< Device > dev)
Reset the device.
void AddRow(const Tensor &v, Tensor *out)
Add row &#39;v&#39; with each row of matrix M; write results into &#39;out&#39;.
void ToProto(singa::TensorProto *proto) const
Serialize data, shape and transpose to protobuf object.
const size_t kDataWidth[]
hardcode the width of types defined in DataType
Definition: tensor.h:37
void SetValue(const SType x)
Set each element of the tensor to be x.
size_t MemSize() const
Return memory size (i.e., Bytes)
Definition: tensor.h:135
Block * block_
Note: block_ is allocated in lazy manner to avoid frequent malloc/free.
Definition: tensor.h:288
Tensor CrossEntropyFwd(const Tensor &p, const Tensor &t)
To be called by pysinga autograd operations; swig ignores the const qualifier http://www.swig.org/Doc3.0/SWIGPlus.html#SWIGPlus_const.
Tensor Mult(const Tensor &A, const Tensor &B)
Do matrix vector multipication or matrix matrix multiplication depdending on the Tensor shape...
A Tensor instance is a multi-dimensional array resident on a Device (default device is the host CPU)...
Definition: tensor.h:56
void ComputeCrossEntropy(const Tensor &p, const Tensor &t, Tensor *loss)
Compute the cross entropy loss given the prediction probability &#39;p&#39; and the target (ground truth) lab...
bool initailized() const
Return true if the content of the tensor is initialized.
Definition: tensor.h:123
void MultColumn(const Tensor &v, Tensor *M)
Multiply column &#39;v&#39; and each column of matrix M; write results into &#39;out&#39;.
Tensor & Transpose()
Reverse the shape vector.
Block * block() const
For functions in xx_math.cc to access the block.
Definition: tensor.h:82
Tensor CopyRows(const Tensor &in, const size_t start, const size_t end)
Return a tensor consisting of rows ([start, end)) from &#39;in&#39;.
void SumRows(const Tensor &M, Tensor *out)
Sum all rows of matrix M into a single row as &#39;out&#39;.
void Gaussian(const SType mean, const SType std, Tensor *out)
Fill in Tensor &#39;t&#39; following Gaussian distribution.
Tensor ConcatenateColumns(const vector< Tensor > &in)
Return a tensor which is horizontally stacked from tensors in &#39;in&#39;.
Tensor & T()
Matrix transpose. Valid only if shape.size() == 2.
Tensor & ToHost()
Equivalent to ToDevice(host_dev).
void SubRow(const Tensor &v, Tensor *M)
Sub row &#39;v&#39; by each row of matrix M; write results into &#39;out&#39;.
void DivColumn(const Tensor &v, Tensor *M)
Divide column &#39;v&#39; by each column of matrix M; write results into &#39;out&#39;.
void GetValue(SType *value, const size_t num)
used for swig code to convert Tensor into numpy array.
Definition: tensor.h:140
Tensor SoftMax(const Tensor &in)
Do softmax for each row. &#39;in&#39; could be a 1-d or 2-d Tensor.
const DataType data_type() const
data type, including kFloat16, kFloat32, kInt
Definition: tensor.h:93
void SumColumns(const Tensor &M, Tensor *out)
Sum all columns of matrix M into a single column as &#39;out&#39;.
std::shared_ptr< Device > defaultDevice
a singleton CppDevice as the host for all devices.
Tensor ConcatColumns(const vector< Tensor > &in)
Alias name for function ConcatenateColumns.
Tensor ConcatenateRows(const vector< Tensor > &in)
Return a tensor which is vertically stacked from tensors in &#39;in&#39;.
Tensor Clone(std::shared_ptr< Device > device=nullptr) const
return an exactly the same Tensor with data been deep copied to the given device. ...
Tensor & Reshape(const Shape &shape)
change the shape (and stride); the block may be reallocated.
Tensor operator<=(const Tensor &in, const SType x)
Element-wise operation, out[i]= (in[i] <= x) ? 1.f : 0.f.
void DivRow(const Tensor &v, Tensor *M)
Divide row &#39;v&#39; by each row of matrix M; write results into &#39;out&#39;.
Tensor operator>(const Tensor &in, const SType x)
Element-wise operation, out[i]= (in[i] > x) ? 1.f : 0.f.
Block represent a chunk of memory (on device or host).
Definition: common.h:60
Tensor SliceColumns(const Tensor &in, const size_t start, const size_t end)
Alias of CopyColumns.
void Bernoulli(const SType p, Tensor *out)
For each element x set x = 1 if random() < p; otherwise x = 1.
Tensor & AsType(const DataType type)
Reset the data type, it would reallocate block if type changes.
void Uniform(const SType low, const SType high, Tensor *out)
Fill in Tensor &#39;t&#39; following uniform distribution.
Tensor ConcatRows(const vector< Tensor > &in)
Alias name for function ConcatenateRows.
void MultRow(const Tensor &v, Tensor *M)
Multiply row &#39;v&#39; with each row of matrix M; write results into &#39;out&#39;.
void CopyDataFromHostPtr(const SType *src, const size_t num, const size_t offset=0)
For init the tensor values, copy &#39;num&#39; elements from &#39;src&#39; to the internal memory with &#39;offset&#39; (elem...
Tensor & Resize(const Shape &shape)
Resize the memory and return itself.
void SubColumn(const Tensor &v, Tensor *M)
Sub column &#39;v&#39; by each column of matrix M.
Tensor operator<(const Tensor &in, const SType x)
Element-wise operation, out[i]= (in[i] < x) ? 1.f : 0.f.
Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements...
Definition: common.h:48
size_t Size() const
Return number of total elements.
Definition: tensor.h:128
void FromProto(const singa::TensorProto &proto)
Deserialize data, shape and transpose from protobuf object.
const SType * data() const
Return immutable Tensor values with given type.
Definition: tensor.h:88
Tensor & ResetLike(const Tensor &t)
Reset the shape, device, and data type as given tensor.
Tensor Pow(const Tensor &in, const SType x)
Element-wise opeartion, out[i]=in[i]^x.
Tensor & Broadcast(const Shape &shape)
Return a view of the input tensor whose shape is broadcasted to be compitable with the given shape...
float L1() const
Return average L1 norm.