Layer¶
Python API¶
Python layers wrap the C++ layers to provide simpler construction APIs.
Example usages:
from singa import layer
from singa import tensor
from singa import device
layer.engine = 'cudnn' # to use cudnn layers
dev = device.create_cuda_gpu()
# create a convolution layer
conv = layer.Conv2D('conv', 32, 3, 1, pad=1, input_sample_shape=(3, 32, 32))
# init param values
w, b = conv.param_values()
w.guassian(0, 0.01)
b.set_value(0)
conv.to_device(dev) # move the layer data onto a CudaGPU device
x = tensor.Tensor((3, 32, 32), dev)
x.uniform(-1, 1)
y = conv.foward(True, x)
dy = tensor.Tensor()
dy.reset_like(y)
dy.set_value(0.1)
# dp is a list of tensors for parameter gradients
dx, dp = conv.backward(kTrain, dy)
-
singa.layer.
engine
= 'cudnn'¶ engine is the prefix of layer identifier.
The value could be one of [‘cudnn’, ‘singacpp’, ‘singacuda’, ‘singacl’], for layers implemented using the cudnn library, Cpp, Cuda and OpenCL respectively. For example, CudnnConvolution layer is identified by ‘cudnn_convolution’; ‘singacpp_convolution’ is for Convolution layer; Some layers’ implementation use only Tensor functions, thererfore they are transparent to the underlying devices. For threse layers, they would have multiple identifiers, e.g., singacpp_dropout, singacuda_dropout and singacl_dropout are all for the Dropout layer. In addition, it has an extra identifier ‘singa’, i.e. ‘singa_dropout’ also stands for the Dropout layer.
engine is case insensitive. Each python layer would create the correct specific layer using the engine attribute.
-
class
singa.layer.
Layer
(name, conf=None, **kwargs)¶ Bases:
object
Base Python layer class.
- Typically, the life cycle of a layer instance includes:
construct layer without input_sample_shapes, goto 2; construct layer with input_sample_shapes, goto 3;
call setup to create the parameters and setup other meta fields
call forward or access layer members
call backward and get parameters for update
- Parameters
name (str) – layer name
-
setup
(in_shapes)¶ Call the C++ setup function to create params and set some meta data.
- Parameters
in_shapes – if the layer accepts a single input Tensor, in_shapes is a single tuple specifying the inpute Tensor shape; if the layer accepts multiple input Tensor (e.g., the concatenation layer), in_shapes is a tuple of tuples, each for one input Tensor
-
caffe_layer
()¶ Create a singa layer based on caffe layer configuration.
-
get_output_sample_shape
()¶ Called after setup to get the shape of the output sample(s).
- Returns
a tuple for a single output Tensor or a list of tuples if this layer has multiple outputs
-
param_names
()¶ - Returns
a list of strings, one for the name of one parameter Tensor
-
param_values
()¶ Return param value tensors.
Parameter tensors are not stored as layer members because cpp Tensor could be moved onto diff devices due to the change of layer device, which would result in inconsistency.
- Returns
a list of tensors, one for each paramter
-
forward
(flag, x)¶ Forward propagate through this layer.
- Parameters
flag – True (kTrain) for training (kEval); False for evaluating; other values for furture use.
x (Tensor or list<Tensor>) – an input tensor if the layer is connected from a single layer; a list of tensors if the layer is connected from multiple layers.
- Returns
a tensor if the layer is connected to a single layer; a list of tensors if the layer is connected to multiple layers;
-
backward
(flag, dy)¶ Backward propagate gradients through this layer.
- Parameters
flag (int) – for future use.
dy (Tensor or list<Tensor>) – the gradient tensor(s) y w.r.t the objective loss
- Returns
<dx, <dp1, dp2..>>, dx is a (set of) tensor(s) for the gradient of x , dpi is the gradient of the i-th parameter
-
to_device
(device)¶ Move layer state tensors onto the given device.
- Parameters
device – swig converted device, created using singa.device
-
as_type
(dtype)¶
-
class
singa.layer.
Dummy
(name, input_sample_shape=None)¶ Bases:
singa.layer.Layer
A dummy layer that does nothing but just forwards/backwards the data (the input/output is a single tensor).
-
get_output_sample_shape
()¶ Called after setup to get the shape of the output sample(s).
- Returns
a tuple for a single output Tensor or a list of tuples if this layer has multiple outputs
-
setup
(input_sample_shape)¶ Call the C++ setup function to create params and set some meta data.
- Parameters
in_shapes – if the layer accepts a single input Tensor, in_shapes is a single tuple specifying the inpute Tensor shape; if the layer accepts multiple input Tensor (e.g., the concatenation layer), in_shapes is a tuple of tuples, each for one input Tensor
-
forward
(flag, x)¶ Return the input x
-
backward
(falg, dy)¶ Return dy, []
-
-
class
singa.layer.
Conv2D
(name, nb_kernels, kernel=3, stride=1, border_mode='same', cudnn_prefer='fastest', workspace_byte_limit=1024, data_format='NCHW', use_bias=True, W_specs=None, b_specs=None, pad=None, input_sample_shape=None)¶ Bases:
singa.layer.Layer
Construct a layer for 2D convolution.
- Parameters
nb_kernels (int) – num of the channels (kernels) of the input Tensor
kernel – an integer or a pair of integers for kernel height and width
stride – an integer or a pair of integers for stride height and width
border_mode (string) – padding mode, case in-sensitive, ‘valid’ -> padding is 0 for height and width ‘same’ -> padding is half of the kernel (floor), the kernel must be odd number.
cudnn_prefer (string) – the preferred algorithm for cudnn convolution which could be ‘fastest’, ‘autotune’, ‘limited_workspace’ and ‘no_workspace’
workspace_byte_limit (int) – max workspace size in MB (default is 512MB)
data_format (string) – either ‘NCHW’ or ‘NHWC’
use_bias (bool) – True or False
pad – an integer or a pair of integers for padding height and width
W_specs (dict) – used to specify the weight matrix specs, fields include, ‘name’ for parameter name ‘lr_mult’ for learning rate multiplier ‘decay_mult’ for weight decay multiplier ‘init’ for init method, which could be ‘gaussian’, ‘uniform’, ‘xavier’ and ‘’ ‘std’, ‘mean’, ‘high’, ‘low’ for corresponding init methods TODO(wangwei) ‘clamp’ for gradient constraint, value is scalar ‘regularizer’ for regularization, currently support ‘l2’
b_specs (dict) – hyper-parameters for bias vector, similar as W_specs
name (string) – layer name.
input_sample_shape – 3d tuple for the shape of the input Tensor without the batchsize, e.g., (channel, height, width) or (height, width, channel)
-
setup
(in_shape)¶ Set up the kernel, stride and padding; then call the C++ setup function to create params and set some meta data.
- Parameters
is a tuple of int for the input sample shape (in_shapes) –
-
class
singa.layer.
Conv1D
(name, nb_kernels, kernel=3, stride=1, border_mode='same', cudnn_prefer='fastest', workspace_byte_limit=1024, use_bias=True, W_specs={'init': 'Xavier'}, b_specs={'init': 'Constant', 'value': 0}, pad=None, input_sample_shape=None)¶ Bases:
singa.layer.Conv2D
Construct a layer for 1D convolution.
Most of the args are the same as those for Conv2D except the kernel, stride, pad, which is a scalar instead of a tuple. input_sample_shape is a tuple with a single value for the input feature length
-
get_output_sample_shape
()¶ Called after setup to get the shape of the output sample(s).
- Returns
a tuple for a single output Tensor or a list of tuples if this layer has multiple outputs
-
-
class
singa.layer.
Pooling2D
(name, mode, kernel=3, stride=2, border_mode='same', pad=None, data_format='NCHW', input_sample_shape=None)¶ Bases:
singa.layer.Layer
2D pooling layer providing max/avg pooling.
All args are the same as those for Conv2D, except the following one
- Parameters
mode – pooling type, model_pb2.PoolingConf.MAX or model_pb2.PoolingConf.AVE
-
setup
(in_shape)¶ Set up the kernel, stride and padding; then call the C++ setup function to create params and set some meta data.
- Parameters
is a tuple of int for the input sample shape (in_shapes) –
-
class
singa.layer.
MaxPooling2D
(name, kernel=3, stride=2, border_mode='same', pad=None, data_format='NCHW', input_sample_shape=None)¶ Bases:
singa.layer.Pooling2D
-
class
singa.layer.
AvgPooling2D
(name, kernel=3, stride=2, border_mode='same', pad=None, data_format='NCHW', input_sample_shape=None)¶ Bases:
singa.layer.Pooling2D
-
class
singa.layer.
MaxPooling1D
(name, kernel=3, stride=2, border_mode='same', pad=None, data_format='NCHW', input_sample_shape=None)¶ Bases:
singa.layer.MaxPooling2D
-
get_output_sample_shape
()¶ Called after setup to get the shape of the output sample(s).
- Returns
a tuple for a single output Tensor or a list of tuples if this layer has multiple outputs
-
-
class
singa.layer.
AvgPooling1D
(name, kernel=3, stride=2, border_mode='same', pad=None, data_format='NCHW', input_sample_shape=None)¶ Bases:
singa.layer.AvgPooling2D
-
get_output_sample_shape
()¶ Called after setup to get the shape of the output sample(s).
- Returns
a tuple for a single output Tensor or a list of tuples if this layer has multiple outputs
-
-
class
singa.layer.
BatchNormalization
(name, momentum=0.9, beta_specs=None, gamma_specs=None, input_sample_shape=None)¶ Bases:
singa.layer.Layer
Batch-normalization.
- Parameters
momentum (float) – for running average mean and variance.
beta_specs (dict) – dictionary includes the fields for the beta param: ‘name’ for parameter name ‘lr_mult’ for learning rate multiplier ‘decay_mult’ for weight decay multiplier ‘init’ for init method, which could be ‘gaussian’, ‘uniform’, ‘xavier’ and ‘’ ‘std’, ‘mean’, ‘high’, ‘low’ for corresponding init methods ‘clamp’ for gradient constraint, value is scalar ‘regularizer’ for regularization, currently support ‘l2’
gamma_specs (dict) – similar to beta_specs, but for the gamma param.
name (string) – layer name
input_sample_shape (tuple) – with at least one integer
-
class
singa.layer.
L2Norm
(name, input_sample_shape, epsilon=1e-08)¶ Bases:
singa.layer.Layer
Normalize each sample to have L2 norm = 1
-
get_output_sample_shape
()¶ Called after setup to get the shape of the output sample(s).
- Returns
a tuple for a single output Tensor or a list of tuples if this layer has multiple outputs
-
forward
(is_train, x)¶ Forward propagate through this layer.
- Parameters
flag – True (kTrain) for training (kEval); False for evaluating; other values for furture use.
x (Tensor or list<Tensor>) – an input tensor if the layer is connected from a single layer; a list of tensors if the layer is connected from multiple layers.
- Returns
a tensor if the layer is connected to a single layer; a list of tensors if the layer is connected to multiple layers;
-
backward
(is_train, dy)¶ Backward propagate gradients through this layer.
- Parameters
flag (int) – for future use.
dy (Tensor or list<Tensor>) – the gradient tensor(s) y w.r.t the objective loss
- Returns
<dx, <dp1, dp2..>>, dx is a (set of) tensor(s) for the gradient of x , dpi is the gradient of the i-th parameter
-
-
class
singa.layer.
LRN
(name, size=5, alpha=1, beta=0.75, mode='cross_channel', k=1, input_sample_shape=None)¶ Bases:
singa.layer.Layer
Local response normalization.
- Parameters
size (int) – # of channels to be crossed normalization.
mode (string) – ‘cross_channel’
input_sample_shape (tuple) – 3d tuple, (channel, height, width)
-
class
singa.layer.
Dense
(name, num_output, use_bias=True, W_specs=None, b_specs=None, W_transpose=False, input_sample_shape=None)¶ Bases:
singa.layer.Layer
Apply linear/affine transformation, also called inner-product or fully connected layer.
- Parameters
num_output (int) – output feature length.
use_bias (bool) – add a bias vector or not to the transformed feature
W_specs (dict) – specs for the weight matrix ‘name’ for parameter name ‘lr_mult’ for learning rate multiplier ‘decay_mult’ for weight decay multiplier ‘init’ for init method, which could be ‘gaussian’, ‘uniform’, ‘xavier’ and ‘’ ‘std’, ‘mean’, ‘high’, ‘low’ for corresponding init methods ‘clamp’ for gradient constraint, value is scalar ‘regularizer’ for regularization, currently support ‘l2’
b_specs (dict) – specs for the bias vector, same fields as W_specs.
W_transpose (bool) – if true, output=x*W.T+b;
input_sample_shape (tuple) – input feature length
-
class
singa.layer.
Dropout
(name, p=0.5, input_sample_shape=None)¶ Bases:
singa.layer.Layer
Droput layer.
- Parameters
p (float) – probability for dropping out the element, i.e., set to 0
name (string) – layer name
-
class
singa.layer.
Activation
(name, mode='relu', input_sample_shape=None)¶ Bases:
singa.layer.Layer
Activation layers.
- Parameters
name (string) – layer name
mode (string) – ‘relu’, ‘sigmoid’, or ‘tanh’
input_sample_shape (tuple) – shape of a single sample
-
class
singa.layer.
Softmax
(name, axis=1, input_sample_shape=None)¶ Bases:
singa.layer.Layer
Apply softmax.
- Parameters
axis (int) – reshape the input as a matrix with the dimension [0,axis) as the row, the [axis, -1) as the column.
input_sample_shape (tuple) – shape of a single sample
-
class
singa.layer.
Flatten
(name, axis=1, input_sample_shape=None)¶ Bases:
singa.layer.Layer
Reshape the input tensor into a matrix.
- Parameters
axis (int) – reshape the input as a matrix with the dimension [0,axis) as the row, the [axis, -1) as the column.
input_sample_shape (tuple) – shape for a single sample
-
class
singa.layer.
Merge
(name, input_sample_shape=None)¶ Bases:
singa.layer.Layer
Sum all input tensors.
- Parameters
input_sample_shape – sample shape of the input. The sample shape of all inputs should be the same.
-
setup
(in_shape)¶ Call the C++ setup function to create params and set some meta data.
- Parameters
in_shapes – if the layer accepts a single input Tensor, in_shapes is a single tuple specifying the inpute Tensor shape; if the layer accepts multiple input Tensor (e.g., the concatenation layer), in_shapes is a tuple of tuples, each for one input Tensor
-
get_output_sample_shape
()¶ Called after setup to get the shape of the output sample(s).
- Returns
a tuple for a single output Tensor or a list of tuples if this layer has multiple outputs
-
forward
(flag, inputs)¶ Merge all input tensors by summation.
TODO(wangwei) do element-wise merge operations, e.g., avg, count :param flag: not used. :param inputs: a list of tensors :type inputs: list
- Returns
A single tensor as the sum of all input tensors
-
class
singa.layer.
Split
(name, num_output, input_sample_shape=None)¶ Bases:
singa.layer.Layer
Replicate the input tensor.
- Parameters
num_output (int) – number of output tensors to generate.
input_sample_shape – includes a single integer for the input sample feature size.
-
setup
(in_shape)¶ Call the C++ setup function to create params and set some meta data.
- Parameters
in_shapes – if the layer accepts a single input Tensor, in_shapes is a single tuple specifying the inpute Tensor shape; if the layer accepts multiple input Tensor (e.g., the concatenation layer), in_shapes is a tuple of tuples, each for one input Tensor
-
get_output_sample_shape
()¶ Called after setup to get the shape of the output sample(s).
- Returns
a tuple for a single output Tensor or a list of tuples if this layer has multiple outputs
-
forward
(flag, input)¶ Replicate the input tensor into mutiple tensors.
- Parameters
flag – not used
input – a single input tensor
- Returns
a list a output tensor (each one is a copy of the input)
-
backward
(flag, grads)¶ Sum all grad tensors to generate a single output tensor.
- Parameters
grads (list of Tensor) –
- Returns
a single tensor as the sum of all grads
-
class
singa.layer.
Concat
(name, axis, input_sample_shapes=None)¶ Bases:
singa.layer.Layer
Concatenate tensors vertically (axis = 0) or horizontally (axis = 1).
Currently, only support tensors with 2 dimensions.
- Parameters
axis (int) – 0 for concat row; 1 for concat columns;
input_sample_shapes – a list of sample shape tuples, one per input tensor
-
forward
(flag, inputs)¶ Concatenate all input tensors.
- Parameters
flag – same as Layer::forward()
input – a list of tensors
- Returns
a single concatenated tensor
-
class
singa.layer.
Slice
(name, axis, slice_point, input_sample_shape=None)¶ Bases:
singa.layer.Layer
Slice the input tensor into multiple sub-tensors vertially (axis=0) or horizontally (axis=1).
- Parameters
axis (int) – 0 for slice rows; 1 for slice columns;
slice_point (list) – positions along the axis to do slice; there are n-1 points for n sub-tensors;
input_sample_shape – input tensor sample shape
-
get_output_sample_shape
()¶ Called after setup to get the shape of the output sample(s).
- Returns
a tuple for a single output Tensor or a list of tuples if this layer has multiple outputs
-
forward
(flag, x)¶ Slice the input tensor on the given axis.
- Parameters
flag – same as Layer::forward()
x – a single input tensor
- Returns
a list a output tensor
-
backward
(flag, grads)¶ Concate all grad tensors to generate a single output tensor
- Parameters
flag – same as Layer::backward()
grads – a list of tensors, one for the gradient of one sliced tensor
- Returns
- a single tensor for the gradient of the original user, and an empty
list.
-
class
singa.layer.
RNN
(name, hidden_size, rnn_mode='lstm', dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False, param_specs=None, input_sample_shape=None)¶ Bases:
singa.layer.Layer
Recurrent layer with 4 types of units, namely lstm, gru, tanh and relu.
- Parameters
hidden_size – hidden feature size, the same for all stacks of layers.
rnn_mode – decides the rnn unit, which could be one of ‘lstm’, ‘gru’, ‘tanh’ and ‘relu’, refer to cudnn manual for each mode.
num_stacks – num of stacks of rnn layers. It is different to the unrolling seqence length.
input_mode – ‘linear’ convert the input feature x by by a linear transformation to get a feature vector of size hidden_size; ‘skip’ does nothing but requires the input feature size equals hidden_size
bidirection – True for bidirectional RNN
param_specs – config for initializing the RNN parameters.
input_sample_shape – includes a single integer for the input sample feature size.
-
forward
(flag, inputs)¶ Forward inputs through the RNN.
- Parameters
flag – True(kTrain) for training; False(kEval) for evaluation; others values for future use.
<x1, x2,..xn, hx, cx>, where xi is the input tensor for the (inputs,) – i-th position, its shape is (batch_size, input_feature_length); the batch_size of xi must >= that of xi+1; hx is the initial hidden state of shape (num_stacks * bidirection?2:1, batch_size, hidden_size). cx is the initial cell state tensor of the same shape as hy. cx is valid for only lstm. For other RNNs there is no cx. Both hx and cx could be dummy tensors without shape and data.
- Returns
- <y1, y2, … yn, hy, cy>, where yi is the output tensor for the i-th
position, its shape is (batch_size, hidden_size * bidirection?2:1). hy is the final hidden state tensor. cx is the final cell state tensor. cx is only used for lstm.
-
backward
(flag, grad)¶ Backward gradients through the RNN.
- Parameters
for future use. (flag,) –
<dy1, dy2,..dyn, dhy, dcy>, where dyi is the gradient for the (grad,) –
output, its shape is (batch_size, hidden_size*bidirection?2 (i-th) – 1); dhy is the gradient for the final hidden state, its shape is (num_stacks * bidirection?2:1, batch_size, hidden_size). dcy is the gradient for the final cell state. cx is valid only for lstm. For other RNNs there is no cx. Both dhy and dcy could be dummy tensors without shape and data.
- Returns
- <dx1, dx2, … dxn, dhx, dcx>, where dxi is the gradient tensor for
the i-th input, its shape is (batch_size, input_feature_length). dhx is the gradient for the initial hidden state. dcx is the gradient for the initial cell state, which is valid only for lstm.
-
class
singa.layer.
LSTM
(name, hidden_size, dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False, param_specs=None, input_sample_shape=None)¶ Bases:
singa.layer.RNN
-
class
singa.layer.
GRU
(name, hidden_size, dropout=0.0, num_stacks=1, input_mode='linear', bidirectional=False, param_specs=None, input_sample_shape=None)¶ Bases:
singa.layer.RNN
-
singa.layer.
get_layer_list
()¶ Return a list of strings which include the identifiers (tags) of all supported layers