Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Macros
trainer.h
1 /************************************************************
2 *
3 * Licensed to the Apache Software Foundation (ASF) under one
4 * or more contributor license agreements. See the NOTICE file
5 * distributed with this work for additional information
6 * regarding copyright ownership. The ASF licenses this file
7 * to you under the Apache License, Version 2.0 (the
8 * "License"); you may not use this file except in compliance
9 * with the License. You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing,
14 * software distributed under the License is distributed on an
15 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 * KIND, either express or implied. See the License for the
17 * specific language governing permissions and limitations
18 * under the License.
19 *
20 *************************************************************/
21 
22 #ifndef INCLUDE_TRAINER_TRAINER_H_
23 #define INCLUDE_TRAINER_TRAINER_H_
24 #include <unordered_map>
25 #include <queue>
26 #include "proto/job.pb.h"
27 #include "proto/singa.pb.h"
28 #include "utils/param.h"
29 #include "utils/singleton.h"
30 #include "utils/factory.h"
31 #include "neuralnet/neuralnet.h"
32 #include "trainer/worker.h"
33 #include "trainer/server.h"
34 #include "communication/socket.h"
35 
36 namespace singa {
44 class Trainer{
45  public:
46  ~Trainer();
55  void Start(bool resume, const SingaProto& singaConf, JobProto* jobConf);
56 
57  protected:
67  void Resume(JobProto* jobConf);
76  vector<Server*> CreateServers(int nthread, const JobProto& jobConf);
85  vector<Worker*> CreateWorkers(int nthread, const JobProto& jobConf);
86 
97  void SetupWorkerServer(
98  const JobProto& jobConf,
99  const vector<Worker*>& workers,
100  const vector<Server*>& servers);
101 
102  void Run(const vector<Worker*>& workers, const vector<Server*>& servers);
106  void DisplayMetric(Msg** msg);
112  Dealer* CreateInterProcsDealer(int dst_procs);
116  void HandleLocalMsg(std::queue<Msg*>* msg_queue, Msg** msg);
117 
121  const vector<Msg*> HandleGet(ParamEntry* entry, Msg** msg);
122  void HandleGetResponse(ParamEntry* entry, Msg** msg);
123 
127  const vector<Msg*> HandleUpdate(ParamEntry* entry, Msg** msg);
128  void HandleUpdateResponse(ParamEntry* entry, Msg** msg);
129 
133  const vector<Msg*> HandlePut(ParamEntry* entry, Msg** msg);
134 
143  void GenMsgs(int type, int version, ParamEntry* entry,
144  Msg* msg, vector<Msg*> *ret);
151  inline int Hash(int grp_id, int param_id) {
152  return grp_id * 997 + param_id;
153  }
154 
155  protected:
156  int procs_id_;
157  Router *router_;
158  std::unordered_map<int, ParamEntry*> worker_shard_;
160  vector<int> slice2server_;
161 };
162 } /* singa */
163 #endif // INCLUDE_TRAINER_TRAINER_H_
void Start(bool resume, const SingaProto &singaConf, JobProto *jobConf)
Entrance function which construct the workers and servers, and luanch one thread per worker/server...
const vector< Msg * > HandlePut(ParamEntry *entry, Msg **msg)
Generate a request message to Put the parameter object.
const vector< Msg * > HandleUpdate(ParamEntry *entry, Msg **msg)
Generate a request message to Update the parameter object.
Msg used to transfer Param info (gradient or value), feature blob, etc between workers, stubs and servers.
Definition: msg.h:91
vector< Worker * > CreateWorkers(int nthread, const JobProto &jobConf)
Create workers instances.
ParamEntry is used for aggregating gradients of Params shared by workers from the same group...
Definition: param.h:335
void HandleLocalMsg(std::queue< Msg * > *msg_queue, Msg **msg)
Handle messages to local servers and local stub.
Dealer * CreateInterProcsDealer(int dst_procs)
Create a socket to send msg to the specified process.
Definition: socket.h:91
void Resume(JobProto *jobConf)
Setting the checkpoint field of model configuration to resume training.
std::unordered_map< int, ParamEntry * > worker_shard_
map from slice to the server that updates it
Definition: trainer.h:158
int Hash(int grp_id, int param_id)
Get a hash id for a Param object from a group.
Definition: trainer.h:151
vector< Server * > CreateServers(int nthread, const JobProto &jobConf)
Create server instances.
Definition: socket.h:125
void SetupWorkerServer(const JobProto &jobConf, const vector< Worker * > &workers, const vector< Server * > &servers)
Setup workers and servers.
Every running process has a training object which launches one or more worker (and server) threads...
Definition: trainer.h:44
const vector< Msg * > HandleGet(ParamEntry *entry, Msg **msg)
Generate a request message to Get the parameter object.
void DisplayMetric(Msg **msg)
Display metrics to log (standard output)
void GenMsgs(int type, int version, ParamEntry *entry, Msg *msg, vector< Msg * > *ret)
Called by HandlePut, HandleUpdate and HandleGet functions.