Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Macros
trainer.h
1 #ifndef INCLUDE_TRAINER_TRAINER_H_
2 #define INCLUDE_TRAINER_TRAINER_H_
3 #include "proto/cluster.pb.h"
4 #include "proto/model.pb.h"
5 #include "utils/updater.h"
6 #include "utils/param.h"
7 #include "utils/singleton.h"
8 #include "utils/factory.h"
9 #include "neuralnet/neuralnet.h"
10 #include "trainer/worker.h"
11 #include "trainer/server.h"
12 #include "communication/socket.h"
13 
14 namespace singa {
22 class Trainer{
40  public:
41  class ParamInfo{
42  public:
43  ParamInfo(shared_ptr<Param> p,int local, int owner):
44  num_update(0), next_version(0),num_local(local), num_total(1),
45  owner_procs(owner){
46  shares.push_back(p);
47  }
48 
57  void AddParam(shared_ptr<Param> p, int local, int owner){
58  num_local+=local;
59  num_total+=1;
60  if(owner>-1)
61  owner_procs=owner;
62  if(local>0){
63  shares.push_back(p);
64  }
65  }
66  int num_update, next_version;
67 
68  int num_local;
69  int num_total;
71  vector<shared_ptr<Param>> shares;
72  };
73 
74  typedef std::map<int, shared_ptr<ParamInfo>> ParamShard;
75 
76  public:
83  void Start(const ModelProto& modelproto, const ClusterProto& clusterproto,
84  int procs_id);
85 
86  // TODO add Resume() function to continue training from a previously stopped
87  // point.
88 
89  protected:
90  void Run(int nworkers, int nservers,
91  const std::map<int, shared_ptr<ParamShard>>& shards);
101  void RegisterDefaultClasses(const singa::ModelProto& proto);
102 
113  virtual int Sharding(int param_id);
114 
118  virtual Msg* HandleGet(shared_ptr<ParamInfo>counter, Msg** msg);
119  virtual Msg* HandleGetResponse(shared_ptr<ParamInfo>counter, Msg** msg);
120 
124  virtual Msg* HandleUpdate(shared_ptr<ParamInfo>counter, Msg** msg);
125  virtual int HandleUpdateResponse(shared_ptr<ParamInfo>counter, Msg** msg);
126 
130  virtual Msg* HandlePut(shared_ptr<ParamInfo>counter, Msg** msg);
131  virtual Msg* HandleConnect(Msg** msg);
132 
133  protected:
134  int procs_id_;
135  shared_ptr<Router> router_;
136 };
137 } /* singa */
138 #endif // INCLUDE_TRAINER_TRAINER_H_
virtual Msg * HandlePut(shared_ptr< ParamInfo >counter, Msg **msg)
Generate a request message to Put the parameter object.
Definition: cluster.pb.h:41
int next_version
all counters are atomic
Definition: trainer.h:66
void Start(const ModelProto &modelproto, const ClusterProto &clusterproto, int procs_id)
Start the training in one process.
int num_total
total workers uses the shared parameter
Definition: trainer.h:69
Definition: model.pb.h:316
Definition: msg.h:59
int num_local
local workers uses the shared parameter
Definition: trainer.h:68
int owner_procs
the procs id of the worker that owns the parameter
Definition: trainer.h:70
virtual int Sharding(int param_id)
Workers from the same group resident in the same process share the same ParamShard which contains Par...
Every running process has a training object which launches one or more worker (and server) threads...
Definition: trainer.h:22
ParamInfo is used to construct a parameter shard.
Definition: trainer.h:41
void RegisterDefaultClasses(const singa::ModelProto &proto)
Register default implementations for all base classes used in the system, e.g., the Updater...
virtual Msg * HandleUpdate(shared_ptr< ParamInfo >counter, Msg **msg)
Generate a request message to Update the parameter object.
void AddParam(shared_ptr< Param > p, int local, int owner)
Associate the counter to a Param object.
Definition: trainer.h:57
virtual Msg * HandleGet(shared_ptr< ParamInfo >counter, Msg **msg)
Generate a request message to Get the parameter object.