Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Macros
tensor_expr_ext.h
Go to the documentation of this file.
1 #ifndef MSHADOW_TENSOR_EXPR_EXT_H
2 #define MSHADOW_TENSOR_EXPR_EXT_H
3 
9 namespace mshadow{
10  // Declaration of expressions goes here
11  namespace expr{
20  template<typename Device, int dimdst, int dimcast>
21  struct Broadcast1DExp: public MakeTensorExp< Broadcast1DExp<Device,dimdst,dimcast>,Tensor<Device,1>,dimdst>{
26  this->shape_ = shape;
27  }
28  };
29 
37  template<typename SrcExp, int srcdim>
38  struct UnpackPatchToColXExp: public MakeTensorExp< UnpackPatchToColXExp<SrcExp,srcdim>, SrcExp, 2>{
40  const SrcExp& img_;
52  UnpackPatchToColXExp( const SrcExp &img, index_t psize, index_t pstride )
53  :img_(img), psize_(psize), pstride_(pstride){
55  utils::Assert( imshape[0] >= psize && imshape[1] >= psize, "UnpackPatchToCol:image shape smaller than patch size");
56  this->i_channel_ = imshape[2];
57  this->i_height_ = imshape[1];
58  this->i_width_ = imshape[0];
59  // calculate number of batches
60  const index_t num = imshape.ProdShape( 3, srcdim );
61  const index_t o_height = ( i_height_ - psize ) / pstride + 1;
62  const index_t o_width = ( i_width_ - psize ) / pstride + 1;
63  this->shape_[0] = o_height * o_width * num;
64  this->shape_[1] = psize * psize * imshape[2];
65  }
66  };
67 
74  template<typename Device, int dstdim>
75  struct PackColToPatchXExp: public MakeTensorExp< PackColToPatchXExp<Device,dstdim>, Tensor<Device,2>, dstdim>{
83  PackColToPatchXExp( const Tensor<Device,2> &mat, Shape<dstdim> imshape, index_t psize, index_t pstride )
84  :mat_(mat), psize_(psize), pstride_(pstride){
85  this->shape_ = imshape;
86  const index_t o_height = ( imshape[1] - psize ) / pstride + 1;
87  const index_t o_width = ( imshape[0] - psize ) / pstride + 1;
88  utils::Assert( mat.shape[0] == o_height * o_width * imshape.ProdShape(3,dstdim), "PackColToPatchExp: mat.shape[0] mismatch" );
89  utils::Assert( mat.shape[1] == psize * psize * imshape[2], "PackColToPatchExp: mat.shape[1] mismatch" );
90  }
91  };
92 
101  template<typename SrcExp, int dimdst, int dimsrc>
102  struct ReshapeExp: public MakeTensorExp< ReshapeExp<SrcExp,dimdst,dimsrc>, SrcExp, dimdst>{
104  const SrcExp& src_;
108  ReshapeExp( const SrcExp &src, Shape<dimdst> shape ):src_(src){
110  utils::Assert( ishape.Size() == shape.Size(), "reshape size must match" );
111  ishape0_ = ishape[0];
112  this->shape_ = shape;
113  }
114  };
115 
126  template<typename SrcExp,int dimsrc, int a1, int a2>
127  struct SwapAxisExp: public MakeTensorExp< SwapAxisExp<SrcExp,dimsrc,a1,a2>, SrcExp, dimsrc>{
129  const SrcExp& src_;
131  SwapAxisExp( const SrcExp &src ):src_(src){
133  std::swap( this->shape_[a1], this->shape_[a2] );
134  }
135  };
136 
147  template<typename EType, typename Reducer,int dimkeep>
148  struct ReduceTo1DExp: public Exp< ReduceTo1DExp<EType,Reducer, dimkeep>, type::kComplex >{
150  const EType& src_;
154  ReduceTo1DExp( const EType& src, real_t scale ):src_(src),scale_(scale){}
155  };
156 
163  template<typename Reducer, typename SrcExp, int srcdim>
164  struct PoolingExp: public MakeTensorExp< PoolingExp<Reducer, SrcExp,srcdim>, SrcExp, srcdim> {
166  const SrcExp& src_;
176  PoolingExp( const SrcExp &src, index_t ksize, index_t kstride )
177  : src_(src), ksize_(ksize), kstride_(kstride) {
179  utils::Assert( sshape[0] >= ksize && sshape[1] >= ksize, "pool: kernel must be smaller than image" );
180  this->src_height_ = sshape[1];
181  this->src_width_ = sshape[0];
182  this->shape_ = sshape;
183  this->shape_[1] = (src_height_ - ksize) / kstride + 1;
184  this->shape_[0] = (src_width_ - ksize) / kstride + 1;
185  }
187  PoolingExp( const SrcExp &src, Shape<2> pshape, index_t ksize, index_t kstride )
188  : src_(src), ksize_(ksize), kstride_(kstride) {
190  utils::Assert( sshape[0] >= ksize && sshape[1] >= ksize, "pool: kernel must be smaller than image" );
191  this->src_height_ = sshape[1];
192  this->src_width_ = sshape[0];
193  this->shape_ = sshape;
194  this->shape_[1] = pshape[1];
195  this->shape_[0] = pshape[0];
196  }
197  };
198 
204  template<typename Reducer, typename Device>
205  struct UnPoolingExp: public MakeTensorExp< UnPoolingExp<Reducer, Device>, Tensor<Device,4>, 4> {
217  UnPoolingExp( const Tensor<Device,4> &data_src, const Tensor<Device,4> &data_pooled,
218  const Tensor<Device,4> &grad_pooled, index_t ksize, index_t kstride )
219  : data_src_(data_src), data_pooled_(data_pooled), grad_pooled_(grad_pooled),
220  ksize_(ksize), kstride_(kstride) {
221  utils::Assert( grad_pooled.shape == data_pooled.shape, "UnPoolingExp: pooled shape mismatch" );
222  utils::Assert( grad_pooled.shape[2] == data_src.shape[2], "UnPoolingExp: pool and src shape mismatch" );
223  utils::Assert( grad_pooled.shape[3] == data_src.shape[3], "UnPoolingExp: pool and src shape mismatch" );
224  this->shape_ = data_src_.shape;
225  }
226  };
227 
233  template<typename SrcExp, int srcdim>
234  struct PaddingExp : public MakeTensorExp<PaddingExp<SrcExp, srcdim>, SrcExp, srcdim> {
236  const SrcExp& src_;
244  PaddingExp( const SrcExp &src, index_t pad )
245  : src_(src), pad_(pad) {
247  src_height_ = this->shape_[1];
248  src_width_ = this->shape_[0];
249  this->shape_[1] += pad * 2; // height
250  this->shape_[0] += pad * 2; // width
251  }
252  };
253 
259  template<typename SrcExp, int srcdim>
260  struct CroppingExp : public MakeTensorExp< CroppingExp<SrcExp, srcdim>, SrcExp, srcdim> {
262  const SrcExp& src_;
270  CroppingExp(const SrcExp &src, Shape<2> cshape ): src_(src) {
272  utils::Assert(this->shape_[1] >= cshape[1], "CroppingExp: height requirement not met");
273  utils::Assert(this->shape_[0] >= cshape[0], "CroppingExp: width requirement not met");
274  pad_height_ = (this->shape_[1] - cshape[1]) / 2;
275  pad_width_ = (this->shape_[0] - cshape[0]) / 2;
276  src_height_ = this->shape_[1];
277  this->shape_[1] = cshape[1]; // width
278  this->shape_[0] = cshape[0]; // height
279  }
281  CroppingExp(const SrcExp &src, Shape<2> cshape, index_t start_height, index_t start_width )
282  : src_(src), pad_height_(start_height), pad_width_(start_width) {
284  utils::Assert(this->shape_[1] >= cshape[1], "CroppingExp: height requirement not met");
285  utils::Assert(this->shape_[0] >= cshape[0], "CroppingExp: width requirement not met");
286  src_height_ = this->shape_[1];
287  this->shape_[1] = cshape[1]; // width
288  this->shape_[0] = cshape[0]; // height
289  }
290 
291  }; // struct CroppingExp
292 
293 
299  template<typename SrcExp, int srcdim>
300  struct MirroringExp : public MakeTensorExp<MirroringExp<SrcExp, srcdim>, SrcExp, srcdim> {
302  const SrcExp& src_;
304  MirroringExp( const SrcExp &src ): src_(src) {
306  }
307  };
308 
315  template<typename Reducer, typename SrcExp, int srcdim>
316  struct ChannelPoolingExp: public MakeTensorExp< ChannelPoolingExp<Reducer, SrcExp,srcdim>, SrcExp, srcdim> {
318  const SrcExp& src_;
322  ChannelPoolingExp( const SrcExp &src, index_t nsize ): src_(src), nsize_(nsize){
323  utils::Assert( nsize % 2 == 1, "ChannelPoolingExp: local size must be odd, to make it symmetric" );
325  utils::Assert( this->shape_[2] >= nsize_, "ChannelPoolingExp: local size need to be smaller than number of channels" );
326  }
327  };
328  }; // namespace expr
329 
330 
331  // Declaration of all functions go here
332  namespace expr{
334  template<typename E, typename R,int d>
336  return ReduceTo1DExp<E,R,d>( e.src_, e.scale_*scale );
337  }
339  template<typename E, typename R,int d>
341  return ReduceTo1DExp<E,R,d>( e.src_, e.scale_*scale );
342  }
343 
353  template<int dimcast,typename Device,int dimdst>
356  utils::Assert( src.shape[0] == shape[dimcast], "broadcast, shape mismatch" );
357  return Broadcast1DExp<Device,dimdst,dimcast>( src, shape );
358  }
359 
376  template<typename SrcExp, int etype>
378  TypeCheckPass< ExpInfo<SrcExp>::kDim >= 3 >::Error_Expression_Does_Not_Meet_Dimension_Req();
379  return UnpackPatchToColXExp<SrcExp, ExpInfo<SrcExp>::kDim >( img.self(), psize, pstride );
380  }
381 
391  template<typename Device, int dstdim>
393  utils::Assert( imshape[0] >= psize && imshape[1] >= psize, "PackColToPatch:image shape smaller than patch size");
394  return PackColToPatchXExp<Device,dstdim>( mat, imshape, psize, pstride );
395  }
405  template<typename SrcExp, int etype, int dimdst>
407  return ReshapeExp< SrcExp,dimdst, ExpInfo<SrcExp>::kDim >( src.self(), oshape );
408  }
409 
419  template<int a1, int a2, typename SrcExp, int etype>
421  typedef ExpInfo<SrcExp> Info;
422  TypeCheckPass< Info::kDim>=a1+1 && Info::kDim >= a2+1 && a1+1 <= a2 >::Error_Expression_Does_Not_Meet_Dimension_Req();
424  }
425 
434  template<int dimkeep, typename SrcExp, int etype>
436  return ReduceTo1DExp<SrcExp,red::sum,dimkeep>( exp.self(), 1.0f );
437  }
438 
449  template<typename Reducer, typename SrcExp, int etype>
451  TypeCheckPass< ExpInfo<SrcExp>::kDim >= 2 >::Error_Expression_Does_Not_Meet_Dimension_Req();
452  return PoolingExp<Reducer,SrcExp, ExpInfo<SrcExp>::kDim >(src.self(), ksize, kstride);
453  }
465  template<typename Reducer, typename SrcExp, int etype>
467  TypeCheckPass< ExpInfo<SrcExp>::kDim >= 2 >::Error_Expression_Does_Not_Meet_Dimension_Req();
468  return PoolingExp<Reducer,SrcExp, ExpInfo<SrcExp>::kDim >(src.self(), pshape, ksize, kstride);
469  }
481  template<typename Reducer, typename Device>
482  inline UnPoolingExp<Reducer, Device> unpool( const Tensor<Device,4>&data_src, const Tensor<Device,4> &data_pooled,
483  const Tensor<Device,4> &grad_pooled, index_t ksize, index_t kstride ) {
484  return UnPoolingExp<Reducer, Device>(data_src, data_pooled, grad_pooled,ksize, kstride);
485  }
486 
495  template<typename SrcExp, int etype>
497  TypeCheckPass< ExpInfo<SrcExp>::kDim >= 2 >::Error_Expression_Does_Not_Meet_Dimension_Req();
499  }
500 
509  template<typename SrcExp, int etype>
511  TypeCheckPass< ExpInfo<SrcExp>::kDim >= 2 >::Error_Expression_Does_Not_Meet_Dimension_Req();
512  return CroppingExp<SrcExp, ExpInfo<SrcExp>::kDim>(src.self(), oshape);
513  }
524  template<typename SrcExp, int etype>
525  inline CroppingExp<SrcExp, ExpInfo<SrcExp>::kDim> crop( const Exp<SrcExp, etype> &src, Shape<2> oshape, index_t start_height, index_t start_width ) {
526  TypeCheckPass< ExpInfo<SrcExp>::kDim >= 2 >::Error_Expression_Does_Not_Meet_Dimension_Req();
527  return CroppingExp<SrcExp, ExpInfo<SrcExp>::kDim>(src.self(), oshape, start_height, start_width);
528  }
529 
537  template<typename SrcExp, int etype>
539  TypeCheckPass< ExpInfo<SrcExp>::kDim >= 2 >::Error_Expression_Does_Not_Meet_Dimension_Req();
541  }
542 
552  template<typename Reducer, typename SrcExp, int etype>
554  TypeCheckPass< ExpInfo<SrcExp>::kDim >= 3 >::Error_Expression_Does_Not_Meet_Dimension_Req();
556  }
557  // short cut functions
565  template<typename Device>
567  return broadcast<0>( src, Shape2( nrow, src.shape[0] ) );
568  }
576  template<typename SrcExp, int etype>
578  return sumall_except_dim<0>( exp );
579  }
580 
581  }; // namespace expr
582 }; // namespace mshadow
583 
584 // ==================================================
585 // implementations afterwards,
586 // no need to read if only use the functions
587 // --------------------------------------------------
588 namespace mshadow{
589  namespace expr{
590  template<typename SV, typename Device, typename EType, typename Reducer, int dimkeep>
591  struct ExpComplexEngine< SV, Device, 1, ReduceTo1DExp<EType,Reducer,dimkeep> >{
592  inline static void Eval( Tensor<Device,1> &dst, const ReduceTo1DExp<EType,Reducer,dimkeep> &exp ){
594  MapReduceKeepHighDim<SV,Reducer,dimkeep>( dst, exp.src_, exp.scale_ );
595  }
596  };
597 
598  template<typename SV, typename Device, typename EType, typename Reducer>
599  struct ExpComplexEngine< SV, Device, 1, ReduceTo1DExp<EType,Reducer,0> >{
600  inline static void Eval( Tensor<Device,1> &dst, const ReduceTo1DExp<EType,Reducer,0> &exp ){
601  MapReduceKeepLowest<SV,Reducer>( dst, exp.src_, exp.scale_ );
602  }
603  };
604  }; // namespace expr
605 
606  namespace expr{
608  template<typename Device, int dimdst, int dimcast>
609  struct Plan< Broadcast1DExp<Device,dimdst,dimcast> >{
610  public:
612  : dptr_( e.src_.dptr ),
613  ystride_( e.shape_.ProdShape(1,dimcast) ),
614  length_(e.shape_[dimcast]){
616  }
617  MSHADOW_XINLINE real_t Eval( index_t y, index_t x ) const{
618  return dptr_[ (y / ystride_) % length_ ];
619  }
620  private:
621  const real_t *dptr_;
622  const index_t ystride_, length_;
623  };
624 
626  template<typename Device, int dimdst>
627  struct Plan< Broadcast1DExp<Device,dimdst,0> >{
628  public:
629  Plan( const Broadcast1DExp<Device,dimdst,0> &e ): dptr_( e.src_.dptr ){}
630  MSHADOW_XINLINE real_t Eval( index_t y, index_t x ) const{
631  return dptr_[ x ];
632  }
633  private:
634  const real_t *dptr_;
635  };
636  }; // namespace expr
637 
638  namespace expr{
639  template<typename SrcExp, int srcdim>
640  struct Plan< UnpackPatchToColXExp<SrcExp,srcdim> >{
641  public:
643  :src_(MakePlan(e.img_)),psize_(e.psize_), pstride_(e.pstride_),
644  i_channel_(e.i_channel_), i_height_(e.i_height_), i_width_(e.i_width_),
645  o_height_(( i_height_ - psize_ ) / pstride_ + 1),
646  o_width_ (( i_width_ - psize_ ) / pstride_ + 1){
647  }
648  MSHADOW_XINLINE real_t Eval( index_t i, index_t j ) const{
649  const index_t x_offset = i % psize_;
650  const index_t idivp = i / psize_;
651  const index_t y_offset = idivp % psize_;
652  const index_t c = idivp / psize_;
653  const index_t x = (j % o_width_) * pstride_ + x_offset;
654  const index_t jdivw = j / o_width_;
655  const index_t y = (jdivw % o_height_) * pstride_ + y_offset;
656  const index_t n = jdivw / o_height_;
657 
658  if( x < i_width_ && y < i_height_ ){
659  return src_.Eval( ( n * i_channel_ + c ) * i_height_ + y, x );
660  }else{
661  return 0.0f;
662  }
663  }
664  private:
665  Plan<SrcExp> src_;
666  const index_t psize_, pstride_, i_channel_, i_height_, i_width_, o_height_, o_width_;
667  };
668 
669  template<typename Device, int dstdim>
670  struct Plan< PackColToPatchXExp<Device, dstdim> >{
671  public:
673  :mat_(e.mat_), psize_(e.psize_), pstride_(e.pstride_),
674  i_channel_(e.shape_[2]), i_height_(e.shape_[1]),
675  o_width_(( e.shape_[0] - psize_ ) / pstride_ + 1),
676  o_height_(( e.shape_[1] - psize_ ) / pstride_ + 1){
677  // note: i/o convention are same as unpack
678  }
679  MSHADOW_XINLINE real_t Eval( index_t i, index_t j ) const{
680  using namespace std;
681  const index_t y = i % i_height_;
682  const index_t idivh = i / i_height_;
683  const index_t c = idivh % i_channel_;
684  const index_t n = idivh / i_channel_;
685  const index_t x = j;
686  const index_t py_min = y < psize_ ? 0 : (y-psize_+pstride_)/pstride_;
687  const index_t px_min = x < psize_ ? 0 : (x-psize_+pstride_)/pstride_;
688  const index_t py_max = min( (y+pstride_)/pstride_, o_height_);
689  const index_t px_max = min( (x+pstride_)/pstride_, o_width_ );
690  real_t res = 0.0f;
691  for( index_t py = py_min; py < py_max; ++py ){
692  for( index_t px = px_min; px < px_max; ++px ){
693  res += mat_[ (c * psize_ + y - py*pstride_) * psize_ + x - px*pstride_ ][ (n * o_height_ + py) * o_width_+px ];
694  }
695  }
696  return res;
697  }
698  private:
699  Tensor<Device,2> mat_;
700  const index_t psize_, pstride_, i_channel_, i_height_, o_width_, o_height_;
701  };
702  };
703 
704  namespace expr{
705  template<typename SrcExp, int dimdst, int dimsrc>
706  struct Plan< ReshapeExp<SrcExp,dimdst,dimsrc> >{
707  public:
709  : src_(MakePlan(e.src_)), oshape0_(e.shape_[0]), ishape0_(e.ishape0_){
710  }
711  MSHADOW_XINLINE real_t Eval( index_t y, index_t x ) const{
712  const index_t idx = y * oshape0_ + x;
713  return src_.Eval( idx / ishape0_, idx % ishape0_ );
714  }
715  private:
716  Plan<SrcExp> src_;
717  const index_t oshape0_, ishape0_;
718  };
719  // special work plan for 1 dimensional data
720  template<typename SrcExp,int dimdst>
721  struct Plan< ReshapeExp<SrcExp,dimdst,1> >{
722  public:
724  : src_(MakePlan(e.src_)), oshape0_(e.shape_[0]){
725  }
726  MSHADOW_XINLINE real_t Eval( index_t y, index_t x ) const{
727  return src_.Eval( 0, y * oshape0_ + x );
728  }
729  private:
730  Plan<SrcExp> src_;
731  const index_t oshape0_;
732  };
733  };
734 
735  namespace expr{
736  template<typename SrcExp,int dimsrc, int a1, int a2>
737  struct Plan< SwapAxisExp<SrcExp,dimsrc,a1,a2> >{
738  public:
740  : src_(MakePlan(e.src_)),
741  shape1_( e.shape_.ProdShape( 1, a1 ) ),
742  shape2_( e.shape_[a1] ),
743  shape3_( e.shape_.ProdShape( a1+1, a2 ) ),
744  shape4_( e.shape_[a2] ){
745  }
746  MSHADOW_XINLINE real_t Eval( index_t i, index_t j ) const{
747  const index_t y = i % shape1_;
748  i /= shape1_;
749  const index_t z = i % shape2_;
750  i /= shape2_;
751  const index_t c = i % shape3_;
752  i /= shape3_;
753  const index_t n = i % shape4_;
754  // swap z and n
755  return src_.Eval( ((((i/shape4_)*shape2_ + z) * shape3_+c) * shape4_ + n ) * shape1_ + y, j );
756  }
757  private:
758  Plan<SrcExp> src_;
759  const index_t shape1_, shape2_, shape3_, shape4_;
760  };
761 
762  template<typename SrcExp,int dimsrc, int a2>
763  struct Plan< SwapAxisExp<SrcExp,dimsrc,0,a2> >{
764  public:
766  : src_(MakePlan(e.src_)),
767  shape0_( e.shape_[0] ),
768  shape1_( e.shape_.ProdShape(1,a2) ),
769  shape2_( e.shape_[a2] ){
770  }
771  MSHADOW_XINLINE real_t Eval( index_t i, index_t x ) const{
772  // swap x and z
773  const index_t y = i % shape1_;
774  i /= shape1_;
775  const index_t z = i % shape2_;
776  const index_t n = i / shape2_;
777  return src_.Eval( ( n*shape0_ + x ) * shape1_ + y , z );
778  }
779  private:
780  Plan<SrcExp> src_;
781  const index_t shape0_, shape1_, shape2_;
782  };
783  };
784 
785  namespace expr{
786  template<typename Reducer, typename SrcExp, int srcdim>
787  struct Plan< PoolingExp< Reducer, SrcExp, srcdim> > {
788  public:
790  : src_( MakePlan( e.src_ ) ), ksize_(e.ksize_), kstride_(e.kstride_),
791  src_height_(e.src_height_),src_width_(e.src_width_), new_height_(e.shape_[1]) {
792  }
793  MSHADOW_XINLINE real_t Eval(index_t i, index_t j) const {
794  using namespace std;
795  const index_t py = i % new_height_;
796  const index_t y_start = py * kstride_;
797  const index_t y_end = min( y_start + ksize_, src_height_ );
798  const index_t px = j;
799  const index_t x_start = px * kstride_;
800  const index_t x_end = min( x_start + ksize_, src_width_ );
801  const index_t c = i / new_height_;
802 
803  real_t res = Reducer::kInitV;
804  for (index_t y = y_start; y < y_end; ++y) {
805  for (index_t x = x_start; x < x_end; ++x) {
806  Reducer::Reduce( res, src_.Eval( c*src_height_+y, x ) );
807  }
808  }
809  return res;
810  }
811  private:
812  Plan<SrcExp> src_;
813  const index_t ksize_, kstride_;
814  const index_t src_height_, src_width_;
815  const index_t new_height_;
816  };
817 
818  template<typename Reducer, typename Device>
819  struct Plan<UnPoolingExp<Reducer, Device> > {
820  public:
822  : data_src_(e.data_src_), data_pooled_(e.data_pooled_), grad_pooled_(e.grad_pooled_),
823  ksize_(e.ksize_), kstride_(e.kstride_) {}
824  MSHADOW_XINLINE real_t Eval(index_t i, index_t j) const {
825  using namespace std;
826  const index_t x = j;
827  const index_t y = i % data_src_.shape[1];
828  const index_t c = i / data_src_.shape[1];
829  const real_t vsrc = data_src_[0][c][y][x];
830 
831  const index_t py_min = y < ksize_ ? 0 : (y-ksize_+kstride_)/kstride_;
832  const index_t px_min = x < ksize_ ? 0 : (x-ksize_+kstride_)/kstride_;
833  const index_t py_max = min( (y+kstride_)/kstride_, data_pooled_.shape[1]);
834  const index_t px_max = min( (x+kstride_)/kstride_, data_pooled_.shape[0]);
835 
836  real_t val = 0;
837  for( index_t py = py_min; py < py_max; ++py ){
838  for( index_t px = px_min; px < px_max; ++px ){
839  val += Reducer::PartialGrad(vsrc, data_pooled_[0][c][py][px]) * grad_pooled_[0][c][py][px];
840  }
841  }
842  return val;
843  }
844  private:
845  Tensor<Device, 4> data_src_, data_pooled_, grad_pooled_;
846  const index_t ksize_;
847  const index_t kstride_;
848  };
849  }; // namespace expr
850 
851  namespace expr{
852  template<typename SrcExp, int srcdim>
853  struct Plan< PaddingExp<SrcExp, srcdim> > {
854  public:
856  : src_(MakePlan(e.src_)), pad_(e.pad_), new_height_(e.shape_[1]),
857  src_height_(e.src_height_), src_width_(e.src_width_) {}
858  MSHADOW_XINLINE real_t Eval(index_t i, index_t j) const {
859  const index_t x = j;
860  const index_t y = i % new_height_;
861  const index_t c = i / new_height_;
862  if (y < pad_ || x < pad_) return 0.0f;
863  const index_t h = y - pad_;
864  const index_t w = x - pad_;
865  if (h < src_height_ && w < src_width_) {
866  return src_.Eval(c * src_height_ + h, w);
867  } else {
868  return 0.0f;
869  }
870  }
871  private:
872  Plan<SrcExp> src_;
873  const index_t pad_;
874  const index_t new_height_;
875  const index_t src_height_;
876  const index_t src_width_;
877  };
878 
879  template<typename SrcExp, int srcdim>
880  struct Plan<CroppingExp<SrcExp, srcdim> > {
881  public:
883  : src_(MakePlan(e.src_)), pad_height_(e.pad_height_),pad_width_(e.pad_width_),
884  new_height_(e.shape_[1]), src_height_(e.src_height_) {}
885  MSHADOW_XINLINE real_t Eval(index_t i, index_t j) const {
886  const index_t x = j;
887  const index_t y = i % new_height_;
888  const index_t c = i / new_height_;
889  const index_t h = y + pad_height_;
890  const index_t w = x + pad_width_;
891  return src_.Eval(c * src_height_ + h, w);
892  }
893  private:
894  Plan<SrcExp> src_;
895  const index_t pad_height_, pad_width_;
896  const index_t new_height_;
897  const index_t src_height_;
898  };
899 
900  template<typename SrcExp, int srcdim>
901  struct Plan< MirroringExp<SrcExp, srcdim> > {
902  public:
904  : src_(MakePlan(e.src_)), width_(e.shape_[0]){}
905  MSHADOW_XINLINE real_t Eval(index_t i, index_t j) const {
906  return src_.Eval( i, width_ - j - 1 );
907  }
908  private:
909  Plan<SrcExp> src_;
910  const index_t width_;
911  };
912  }; // namespace expr
913 
914  namespace expr{
915  template<typename Reducer, typename SrcExp, int srcdim>
916  struct Plan< ChannelPoolingExp< Reducer, SrcExp, srcdim> > {
917  public:
919  : src_( MakePlan( e.src_ ) ), channel_(e.shape_[2]),
920  height_(e.shape_[1]),width_(e.shape_[0]), hnsize_(e.nsize_/2){
921  }
922  MSHADOW_XINLINE real_t Eval(index_t i, index_t j) const {
923  using namespace std;
924  const index_t y = i % height_;
925  i /= height_;
926  const index_t c = i % channel_;
927  const index_t n = i / channel_;
928  const index_t x = j;
929  const index_t cstart = c < hnsize_ ? 0 : c - hnsize_;
930  const index_t cend = min( c + hnsize_ + 1, channel_ );
931  real_t res = Reducer::kInitV;
932  for( index_t cc = cstart; cc < cend; ++ cc ){
933  Reducer::Reduce( res, src_.Eval( (n*channel_+cc)*height_ + y, x ) );
934  }
935  return res;
936  }
937  private:
938  Plan<SrcExp> src_;
939  const index_t channel_, height_, width_, hnsize_;
940  };
941  };
942 }; // namespace mshadow
943 
944 #if MSHADOW_USE_SSE
945 // implementations of SSE support, if possible
946 #include "tensor_sse-inl.hpp"
947 namespace mshadow{
948  namespace expr{
949  template<int dimdst>
950  struct SSECheck< Broadcast1DExp<cpu,dimdst,0> >{
951  const static bool kPass = true;
952  };
953  template<int dimdst>
954  struct SSEAlignCheck<2, Broadcast1DExp<cpu,dimdst,0> >{
955  inline static bool Check( const Broadcast1DExp<cpu,dimdst,0> &exp ){
956  return sse2::CheckAlign( exp.src_.dptr );
957  }
958  };
959  template<int dimdst>
960  class SSEPlan< Broadcast1DExp<cpu,dimdst,0> >{
961  public:
963  :dptr_(t.src_.dptr){}
964  MSHADOW_CINLINE sse2::FVec<real_t> EvalSSE( index_t y, index_t x ) const{
965  return sse2::FVec<real_t>( &dptr_[ x ] );
966  }
967  MSHADOW_CINLINE real_t Eval( index_t y, index_t x ) const{
968  return dptr_[ x ];
969  }
970  private:
971  const real_t *dptr_;
972  };
973  };
974 };
975 #endif
976 
977 #endif
978 
const SrcExp & src_
source operand
Definition: tensor_expr_ext.h:302
ChannelPoolingExp< Reducer, SrcExp, ExpInfo< SrcExp >::kDim > chpool(const Exp< SrcExp, etype > &src, index_t nsize)
channel pooling, do reduction over (local nearby) channels, used to implement local response normaliz...
Definition: tensor_expr_ext.h:553
broadcast Tensor1D into a higher dimension Tensor input: Tensor<Device,1>: ishape[0] output: Tensor<D...
Definition: tensor_expr_ext.h:21
DotExp< TA, TB, ltrans, rtrans > operator*(const DotExp< TA, TB, ltrans, rtrans > &lhs, real_t rhs)
dot operator def
Definition: tensor_expr.h:206
unsigned index_t
type that will be used for index
Definition: tensor_base.h:123
Shape< dim > shape_
the shape of this expression
Definition: tensor_expr_engine-inl.hpp:22
PoolingExp(const SrcExp &src, Shape< 2 > pshape, index_t ksize, index_t kstride)
constructor, specify shape
Definition: tensor_expr_ext.h:187
index_t src_height_
source tensor height
Definition: tensor_expr_ext.h:240
UnPoolingExp< Reducer, Device > unpool(const Tensor< Device, 4 > &data_src, const Tensor< Device, 4 > &data_pooled, const Tensor< Device, 4 > &grad_pooled, index_t ksize, index_t kstride)
unpooling gradient for 4D, backprop gradient value back, revserse operation of pooling ...
Definition: tensor_expr_ext.h:482
This part of code gives plan that can be used to carry out execution.
Definition: tensor_expr_engine-inl.hpp:33
const SrcExp & src_
source operand
Definition: tensor_expr_ext.h:236
PoolingExp< Reducer, SrcExp, ExpInfo< SrcExp >::kDim > pool(const Exp< SrcExp, etype > &src, index_t ksize, index_t kstride)
pooling subregion results together
Definition: tensor_expr_ext.h:450
Definition: tensor_expr_engine-inl.hpp:201
bool CheckAlign(size_t pitch)
check if a pointer is aligned
Definition: tensor_sse-inl.hpp:52
index_t psize_
patch size
Definition: tensor_expr_ext.h:79
channel pooling expression, do reduction over (local nearby) channels, used to implement local respon...
Definition: tensor_expr_ext.h:316
shape of a tensor IMPORTANT NOTE: this shape is different from numpy.shape shape[0] gives the lowest ...
Definition: tensor.h:23
index_t ksize_
kernel size
Definition: tensor_expr_ext.h:213
ReshapeExp(const SrcExp &src, Shape< dimdst > shape)
constructor
Definition: tensor_expr_ext.h:108
index_t kstride_
kernel stride
Definition: tensor_expr_ext.h:215
const SrcExp & src_
source operand
Definition: tensor_expr_ext.h:166
real_t scale_
source operand, scale of the
Definition: tensor_expr_ext.h:152
const Tensor< Device, 4 > & grad_pooled_
gradient data of pooled part, to be propgate down
Definition: tensor_expr_ext.h:211
PaddingExp(const SrcExp &src, index_t pad)
constructor
Definition: tensor_expr_ext.h:244
index_t pstride_
patch stride
Definition: tensor_expr_ext.h:81
PackColToPatchXExp< Device, dstdim > pack_col2patch(const Tensor< Device, 2 > &mat, Shape< dstdim > imshape, index_t psize, index_t pstride)
reverse operation of pack_col2patch, can be used to implement deconvolution
Definition: tensor_expr_ext.h:392
void Assert(bool exp)
assert a expression is true
Definition: tensor_base.h:285
ReduceTo1DExp< SrcExp, red::sum, 0 > sum_rows(const Exp< SrcExp, etype > &exp)
a expression that sum over rows of a matrix
Definition: tensor_expr_ext.h:577
PackColToPatchXExp(const Tensor< Device, 2 > &mat, Shape< dstdim > imshape, index_t psize, index_t pstride)
constructor
Definition: tensor_expr_ext.h:83
unpack local (overlap) patches of image to column of mat, can be used to implement convolution...
Definition: tensor_expr_ext.h:38
MirroringExp< SrcExp, ExpInfo< SrcExp >::kDim > mirror(const Exp< SrcExp, etype > &src)
mirroring expression, mirror images in width
Definition: tensor_expr_ext.h:538
index_t src_width_
source tensor width
Definition: tensor_expr_ext.h:242
unpooling expr reverse operation of pooling, used to pass gradient back
Definition: tensor_expr_ext.h:205
support of sse2 optimization of some operations
CroppingExp(const SrcExp &src, Shape< 2 > cshape)
constructor
Definition: tensor_expr_ext.h:270
static check sse enable if a expression E can not be evaluated using sse, then kPass = false ...
Definition: tensor_sse-inl.hpp:357
float real_t
type that will be used for content
Definition: tensor_base.h:118
float vector real type, used for vectorization
Definition: tensor_sse-inl.hpp:88
reverse operation of UnpackPatchToCol, used to backprop gradient back this is a version supporting mu...
Definition: tensor_expr_ext.h:75
index_t src_height_
src height
Definition: tensor_expr_ext.h:268
const SubType & self(void) const
Definition: tensor_expr.h:52
padding expression, pad a image with zeros
Definition: tensor_expr_ext.h:234
const Tensor< Device, 1 > src_
source operand
Definition: tensor_expr_ext.h:23
index_t pad_width_
pad height
Definition: tensor_expr_ext.h:266
reduction to 1 dimension tensor input: Tensor<Device,k>: ishape output: Tensor<Device,1> shape[0] = ishape[dimkeep];
Definition: tensor_expr_ext.h:148
device name CPU
Definition: tensor.h:185
UnpackPatchToColXExp(const SrcExp &img, index_t psize, index_t pstride)
constructor
Definition: tensor_expr_ext.h:52
const Tensor< Device, 4 > & data_src_
source input, corresponds to src in pooling
Definition: tensor_expr_ext.h:207
PoolingExp(const SrcExp &src, index_t ksize, index_t kstride)
constructor
Definition: tensor_expr_ext.h:176
static type inference template, used to get the dimension of each expression, if ExpInfo<E>::kDim == ...
Definition: tensor_expr_engine-inl.hpp:153
const SrcExp & src_
source expression
Definition: tensor_expr_ext.h:129
UnpackPatchToColXExp< SrcExp, ExpInfo< SrcExp >::kDim > unpack_patch2col(const Exp< SrcExp, etype > &img, index_t psize, index_t pstride)
unpack local (overlap) patches of image to column of mat, can be used to implement convolution after ...
Definition: tensor_expr_ext.h:377
const Tensor< Device, 2 > & mat_
source operand
Definition: tensor_expr_ext.h:77
definitions of how expressions should be evaluated
UnPoolingExp(const Tensor< Device, 4 > &data_src, const Tensor< Device, 4 > &data_pooled, const Tensor< Device, 4 > &grad_pooled, index_t ksize, index_t kstride)
constructor
Definition: tensor_expr_ext.h:217
Definition: tensor_sse-inl.hpp:250
MSHADOW_XINLINE Shape< 2 > Shape2(index_t s1, index_t s0)
construct a two dimension shape, stride will equal s0
Definition: tensor.h:152
SwapAxisExp(const SrcExp &src)
constructor
Definition: tensor_expr_ext.h:131
Definition: tensor.h:276
index_t pstride_
patch stride
Definition: tensor_expr_ext.h:44
index_t i_height_
height of img
Definition: tensor_expr_ext.h:48
Broadcast1DExp(const Tensor< Device, 1 > &src, Shape< dimdst > shape)
constructor
Definition: tensor_expr_ext.h:25
MSHADOW_XINLINE index_t ProdShape(int dimstart, int dimend) const
Definition: tensor.h:104
Broadcast1DExp< Device, 2, 0 > repmat(const Tensor< Device, 1 > &src, index_t nrow)
a expression that replicate a 1 dimension tensor for nrow times
Definition: tensor_expr_ext.h:566
index_t ksize_
kernel size
Definition: tensor_expr_ext.h:168
index_t src_height_
source height shape[1]
Definition: tensor_expr_ext.h:172
crop expression, cut off the boundary region, reverse operation of padding
Definition: tensor_expr_ext.h:260
index_t pad_height_
pad height
Definition: tensor_expr_ext.h:264
MSHADOW_XINLINE real_t Eval(index_t y, index_t x) const
evaluate the expression at index [y][x] to be implemented by SubType
index_t pad_
pad size
Definition: tensor_expr_ext.h:238
Definition: tensor_expr_engine-inl.hpp:215
mirror expression, mirror a image in width
Definition: tensor_expr_ext.h:300
some engine that evaluate complex expression
Definition: tensor_expr_engine-inl.hpp:390
SwapAxisExp< SrcExp, ExpInfo< SrcExp >::kDim, a1, a2 > swapaxis(const Exp< SrcExp, etype > &src)
a expression that reshapes a tensor to another shape
Definition: tensor_expr_ext.h:420
index_t i_width_
width of img
Definition: tensor_expr_ext.h:50
index_t ishape0_
smallest dimension of input
Definition: tensor_expr_ext.h:106
reshape the content to another shape input: Tensor<Device,dimsrc>: ishape output: Tensor<Device...
Definition: tensor_expr_ext.h:102
Shape< dimension > shape
shape of the tensor
Definition: tensor.h:217
#define MSHADOW_CINLINE
cpu force inline
Definition: tensor_base.h:101
CroppingExp< SrcExp, ExpInfo< SrcExp >::kDim > crop(const Exp< SrcExp, etype > &src, Shape< 2 > oshape)
revserse operationg of padding, cut off boundaries, crop output from center of input ...
Definition: tensor_expr_ext.h:510
ReduceTo1DExp< SrcExp, red::sum, dimkeep > sumall_except_dim(const Exp< SrcExp, etype > &exp)
a sum over all dimensions, except dimkeep
Definition: tensor_expr_ext.h:435
base class for expression
Definition: tensor_expr.h:49
index_t psize_
patch size
Definition: tensor_expr_ext.h:42
const SrcExp & src_
source operand
Definition: tensor_expr_ext.h:318
index_t nsize_
neighbor size
Definition: tensor_expr_ext.h:320
MSHADOW_XINLINE size_t Size(void) const
Definition: tensor.h:82
const Tensor< Device, 4 > & data_pooled_
result of pooled data, corresponds to result of pooling
Definition: tensor_expr_ext.h:209
ReduceTo1DExp(const EType &src, real_t scale)
construct a repmat expression from src and nrow
Definition: tensor_expr_ext.h:154
const SrcExp & src_
source expression
Definition: tensor_expr_ext.h:104
a general class that allows extension that makes tensors of some shape
Definition: tensor_expr_engine-inl.hpp:20
const SrcExp & src_
source operand
Definition: tensor_expr_ext.h:262
const SrcExp & img_
source operand
Definition: tensor_expr_ext.h:40
const EType & src_
source operand
Definition: tensor_expr_ext.h:150
Definition: tensor_sse-inl.hpp:381
swap two axis of a tensor input: Tensor<Device,dim>: ishape output: Tensor<Device,dimdst> oshape[a1],oshape[a2] = ishape[a2],oshape[a1]
Definition: tensor_expr_ext.h:127
PaddingExp< SrcExp, ExpInfo< SrcExp >::kDim > pad(const Exp< SrcExp, etype > &src, index_t pad)
padding expression, pad a image with zeros on boundaries, padding affects shape[0], and shape[1]
Definition: tensor_expr_ext.h:496
index_t kstride_
kernel stride
Definition: tensor_expr_ext.h:170
Broadcast1DExp< Device, dimdst, dimcast > broadcast(const Tensor< Device, 1 > &src, Shape< dimdst > shape)
a expression that replicate a 1 dimension tensor in dimension dimcast
Definition: tensor_expr_ext.h:354
ReshapeExp< SrcExp, dimdst, ExpInfo< SrcExp >::kDim > reshape(const Exp< SrcExp, etype > &src, Shape< dimdst > oshape)
a expression that reshapes a tensor to another shape
Definition: tensor_expr_ext.h:406
CroppingExp(const SrcExp &src, Shape< 2 > cshape, index_t start_height, index_t start_width)
constructor
Definition: tensor_expr_ext.h:281
index_t i_channel_
number of input channel
Definition: tensor_expr_ext.h:46
MirroringExp(const SrcExp &src)
constructor
Definition: tensor_expr_ext.h:304
index_t src_width_
source width shape[0]
Definition: tensor_expr_ext.h:174
pooling expression, do reduction over local patches of a image
Definition: tensor_expr_ext.h:164
ChannelPoolingExp(const SrcExp &src, index_t nsize)
constructor
Definition: tensor_expr_ext.h:322