Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Macros
Public Member Functions | Protected Member Functions | Protected Attributes | List of all members
singa::Trainer Class Reference

Every running process has a training object which launches one or more worker (and server) threads. More...

#include <trainer.h>

Public Member Functions

void Start (bool resume, const SingaProto &singaConf, JobProto *jobConf)
 Entrance function which construct the workers and servers, and luanch one thread per worker/server. More...
 

Protected Member Functions

void Resume (JobProto *jobConf)
 Setting the checkpoint field of model configuration to resume training. More...
 
vector< Server * > CreateServers (int nthread, const JobProto &jobConf)
 Create server instances. More...
 
vector< Worker * > CreateWorkers (int nthread, const JobProto &jobConf)
 Create workers instances. More...
 
void SetupWorkerServer (const JobProto &jobConf, const vector< Worker * > &workers, const vector< Server * > &servers)
 Setup workers and servers. More...
 
void Run (const vector< Worker * > &workers, const vector< Server * > &servers)
 
void DisplayMetric (Msg **msg)
 Display metrics to log (standard output)
 
DealerCreateInterProcsDealer (int dst_procs)
 Create a socket to send msg to the specified process. More...
 
void HandleLocalMsg (std::queue< Msg * > *msg_queue, Msg **msg)
 Handle messages to local servers and local stub.
 
const vector< Msg * > HandleGet (ParamEntry *entry, Msg **msg)
 Generate a request message to Get the parameter object.
 
void HandleGetResponse (ParamEntry *entry, Msg **msg)
 
const vector< Msg * > HandleUpdate (ParamEntry *entry, Msg **msg)
 Generate a request message to Update the parameter object.
 
void HandleUpdateResponse (ParamEntry *entry, Msg **msg)
 
const vector< Msg * > HandlePut (ParamEntry *entry, Msg **msg)
 Generate a request message to Put the parameter object.
 
void GenMsgs (int type, int version, ParamEntry *entry, Msg *msg, vector< Msg * > *ret)
 Called by HandlePut, HandleUpdate and HandleGet functions. More...
 
int Hash (int grp_id, int param_id)
 Get a hash id for a Param object from a group. More...
 

Protected Attributes

int procs_id_
 
Routerrouter_
 
std::unordered_map< int,
ParamEntry * > 
worker_shard_
 map from slice to the server that updates it
 
vector< int > slice2server_
 

Detailed Description

Every running process has a training object which launches one or more worker (and server) threads.

The main thread runs a loop to forward messages between workers and servers.

Member Function Documentation

Dealer* singa::Trainer::CreateInterProcsDealer ( int  dst_procs)
protected

Create a socket to send msg to the specified process.

Parameters
dst_procsthe dst process (logical) ID
Returns
the newly created socket
vector<Server*> singa::Trainer::CreateServers ( int  nthread,
const JobProto &  jobConf 
)
protected

Create server instances.

Parameters
nthreadtotal num of threads in current procs which is used to assign each thread a local thread ID. The number of workers is extracted from Cluster
jobConf
Returns
server instances
vector<Worker*> singa::Trainer::CreateWorkers ( int  nthread,
const JobProto &  jobConf 
)
protected

Create workers instances.

Parameters
nthreadtotal num of threads in current procs which is used to assign each thread a local thread ID. The number of workers is extracted from Cluster
jobConf
Returns
worker instances
void singa::Trainer::GenMsgs ( int  type,
int  version,
ParamEntry entry,
Msg msg,
vector< Msg * > *  ret 
)
protected

Called by HandlePut, HandleUpdate and HandleGet functions.

Parameters
typemessage type
versionparam version
entry
msg
retgenerated messages
int singa::Trainer::Hash ( int  grp_id,
int  param_id 
)
inlineprotected

Get a hash id for a Param object from a group.

Simple multiple group_id with a large prime number 997 (assuming there are no more than 997 worker groups) and plus owner param id.

void singa::Trainer::Resume ( JobProto *  jobConf)
protected

Setting the checkpoint field of model configuration to resume training.

The checkpoint folder will be searched to get the files for the latest checkpoint, which will be added into the checkpoint field. The workers would then load the values of params from the checkpoint files.

Parameters
jobConfjob configuration
void singa::Trainer::SetupWorkerServer ( const JobProto &  jobConf,
const vector< Worker * > &  workers,
const vector< Server * > &  servers 
)
protected

Setup workers and servers.

For each worker, create and assign a neuralnet to it. For each server, create and assign the param shard to it. Create the partition map from slice ID to server

Parameters
modelConf
workers
servers
void singa::Trainer::Start ( bool  resume,
const SingaProto &  singaConf,
JobProto *  jobConf 
)

Entrance function which construct the workers and servers, and luanch one thread per worker/server.

Parameters
resumeif true resume the training from the latest checkpoint files
singaConfglobal singa configuration including zookeeper and
jobConfjob configuration, including cluster and model configuration

The documentation for this class was generated from the following file: