1 #ifndef INCLUDE_TRAINER_WORKER_H_
2 #define INCLUDE_TRAINER_WORKER_H_
5 #include "neuralnet/neuralnet.h"
6 #include "proto/model.pb.h"
7 #include "utils/cluster.h"
8 #include "communication/socket.h"
9 #include "communication/msg.h"
12 const int kCollectSleepTime=5;
20 Worker(
int thread_id,
int group_id,
int worker_id);
22 void Setup(
const ModelProto& model, shared_ptr<NeuralNet> train_net);
23 void set_test_net(shared_ptr<NeuralNet> test_net){
26 void set_validation_net(shared_ptr<NeuralNet> val_net){
27 validation_net_=val_net;
32 int Put(shared_ptr<Param> param,
int step);
33 int Get(shared_ptr<Param> param,
int step);
34 int Update(shared_ptr<Param> param,
int step);
35 int Collect(shared_ptr<Param> param,
int step);
36 int CollectAll(shared_ptr<NeuralNet> net,
int step);
52 virtual void TestOneBatch(shared_ptr<NeuralNet> net,
int step, Phase phase)=0;
59 void Test(shared_ptr<NeuralNet> net,
int nsteps,
const string &prefix);
78 return (modelproto_.display_frequency() > 0
79 && step >= modelproto_.display_after_steps()
80 && ((step - modelproto_.display_after_steps())
81 % modelproto_.display_frequency() == 0));
84 const bool DisplayDebugInfo(
const int step)
const {
85 return DisplayNow(step)&&modelproto_.debug()&&group_id_==0;
87 const void DisplayPerformance(
const Metric & perf,
const string& prefix);
94 return (step >= modelproto_.train_steps());
102 && modelproto_.checkpoint_frequency() > 0
103 && step >= modelproto_.checkpoint_after_steps()
104 && ((step - modelproto_.checkpoint_after_steps())
105 % modelproto_.checkpoint_frequency() == 0));
113 && modelproto_.test_frequency() > 0
114 && modelproto_.test_steps() > 0
115 && step >= modelproto_.test_after_steps()
116 && ((step - modelproto_.test_after_steps())
117 % modelproto_.test_frequency() == 0));
125 && modelproto_.validation_frequency() > 0
126 && modelproto_.validation_steps() > 0
127 && step >= modelproto_.validation_after_steps()
128 && ((step - modelproto_.validation_after_steps())
129 % modelproto_.validation_frequency() == 0));
145 int thread_id_, group_id_, worker_id_;
148 shared_ptr<NeuralNet> train_net_, test_net_, validation_net_;
149 shared_ptr<Dealer> layer_dealer_, param_dealer_;
150 Poller layer_poller_, param_poller_;
156 BPWorker(
int thread_id,
int group_id,
int worker_id):
Worker(thread_id, group_id, worker_id){}
158 virtual void TestOneBatch(shared_ptr<NeuralNet> net,
int step, Phase phase);
159 void Forward(shared_ptr<NeuralNet> net,
int step,
bool training);
160 void Backward(shared_ptr<NeuralNet> net,
int step);
191 #endif // INCLUDE_TRAINER_WORKER_H_
virtual void Run()
Main function of Worker.
The Worker class which runs the training algorithm.
Definition: worker.h:18
Definition: model.pb.h:316
virtual void TrainOneBatch(int step)=0
Train one mini-batch.
const bool StopNow(const int step) const
return true if the stop condition is satisfied, e.g., the maximum number of steps have been reached...
Definition: worker.h:93
void Test(shared_ptr< NeuralNet > net, int nsteps, const string &prefix)
Test the perforance of the learned model on validation or test dataset.
virtual void TestOneBatch(shared_ptr< NeuralNet > net, int step, Phase phase)
Test/validate one mini-batch.
const bool ValidateNow(const int step)
Check is it time to do validation.
Definition: worker.h:123
const bool CheckpointNow(const int step) const
Check is it time to do checkpoint.
Definition: worker.h:100
void ReceiveBlobs(shared_ptr< NeuralNet > net)
start training from scratch.
virtual void TrainOneBatch(int step)
Train one mini-batch.
void RunOneBatch(int step, Metric *perf=nullptr)
check validation/test firstly, then TrainOneBatch Performance collects performance for the whole neur...
const bool DisplayNow(const int step) const
Pull data from layers resident on other nodes due to Model Partition.
Definition: worker.h:77
virtual void TestOneBatch(shared_ptr< NeuralNet > net, int step, Phase phase)=0
Test/validate one mini-batch.
const bool TestNow(const int step) const
Check is it time to do test.
Definition: worker.h:111