Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Macros
Classes | Public Types | Public Member Functions | Protected Member Functions | Protected Attributes | List of all members
singa::Trainer Class Reference

Every running process has a training object which launches one or more worker (and server) threads. More...

#include <trainer.h>

Classes

class  ParamInfo
 ParamInfo is used to construct a parameter shard. More...
 

Public Types

typedef std::map< int,
shared_ptr< ParamInfo > > 
ParamShard
 

Public Member Functions

void Start (const ModelProto &modelproto, const ClusterProto &clusterproto, int procs_id)
 Start the training in one process. More...
 

Protected Member Functions

void Run (int nworkers, int nservers, const std::map< int, shared_ptr< ParamShard >> &shards)
 
void RegisterDefaultClasses (const singa::ModelProto &proto)
 Register default implementations for all base classes used in the system, e.g., the Updater, BaseMsg, etc. More...
 
virtual int Sharding (int param_id)
 Workers from the same group resident in the same process share the same ParamShard which contains ParamCounters for Param objects used/updated by these worekrs. More...
 
virtual MsgHandleGet (shared_ptr< ParamInfo >counter, Msg **msg)
 Generate a request message to Get the parameter object.
 
virtual MsgHandleGetResponse (shared_ptr< ParamInfo >counter, Msg **msg)
 
virtual MsgHandleUpdate (shared_ptr< ParamInfo >counter, Msg **msg)
 Generate a request message to Update the parameter object.
 
virtual int HandleUpdateResponse (shared_ptr< ParamInfo >counter, Msg **msg)
 
virtual MsgHandlePut (shared_ptr< ParamInfo >counter, Msg **msg)
 Generate a request message to Put the parameter object.
 
virtual MsgHandleConnect (Msg **msg)
 

Protected Attributes

int procs_id_
 
shared_ptr< Routerrouter_
 

Detailed Description

Every running process has a training object which launches one or more worker (and server) threads.

The main thread runs a loop to forward messages between workers and servers.

Member Function Documentation

void singa::Trainer::RegisterDefaultClasses ( const singa::ModelProto proto)
protected

Register default implementations for all base classes used in the system, e.g., the Updater, BaseMsg, etc.

All built-in layer implementations are registered here. For other base classes, use its base class name (string) as the key and the implementation class as the value, e.g., <"Updater" SGDUpdater>.

virtual int singa::Trainer::Sharding ( int  param_id)
protectedvirtual

Workers from the same group resident in the same process share the same ParamShard which contains ParamCounters for Param objects used/updated by these worekrs.

Shared Param objects are associated with the same ParamCounter.

Returns
server id where the parameter is maintained.
void singa::Trainer::Start ( const ModelProto modelproto,
const ClusterProto clusterproto,
int  procs_id 
)

Start the training in one process.

Parameters
modelproto
clusterproto

The documentation for this class was generated from the following file: