Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Macros
Public Member Functions | Static Public Member Functions | Protected Member Functions | Protected Attributes | List of all members
singa::Param Class Reference

Base paramter class. More...

#include <param.h>

Public Member Functions

void Init (const ParamProto &proto)
 
virtual void Setup (const std::vector< int > &shape)
 Setup param object. More...
 
virtual void InitValues ()
 
virtual void InitValues (int version)
 
void ShareFrom (const Param &other)
 Share the data blob from other Param objects. More...
 
void FromProto (const BlobProto &blob)
 Init param values from checkpoint blob.
 
void ToProto (BlobProto *blob)
 Dump param values to blob.
 
void AddSlice (int slice_id, int size)
 Add a slice. More...
 
float lr_scale () const
 Scale the learning rate when updating parameters in the Param object.
 
float wd_scale () const
 Scale the weight decay when updating parameters in the Param object.
 
const std::string & name () const
 Parameter name used for Param re-use in other model or sharing between layers.
 
void set_name (const std::string &name)
 
int owner () const
 If it shares data from others, then owner is the id of that Param, otherwise it is itself's id.
 
int id () const
 ID start from 0 and ordered for all Param from the same neuralnet.
 
void set_id (int id)
 Set ID.
 
int version () const
 Param version is stored inside the data blob to enable all Param objs sharing the same values have the same version. More...
 
void set_version (int v)
 
int local_version () const
 
void set_local_version (int v)
 
const std::string & share_from () const
 
int size () const
 
const Blob< float > & data () const
 
Blob< float > * mutable_data ()
 
const Blob< float > & grad () const
 
Blob< float > * mutable_grad ()
 
float * mutable_cpu_data ()
 
float * mutable_cpu_grad ()
 
float * mutable_cpu_history ()
 
int slice_start () const
 
int num_slices () const
 
virtual MsgGenPutMsg (bool copy, int slice_idx)
 Below are message/request related functions. More...
 
virtual MsgGenGetMsg (bool copy, int slice_idx)
 Generate the message for a get request, i.e., get parameters from a server. More...
 
virtual MsgGenUpdateMsg (bool copy, int slice_idx)
 Generate the message for a update request, i.e., pass info to server for parameter update. More...
 
virtual MsgGenSyncMsg (int offset, int size)
 Generate the message for a synchronization request between server groups. More...
 
virtual MsgHandlePutMsg (Msg **msg, bool reserve)
 Server handling function for put request. More...
 
virtual MsgHandleGetMsg (Msg **msg, bool reserve)
 Server handling function for put request. More...
 
virtual void ParseUpdateMsgs (const std::vector< Msg * > &msgs)
 Server parse update requests. More...
 
virtual const std::vector< Msg * > GenUpdateResponseMsgs (std::vector< Msg * > *msgs, bool reserve)
 Generate the messages to response the update requests. More...
 
virtual MsgHandleSyncMsg (Msg **msg, bool reserve)
 Server handling function for synchronization message. More...
 
virtual int ParseGetResponseMsg (Msg *msg, int slice_idx)
 Worker/Stub parsing function for get response. More...
 
virtual int ParseUpdateResponseMsg (Msg *msg, int slice_idx)
 Worker/Server parsing function for update response. More...
 
virtual int ParseSyncResponseMsg (Msg *msg, int slice_idx)
 Server parsing function for synchronization response. More...
 

Static Public Member Functions

static ParamCreate (const ParamProto &proto)
 

Protected Member Functions

void ParseResponseMsg (Msg *msg, int slice_idx)
 Implement the common code of ParseGetResponseMsg and ParseUpdateResponseMsg. More...
 

Protected Attributes

int local_version_ = -1
 
int slice_start_ = 0
 
int num_slices_ = 0
 
std::vector< int > slice_offset_
 
std::vector< int > slice_size_
 
std::vector< bool > pending_get_
 
std::vector< bool > pending_update_
 
int num_pending_requests_ = 0
 
std::shared_ptr< Blob< float > > data_ = nullptr
 
Blob< float > grad_
 
Blob< float > history_
 
ParamProto proto_
 

Detailed Description

Base paramter class.

The Param object is a set of parameters, e.g., the (sub) weight matrix or (sub) bias vector.

It has at a gradient Blob and data Blob for gradients and parameter values. Since some layers (or neuralnet) share parameter values, the data Blob is a shared pointer which can be assigned to many Param objects' data field.

It provides access methods like data(), grad(). It also provides functions for generating messages and parsing messages to transferring the Param objects among worker-worker, worker-server and server-server.

Param objects are of different sizes, which makes it hard to acheive load-balance among servers. Hence, we slice large Param objects into small pieces. At the server side, one slice is a Param object.

Member Function Documentation

void singa::Param::AddSlice ( int  slice_id,
int  size 
)

Add a slice.

Parameters
slice_id
sizenum of floats for this slice
virtual Msg* singa::Param::GenGetMsg ( bool  copy,
int  slice_idx 
)
virtual

Generate the message for a get request, i.e., get parameters from a server.

The basic communication workflows are as follow:

|Put |Get |Update |Sync

Generate|(stub) |(stub) |(stub) |(server)

Message |GenPutMsg |GenGetMsg |GenUpdateMsg |GenSyncMsg

