22 #ifndef INCLUDE_TRAINER_TRAINER_H_
23 #define INCLUDE_TRAINER_TRAINER_H_
24 #include <unordered_map>
26 #include "proto/job.pb.h"
27 #include "proto/singa.pb.h"
28 #include "utils/param.h"
29 #include "utils/singleton.h"
30 #include "utils/factory.h"
31 #include "neuralnet/neuralnet.h"
32 #include "trainer/worker.h"
33 #include "trainer/server.h"
34 #include "communication/socket.h"
55 void Start(
bool resume,
const SingaProto& singaConf, JobProto* jobConf);
67 void Resume(JobProto* jobConf);
76 vector<Server*>
CreateServers(
int nthread,
const JobProto& jobConf);
85 vector<Worker*>
CreateWorkers(
int nthread,
const JobProto& jobConf);
98 const JobProto& jobConf,
99 const vector<Worker*>& workers,
100 const vector<Server*>& servers);
102 void Run(
const vector<Worker*>& workers,
const vector<Server*>& servers);
144 Msg* msg, vector<Msg*> *ret);
151 inline int Hash(
int grp_id,
int param_id) {
152 return grp_id * 997 + param_id;
160 vector<int> slice2server_;
163 #endif // INCLUDE_TRAINER_TRAINER_H_
void Start(bool resume, const SingaProto &singaConf, JobProto *jobConf)
Entrance function which construct the workers and servers, and luanch one thread per worker/server...
const vector< Msg * > HandlePut(ParamEntry *entry, Msg **msg)
Generate a request message to Put the parameter object.
const vector< Msg * > HandleUpdate(ParamEntry *entry, Msg **msg)
Generate a request message to Update the parameter object.
Msg used to transfer Param info (gradient or value), feature blob, etc between workers, stubs and servers.
Definition: msg.h:91
vector< Worker * > CreateWorkers(int nthread, const JobProto &jobConf)
Create workers instances.
ParamEntry is used for aggregating gradients of Params shared by workers from the same group...
Definition: param.h:335
void HandleLocalMsg(std::queue< Msg * > *msg_queue, Msg **msg)
Handle messages to local servers and local stub.
Dealer * CreateInterProcsDealer(int dst_procs)
Create a socket to send msg to the specified process.
void Resume(JobProto *jobConf)
Setting the checkpoint field of model configuration to resume training.
std::unordered_map< int, ParamEntry * > worker_shard_
map from slice to the server that updates it
Definition: trainer.h:158
int Hash(int grp_id, int param_id)
Get a hash id for a Param object from a group.
Definition: trainer.h:151
vector< Server * > CreateServers(int nthread, const JobProto &jobConf)
Create server instances.
void SetupWorkerServer(const JobProto &jobConf, const vector< Worker * > &workers, const vector< Server * > &servers)
Setup workers and servers.
Every running process has a training object which launches one or more worker (and server) threads...
Definition: trainer.h:44
const vector< Msg * > HandleGet(ParamEntry *entry, Msg **msg)
Generate a request message to Get the parameter object.
void DisplayMetric(Msg **msg)
Display metrics to log (standard output)
void GenMsgs(int type, int version, ParamEntry *entry, Msg *msg, vector< Msg * > *ret)
Called by HandlePut, HandleUpdate and HandleGet functions.