Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Macros
tensor_gpu-inl.hpp
Go to the documentation of this file.
1 #ifndef MSHADOW_TENSOR_GPU_INL_HPP
2 #define MSHADOW_TENSOR_GPU_INL_HPP
3 
8 #include "tensor.h"
9 
10 #if !(MSHADOW_USE_CUDA)
11 namespace mshadow {
12  // do nothing if no GPU operation is involved
13  inline void InitTensorEngine( int dev_id ){
14  }
15  inline void ShutdownTensorEngine( void ){
16  }
17 };
18 #else
19 namespace mshadow {
20  #if (MSHADOW_USE_NVML)
21  inline int AutoSelectDevice(int device_count) {
22  // TODO nvml device id and cuda device id are not consistent
23  return 0;
24  }
25  #endif
26  inline void InitTensorEngine(int dev_id){
27  cudaDeviceProp prop;
28  int device_id = 0;
29  int device_count = 0;
30  cudaGetDeviceCount(&device_count);
31  utils::Assert(device_count > 0, "Cannot find CUDA device. Please check CUDA-Configuration");
32  if (dev_id < 0) {
33  #if (MSHADOW_USE_NVML)
34  device_id = AutoSelectDevice(device_count);
35  #endif
36  } else {
37  device_id = dev_id;
38  }
39  utils::Assert( device_id < device_count, "Incorrect Device ID" );
40  utils::Assert( cudaSetDevice(device_id) == cudaSuccess, "cannot set device" );
41  cudaGetDeviceProperties(&prop, device_id);
42  printf("Use CUDA Device %d: %s\n", device_id, prop.name);
43  cublasInit();
44  }
45  inline void ShutdownTensorEngine( void ){
46  cublasShutdown();
47  }
48 
49  template<int dim>
50  inline void AllocSpace(Tensor<gpu,dim> &obj, bool pad){
51  size_t pitch;
52  // common choice for cuda mem align unit is 32
53  if( pad && obj.shape[0] >= MSHADOW_MIN_PAD_RATIO * 32 ){
54  cudaError_t err = cudaMallocPitch( (void**)&obj.dptr, &pitch, \
55  obj.shape[0] * sizeof(real_t), obj.FlatTo2D().shape[1] );
56  utils::Assert( err == cudaSuccess, cudaGetErrorString(err) );
57  obj.shape.stride_ = static_cast<index_t>( pitch / sizeof(real_t) );
58  }else{
59  obj.shape.stride_ = obj.shape[0];
60  cudaError_t err = cudaMallocPitch( (void**)&obj.dptr, &pitch, \
61  obj.shape.Size() * sizeof(real_t), 1 );
62  utils::Assert( err == cudaSuccess, cudaGetErrorString(err) );
63  }
64  }
65 
66  template<int dim>
67  inline void FreeSpace(Tensor<gpu,dim> &obj){
68  cudaFree( obj.dptr ); obj.dptr = NULL;
69  }
70 
71  template<typename A,typename B, int dim>
72  inline void Copy(Tensor<A,dim> _dst, Tensor<B,dim> _src, cudaMemcpyKind kind){
73  utils::Assert( _dst.shape == _src.shape, "Copy:shape mismatch" );
74  Tensor<A,2> dst = _dst.FlatTo2D();
75  Tensor<B,2> src = _src.FlatTo2D();
76  cudaError_t err = cudaMemcpy2D( dst.dptr, dst.shape.stride_ * sizeof(real_t),
77  src.dptr, src.shape.stride_ * sizeof(real_t),
78  dst.shape[0] * sizeof(real_t),
79  dst.shape[1], kind );
80  utils::Assert( err == cudaSuccess, cudaGetErrorString(err) );
81  }
82  template<int dim>
83  inline void Copy(Tensor<cpu,dim> dst, const Tensor<gpu,dim> &src){
84  Copy( dst, src, cudaMemcpyDeviceToHost );
85  }
86  template<int dim>
87  inline void Copy(Tensor<gpu,dim> dst, const Tensor<gpu,dim> &src){
88  Copy( dst, src, cudaMemcpyDeviceToDevice );
89  }
90  template<int dim>
91  inline void Copy(Tensor<gpu,dim> dst, const Tensor<cpu,dim> &src){
92  Copy( dst, src, cudaMemcpyHostToDevice );
93  }
94 };
95 
96 #ifdef __CUDACC__
97 // the following part is included only if compiler is nvcc
98 #include "cuda/tensor_gpu-inl.cuh"
99 
100 namespace mshadow{
101  template<typename Saver, typename E, int dim>
102  inline void MapPlan(Tensor<gpu,dim> _dst, const expr::Plan<E> &plan){
103  cuda::MapPlan<Saver>( _dst.FlatTo2D(), plan );
104  }
105 
106  template<typename Saver, int dim, typename E, int etype>
107  inline void MapExp(Tensor<gpu,dim> dst, const expr::Exp<E,etype> &exp ){
108  using namespace expr;
109  TypeCheckPass< TypeCheck<gpu,dim,E>::kMapPass >::Error_All_Tensor_in_Exp_Must_Have_Same_Type();
110  Shape<dim> eshape = ShapeCheck<dim,E>::Check( exp.self() );
111  utils::Assert( eshape[0] == 0 || eshape == dst.shape, "Assignment: Shape of Tensors in expression is not consistent with target" );
112  MapPlan<Saver>( dst, MakePlan( exp.self() ) );
113  }
114 
115  template<typename Saver, typename Reducer, typename E, int etype>
116  inline void MapReduceKeepLowest( Tensor<gpu,1> dst, const expr::Exp<E,etype> &exp, real_t scale ){
117  using namespace expr;
118  TypeCheckPass< TypeCheck<gpu,1,E>::kRedPass >::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
119  Shape<2> eshape = ShapeCheck< ExpInfo<E>::kDim, E >::Check( exp.self() ).FlatTo2D();
120 
121  utils::Assert( eshape[0] == dst.shape[0], "reduction dimension do not match" );
122  utils::Assert( eshape[1] != 0, "can not reduce over empty tensor" );
123  cuda::MapReduceKeepLowest<Saver,Reducer>( dst, MakePlan( exp.self() ), scale, eshape );
124  }
125 
126  template<typename Saver, typename Reducer, int dimkeep, typename E, int etype>
127  inline void MapReduceKeepHighDim( Tensor<gpu,1> dst, const expr::Exp<E,etype> &exp, real_t scale ){
128  using namespace expr;
129  TypeCheckPass< TypeCheck<gpu,dimkeep,E>::kRedPass >::Error_TypeCheck_Not_Pass_For_Reduce_Exp();
130  typedef Shape< ExpInfo<E>::kDim > EShape;
131  EShape eshape = ShapeCheck< ExpInfo<E>::kDim, E >::Check( exp.self() );
132  utils::Assert( eshape[dimkeep] == dst.shape[0], "reduction dimension do not match" );
133  // use equvalent form
134  Shape<4> pshape = Shape4( eshape.ProdShape(dimkeep+1,EShape::kMaxShape), eshape[dimkeep],
135  eshape.ProdShape(1,dimkeep), eshape[0] );
136  // call equavalent map red dim 2
137  cuda::MapReduceKeepDim2<Saver,Reducer>( dst, MakePlan( exp.self() ), scale, pshape );
138  }
139 
140  inline void Softmax( Tensor<gpu,2> dst, const Tensor<gpu,2>& src ){
141  cuda::Softmax( dst, src );
142  }
143 }; // namespace mshadow
144 
145 #endif // __CUDACC__
146 
147 #endif // MSHADOW_USE_CUDA
148 #endif // TENSOR_GPU_INL_HPP
void MapExp(Tensor< cpu, dim > dst, const expr::Exp< E, etype > &exp)
CPU/GPU: map a expression to a tensor, this function calls MapPlan.
Definition: tensor_cpu-inl.hpp:87
unsigned index_t
type that will be used for index
Definition: tensor_base.h:123
void MapReduceKeepLowest(Tensor< cpu, 1 > dst, const expr::Exp< E, etype > &exp, real_t scale=1.0f)
CPU/GPU: map a expression, do reduction to 1D Tensor in lowest dimension (dimension 0) ...
Definition: tensor_cpu-inl.hpp:100
#define MSHADOW_MIN_PAD_RATIO
x dimension of data must be bigger pad_size * ratio to be alloced padded memory, otherwise use tide a...
Definition: tensor_base.h:32
MSHADOW_XINLINE Shape< 4 > Shape4(index_t s3, index_t s2, index_t s1, index_t s0)
construct a four dimension shape, stride will equal s0
Definition: tensor.h:176
void FreeSpace(Tensor< cpu, dim > &obj)
CPU/GPU: free the space of tensor, will set obj.dptr to NULL.
Definition: tensor_cpu-inl.hpp:36
void Assert(bool exp)
assert a expression is true
Definition: tensor_base.h:285
void InitTensorEngine(int device_id=0)
initialize tensor engine, used to call intialization functions of dependent libs this function should...
Definition: tensor_gpu-inl.hpp:26
void MapReduceKeepHighDim(Tensor< cpu, 1 > dst, const expr::Exp< E, etype > &exp, real_t scale=1.0f)
CPU/GPU: map a expression, do reduction to 1D Tensor in third dimension (dimension 2) ...
Definition: tensor_cpu-inl.hpp:119
float real_t
type that will be used for content
Definition: tensor_base.h:118
void Softmax(Tensor< cpu, 2 > dst, const Tensor< cpu, 2 > &energy)
CPU/GPU: normalize softmax: dst[i][j] = exp( energy[i][j] ) /( sum_j exp( energy[i][j] ) ) ...
Definition: tensor_cpu-inl.hpp:160
header file of tensor data structure and functions covention: this lib requires explicit memory alloc...
real_t * dptr
pointer to the data
Definition: tensor.h:215
MSHADOW_XINLINE Tensor< Device, 2 > FlatTo2D(void) const
flatten the tensor to 2 dimension, collapse the higher dimensions together
Definition: tensor.h:229
void ShutdownTensorEngine(void)
Shutdown tensor engine, this function should be called after all GPU tensor operations, for using tensors in CPU, this call is actually not needed.
Definition: tensor_gpu-inl.hpp:45
void Copy(Tensor< cpu, dim > dst, const Tensor< cpu, dim > &src)
copy data from one tensor to another, with same shape
Definition: tensor_cpu-inl.hpp:42
Shape< dimension > shape
shape of the tensor
Definition: tensor.h:217
void AllocSpace(Tensor< cpu, dim > &obj, bool pad=MSHADOW_ALLOC_PAD)
CPU/CPU: allocate space for CTensor, according to the shape in the obj this function is responsible t...
Definition: tensor_cpu-inl.hpp:14
general tensor
Definition: tensor.h:206
PaddingExp< SrcExp, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
Definition: tensor_expr_ext.h:496