Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Macros
updater.h
1 #ifndef INCLUDE_UTILS_UPDATER_H_
2 #define INCLUDE_UTILS_UPDATER_H_
3 #include "proto/model.pb.h"
4 #include "utils/param.h"
5 
6 namespace singa{
10 class Updater{
11  public:
12  virtual void Init(const UpdaterProto &proto){
13  proto_=proto;
14  }
15  virtual void Update(int step, shared_ptr<Param> param, float grad_scale=1.0f)=0;
16 
17  float GetLearningRate(int step);
18  protected:
19  UpdaterProto proto_;
20 };
21 class SGDUpdater : public Updater{
22  public:
23  virtual void Init(const UpdaterProto& proto);
24  virtual void Update(int step, shared_ptr<Param> param, float grad_scale=1.0f);
25 
26  protected:
27  float base_lr_;
28  float momentum_;
29  float weight_decay_;
30 };
31 class NesterovUpdater : public Updater{
32  public:
33  virtual void Init(const UpdaterProto& proto);
34  virtual void Update(int step, shared_ptr<Param> param, float grad_scale=1.0f);
35 
36  protected:
37  float base_lr_;
38  float momentum_;
39  float weight_decay_;
40 };
41 class AdaGradUpdater : public Updater{
42  public:
43  virtual void Init(const UpdaterProto& proto);
44  virtual void Update(int step, shared_ptr<Param> param, float grad_scale=1.0f);
45 
46  protected:
47  float base_lr_;
48  float delta_;
49  float weight_decay_;
50 };
51 
52 class RMSPropUpdater : public Updater{
53  public:
54  virtual void Init(const UpdaterProto& proto);
55  virtual void Update(int step, shared_ptr<Param> param, float grad_scale=1.0f);
56 
57  protected:
58  float base_lr_;
59  float delta_;
60  float rho_;
61  float weight_decay_;
62 };
63 
64 /*
65 class AdaDeltaUpdater : public Updater{
66  public:
67  virtual void Init(const UpdaterProto& proto);
68  virtual void Update(int step, shared_ptr<Param> param, float grad_scale=1.0f);
69 
70  protected:
71  float rho_;
72  float delta_;
73  float weight_decay_;
74 };
75 */
76 }
77 
78 #endif // INCLUDE_UTILS_UPDATER_H_
Definition: model.pb.h:3432
Definition: updater.h:52
Definition: updater.h:31
Definition: updater.h:41
Updater for Param.
Definition: updater.h:10
Definition: updater.h:21