Handle |(server) |(server) |(server) |(server) Message |HandlePutMsg|HandleGetMsg |ParseUpdateMsg |HandleSyncMsg

| | |GenUpdateResMsg |

Handle | |(stub) |(stub) |(server)

Response| |ParseGetResMsg|ParseUpdateResMsg|ParseSyncResMsg

Generate the message for a put request, i.e., put parameters to a server

This function is called at worker/stub side.

Parameters
copydecides whether to copy the parameter values from the server.
slice_idxindex of the slice from which the message is generated.
Returns
generated message without setting src, dst, target fields.
virtual Msg* singa::Param::GenPutMsg ( bool  copy,
int  slice_idx 
)
virtual

Below are message/request related functions.

The basic communication workflows are as follow:

|Put |Get |Update |Sync

Generate|(stub) |(stub) |(stub) |(server)

Message |GenPutMsg |GenGetMsg |GenUpdateMsg |GenSyncMsg

Handle |(server) |(server) |(server) |(server) Message |HandlePutMsg|HandleGetMsg |ParseUpdateMsg |HandleSyncMsg

| | |GenUpdateResMsg |

Handle | |(stub) |(stub) |(server)

Response| |ParseGetResMsg|ParseUpdateResMsg|ParseSyncResMsg

Generate the message for a put request, i.e., put parameters to a server

This function is called at worker/stub side.

Parameters
copydecides whether to copy the parameter values from the server.
slice_idxindex of the slice from which the message is generated.
Returns
generated message without setting src, dst, target fields.
virtual Msg* singa::Param::GenSyncMsg ( int  offset,
int  size 
)
virtual

Generate the message for a synchronization request between server groups.

This function is called at server side where the Param is actually a slice of an original Param object.

virtual Msg* singa::Param::GenUpdateMsg ( bool  copy,
int  slice_idx 
)
virtual

Generate the message for a update request, i.e., pass info to server for parameter update.

The basic communication workflows are as follow:

|Put |Get |Update |Sync

Generate|(stub) |(stub) |(stub) |(server)

Message |GenPutMsg |GenGetMsg |GenUpdateMsg |GenSyncMsg

Handle |(server) |(server) |(server) |(server) Message |HandlePutMsg|HandleGetMsg |ParseUpdateMsg |HandleSyncMsg

| | |GenUpdateResMsg |

Handle | |(stub) |(stub) |(server)

Response| |ParseGetResMsg|ParseUpdateResMsg|ParseSyncResMsg

Generate the message for a put request, i.e., put parameters to a server

This function is called at worker/stub side.

Parameters
copydecides whether to copy the parameter values from the server.
slice_idxindex of the slice from which the message is generated.
Returns
generated message without setting src, dst, target fields.
virtual const std::vector<Msg*> singa::Param::GenUpdateResponseMsgs ( std::vector< Msg * > *  msgs,
bool  reserve 
)
virtual

Generate the messages to response the update requests.

This function is called at the server side, where the Param is actually a slice of an original Param object.

Parameters
msgsfor synchronous training, there would be multiple procs in which workers sharing the same Param (slice) objects. Their update requests is bufferred and handled together. For asynchrnous training, there is only request in msgs.
Returns
response messages
virtual Msg* singa::Param::HandleGetMsg ( Msg **  msg,
bool  reserve 
)
virtual

Server handling function for put request.

virtual Msg* singa::Param::HandlePutMsg ( Msg **  msg,
bool  reserve 
)
virtual

Server handling function for put request.

Parameters
msgrequest
reserveif true reserve the msg space for the calling function; otherwise the msg should be freed inside the function.
Returns
resposne message
virtual Msg* singa::Param::HandleSyncMsg ( Msg **  msg,
bool  reserve 
)
virtual

Server handling function for synchronization message.

int singa::Param::local_version ( ) const
inline
Returns
the version of the parameter value local to a worker
virtual int singa::Param::ParseGetResponseMsg ( Msg msg,
int  slice_idx 
)
virtual

Worker/Stub parsing function for get response.

Parameters
msg
slice_idxindex for the slice
void singa::Param::ParseResponseMsg ( Msg msg,
int  slice_idx 
)
protected

Implement the common code of ParseGetResponseMsg and ParseUpdateResponseMsg.

virtual int singa::Param::ParseSyncResponseMsg ( Msg msg,
int  slice_idx 
)
virtual

Server parsing function for synchronization response.

virtual void singa::Param::ParseUpdateMsgs ( const std::vector< Msg * > &  msgs)
virtual

Server parse update requests.

virtual int singa::Param::ParseUpdateResponseMsg ( Msg msg,
int  slice_idx 
)
virtual

Worker/Server parsing function for update response.

virtual void singa::Param::Setup ( const std::vector< int > &  shape)
virtual

Setup param object.

Parameters
confparam configuration, include learning rate multiplier etc.
shapeone value per dimension
void singa::Param::ShareFrom ( const Param other)

Share the data blob from other Param objects.

Parameters
otherthe Param object whose owner owns the data blob
int singa::Param::size ( ) const
inline
Returns
num of floats.
int singa::Param::slice_start ( ) const
inline
Returns
slice start ID
int singa::Param::version ( ) const
inline

Param version is stored inside the data blob to enable all Param objs sharing the same values have the same version.

Returns
the param version

The documentation for this class was generated from the following file: