1 #ifndef INCLUDE_TRAINER_TRAINER_H_
2 #define INCLUDE_TRAINER_TRAINER_H_
3 #include "proto/cluster.pb.h"
4 #include "proto/model.pb.h"
5 #include "utils/updater.h"
6 #include "utils/param.h"
7 #include "utils/singleton.h"
8 #include "utils/factory.h"
9 #include "neuralnet/neuralnet.h"
10 #include "trainer/worker.h"
11 #include "trainer/server.h"
12 #include "communication/socket.h"
43 ParamInfo(shared_ptr<Param> p,
int local,
int owner):
57 void AddParam(shared_ptr<Param> p,
int local,
int owner){
71 vector<shared_ptr<Param>> shares;
74 typedef std::map<int, shared_ptr<ParamInfo>> ParamShard;
90 void Run(
int nworkers,
int nservers,
91 const std::map<
int, shared_ptr<ParamShard>>& shards);
119 virtual Msg* HandleGetResponse(shared_ptr<ParamInfo>counter,
Msg** msg);
125 virtual int HandleUpdateResponse(shared_ptr<ParamInfo>counter,
Msg** msg);
131 virtual Msg* HandleConnect(
Msg** msg);
135 shared_ptr<Router> router_;
138 #endif // INCLUDE_TRAINER_TRAINER_H_
virtual Msg * HandlePut(shared_ptr< ParamInfo >counter, Msg **msg)
Generate a request message to Put the parameter object.
Definition: cluster.pb.h:41
int next_version
all counters are atomic
Definition: trainer.h:66
void Start(const ModelProto &modelproto, const ClusterProto &clusterproto, int procs_id)
Start the training in one process.
int num_total
total workers uses the shared parameter
Definition: trainer.h:69
Definition: model.pb.h:316
int num_local
local workers uses the shared parameter
Definition: trainer.h:68
int owner_procs
the procs id of the worker that owns the parameter
Definition: trainer.h:70
virtual int Sharding(int param_id)
Workers from the same group resident in the same process share the same ParamShard which contains Par...
Every running process has a training object which launches one or more worker (and server) threads...
Definition: trainer.h:22
ParamInfo is used to construct a parameter shard.
Definition: trainer.h:41
void RegisterDefaultClasses(const singa::ModelProto &proto)
Register default implementations for all base classes used in the system, e.g., the Updater...
virtual Msg * HandleUpdate(shared_ptr< ParamInfo >counter, Msg **msg)
Generate a request message to Update the parameter object.
void AddParam(shared_ptr< Param > p, int local, int owner)
Associate the counter to a Param object.
Definition: trainer.h:57
virtual Msg * HandleGet(shared_ptr< ParamInfo >counter, Msg **msg)
Generate a request message to Get the parameter object.