Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Macros
updater.h
1 /************************************************************
2 *
3 * Licensed to the Apache Software Foundation (ASF) under one
4 * or more contributor license agreements. See the NOTICE file
5 * distributed with this work for additional information
6 * regarding copyright ownership. The ASF licenses this file
7 * to you under the Apache License, Version 2.0 (the
8 * "License"); you may not use this file except in compliance
9 * with the License. You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing,
14 * software distributed under the License is distributed on an
15 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 * KIND, either express or implied. See the License for the
17 * specific language governing permissions and limitations
18 * under the License.
19 *
20 *************************************************************/
21 
22 #ifndef SINGA_UTILS_UPDATER_H_
23 #define SINGA_UTILS_UPDATER_H_
24 
25 #include "proto/job.pb.h"
26 #include "utils/param.h"
27 
28 namespace singa {
36 class LRGenerator {
37  public:
38  static LRGenerator* Create(const LRGenProto& proto);
39 
40  virtual ~LRGenerator() {}
41 
42  virtual void Init(const LRGenProto& proto) { proto_ = proto; }
47  virtual float Get(int step) { return proto_.base_lr(); }
48 
49  protected:
50  LRGenProto proto_;
51 };
52 
53 class FixedStepLRGen : public LRGenerator {
54  public:
55  float Get(int step) override;
56  private:
57  int last_idx_ = 0;
58 };
59 
60 class StepLRGen : public LRGenerator {
61  public:
62  float Get(int step) override;
63 };
64 
65 class LinearLRGen : public LRGenerator {
66  public:
67  float Get(int step) override;
68 };
69 
70 class ExpLRGen : public LRGenerator {
71  public:
72  float Get(int step) override;
73 };
74 
75 class InvLRGen : public LRGenerator {
76  public:
77  float Get(int step) override;
78 };
79 
80 class InvTLRGen : public LRGenerator {
81  public:
82  float Get(int step) override;
83 };
84 
88 class Updater {
89  public:
90  static Updater* Create(const UpdaterProto& proto);
91 
92  virtual ~Updater() {}
93 
94  virtual void Init(const UpdaterProto &proto);
95  virtual void Update(int step, Param* param, float grad_scale) = 0;
96 
97  protected:
98  UpdaterProto proto_;
99  LRGenerator* lr_gen_;
100  float weight_decay_;
101  float momentum_;
102 };
103 
104 class SGDUpdater : public Updater {
105  public:
106  void Update(int step, Param* param, float grad_scale) override;
107 };
108 
109 class AdaGradUpdater : public Updater {
110  public:
111  void Update(int step, Param* param, float grad_scale) override;
112 };
113 
114 
115 class NesterovUpdater : public Updater {
116  public:
117  void Update(int step, Param* param, float grad_scale) override;
118 };
119 
120 /*
121 class RMSPropUpdater : public Updater {
122  public:
123  virtual void Update(int step, Param* param, float grad_scale);
124 
125  protected:
126  float base_lr_;
127  float delta_;
128  float rho_;
129  float weight_decay_;
130 };
131 
132 class AdaDeltaUpdater : public Updater {
133  public:
134  virtual void Update(int step, Param* param, float grad_scale);
135 
136  protected:
137  float rho_;
138  float delta_;
139  float weight_decay_;
140 };
141 */
142 
143 } // namespace singa
144 
145 #endif // SINGA_UTILS_UPDATER_H_
float Get(int step) override
float Get(int step) override
Definition: updater.h:75
float Get(int step) override
Base learning rate generator.
Definition: updater.h:36
Base paramter class.
Definition: param.h:93
virtual float Get(int step)
Definition: updater.h:47
float Get(int step) override
Definition: updater.h:60
Definition: updater.h:53
Definition: updater.h:70
Definition: updater.h:115
Definition: updater.h:109
Updater for Param.
Definition: updater.h:88
Definition: updater.h:104
float Get(int step) override
Definition: updater.h:80
float Get(int step) override
Definition: updater.h:65