Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Macros
worker.h
1 /************************************************************
2 *
3 * Licensed to the Apache Software Foundation (ASF) under one
4 * or more contributor license agreements. See the NOTICE file
5 * distributed with this work for additional information
6 * regarding copyright ownership. The ASF licenses this file
7 * to you under the Apache License, Version 2.0 (the
8 * "License"); you may not use this file except in compliance
9 * with the License. You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing,
14 * software distributed under the License is distributed on an
15 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 * KIND, either express or implied. See the License for the
17 * specific language governing permissions and limitations
18 * under the License.
19 *
20 *************************************************************/
21 
22 #ifndef SINGA_TRAINER_WORKER_H_
23 #define SINGA_TRAINER_WORKER_H_
24 #include "neuralnet/neuralnet.h"
25 #include "proto/job.pb.h"
26 #include "communication/socket.h"
27 
28 namespace singa {
30 const int kCollectSleepTime=5;
42 class Worker {
43  public:
44  static Worker* Create(const JobProto& proto);
50  virtual void Init(int thread_id, int grp_id, int id);
51  virtual ~Worker();
55  void Setup(const JobProto& job, shared_ptr<NeuralNet> train_net,
56  shared_ptr<NeuralNet> valid_net, shared_ptr<NeuralNet> test_net);
62  void Run();
79  void InitLocalParams();
80 
90  void Checkpoint(int step, shared_ptr<NeuralNet> net);
96  void Test(int nsteps, Phase phase, shared_ptr<NeuralNet> net);
101  virtual void TrainOneBatch(int step, Metric* perf)=0;
105  virtual void TestOneBatch(int step, Phase phase, shared_ptr<NeuralNet> net,
106  Metric* perf)=0;
113  void Report(const string& prefix, const Metric & perf);
114 
120  int Put(Param* param, int step);
128  int Get(Param* param, int step);
134  int Update(Param* param, int step);
141  int Collect(Param* param, int step);
145  int CollectAll(shared_ptr<NeuralNet> net, int step);
149  void ReceiveBlobs(
150  bool data, bool grad, BridgeLayer* layer, shared_ptr<NeuralNet> net);
154  void SendBlobs(
155  bool data, bool grad, BridgeLayer* layer, shared_ptr<NeuralNet> net);
156 
160  inline bool DisplayNow(int step) const;
164  inline bool DisplayDebugInfo(int step) const;
168  inline bool StopNow(int step) const;
172  inline bool CheckpointNow(int step) const;
177  inline bool TestNow(int step) const;
182  inline bool ValidateNow(int step) const;
183 
187  int grp_id() const { return grp_id_;}
188 
192  int id() const { return id_;}
193 
194  protected:
195  int thread_id_, grp_id_, id_;
196  int step_;
197  JobProto job_conf_;
198  shared_ptr<NeuralNet> train_net_, test_net_, validation_net_;
199  Dealer* layer_dealer_, *dealer_;
200 };
201 
202 class BPWorker: public Worker{
203  public:
204  ~BPWorker(){}
205  void Init(int thread_id, int grp_id, int id) override;
206  void TrainOneBatch(int step, Metric* perf) override;
207  void TestOneBatch(int step, Phase phase, shared_ptr<NeuralNet> net,
208  Metric* perf) override;
209 
210  void Forward(int step, Phase phase, shared_ptr<NeuralNet> net, Metric* perf);
211  void Backward(int step, shared_ptr<NeuralNet> net);
212 };
213 
214 class CDWorker: public Worker{
215  public:
216  void TrainOneBatch(int step, Metric* perf) override;
217  void TestOneBatch(int step, Phase phase, shared_ptr<NeuralNet> net,
218  Metric* perf) override;
219 };
220 
221 inline int BlobTrgt(int grp, int layer) {
222  return (grp << 16) | layer;
223 }
224 
225 inline int BlobGrp(int blob_trgt) {
226  return blob_trgt >> 16;
227 }
228 
229 inline int BlobLayer(int blob_trgt) {
230  static int mask = (1 << 16) -1;
231  return blob_trgt & mask;
232 }
233 } // namespace singa
234 
235 #endif // SINGA_TRAINER_WORKER_H_
void TrainOneBatch(int step, Metric *perf) override
Train one mini-batch.
bool TestNow(int step) const
Check is it time to do test.
int Collect(Param *param, int step)
Block until the param is updated since sending the update request.
The Worker class which runs the training algorithm.
Definition: worker.h:42
void InitLocalParams()
Init all local params (i.e., params from layers resident in this worker).
Base paramter class.
Definition: param.h:93
int grp_id() const
Definition: worker.h:187
int CollectAll(shared_ptr< NeuralNet > net, int step)
Call Collect for every param of net.
bool DisplayDebugInfo(int step) const
Check is it time to display training info, e.g., loss and precison.
Definition: worker.h:202
void ReceiveBlobs(bool data, bool grad, BridgeLayer *layer, shared_ptr< NeuralNet > net)
Receive blobs from other workers due to model partitions.
void TestOneBatch(int step, Phase phase, shared_ptr< NeuralNet > net, Metric *perf) override
Test/validate one mini-batch.
virtual void Init(int thread_id, int grp_id, int id)
Definition: socket.h:91
bool DisplayNow(int step) const
Check is it time to display training info, e.g., loss and precison.
void Test(int nsteps, Phase phase, shared_ptr< NeuralNet > net)
Test the perforance of the learned model on validation or test dataset.
int id() const
worker ID within the worker group.
Definition: worker.h:192
void Init(int thread_id, int grp_id, int id) override
int Update(Param *param, int step)
Update Param.
int Put(Param *param, int step)
Put Param to server.
void Setup(const JobProto &job, shared_ptr< NeuralNet > train_net, shared_ptr< NeuralNet > valid_net, shared_ptr< NeuralNet > test_net)
Setup members.
void TrainOneBatch(int step, Metric *perf) override
Train one mini-batch.
Definition: connection_layer.h:33
virtual void TestOneBatch(int step, Phase phase, shared_ptr< NeuralNet > net, Metric *perf)=0
Test/validate one mini-batch.
void Run()
Main function of Worker.
void Checkpoint(int step, shared_ptr< NeuralNet > net)
Checkpoint all params owned by the worker from the first group onto disk.
int Get(Param *param, int step)
Get Param with specific version from server If the current version >= the requested version...
bool StopNow(int step) const
Check is it time to stop.
void Report(const string &prefix, const Metric &perf)
Report performance to the stub.
Performance mtrics.
Definition: common.h:85
void SendBlobs(bool data, bool grad, BridgeLayer *layer, shared_ptr< NeuralNet > net)
Send blobs to other workers due to model partitions.
virtual void TrainOneBatch(int step, Metric *perf)=0
Train one mini-batch.
const int kCollectSleepTime
< sleep 5 milliseconds if the Param is not updated to the expected version
Definition: worker.h:30
bool CheckpointNow(int step) const
Check is it time to do checkpoint.
Definition: worker.h:214
void TestOneBatch(int step, Phase phase, shared_ptr< NeuralNet > net, Metric *perf) override
Test/validate one mini-batch.
bool ValidateNow(int step) const
Check is it time to do validation.