Apache SINGA
A distributed deep learning platform .
 All Classes Namespaces Files Functions Variables Typedefs Macros
param.h
1 /************************************************************
2 *
3 * Licensed to the Apache Software Foundation (ASF) under one
4 * or more contributor license agreements. See the NOTICE file
5 * distributed with this work for additional information
6 * regarding copyright ownership. The ASF licenses this file
7 * to you under the Apache License, Version 2.0 (the
8 * "License"); you may not use this file except in compliance
9 * with the License. You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing,
14 * software distributed under the License is distributed on an
15 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 * KIND, either express or implied. See the License for the
17 * specific language governing permissions and limitations
18 * under the License.
19 *
20 *************************************************************/
21 
22 #ifndef SINGA_UTILS_PARAM_H_
23 #define SINGA_UTILS_PARAM_H_
24 
25 #include <memory>
26 #include <string>
27 #include <vector>
28 #include "communication/msg.h"
29 #include "proto/job.pb.h"
30 #include "utils/blob.h"
31 
32 namespace singa {
33 
38  public:
39  static ParamGenerator* Create(const ParamGenProto& proto);
40 
41  virtual ~ParamGenerator() {}
42 
43  virtual void Init(const ParamGenProto& proto) { proto_ = proto; }
44  virtual void Fill(Blob<float>* data);
45 
46  protected:
47  ParamGenProto proto_;
48 };
49 
50 class GaussianGen : public ParamGenerator {
51  public:
52  void Fill(Blob<float>* data) override;
53 };
54 
56  public:
57  void Fill(Blob<float>* data) override;
58 };
59 
60 class UniformGen : public ParamGenerator {
61  public:
62  void Fill(Blob<float>* data) override;
63 };
64 
66  public:
67  void Fill(Blob<float>* data) override;
68 };
69 
71  public:
72  void Fill(Blob<float>* data) override;
73 };
74 
93 class Param {
94  public:
95  static Param* Create(const ParamProto& proto);
96 
97  Param() {}
98  virtual ~Param() {}
99  void Init(const ParamProto& proto) { proto_ = proto; }
106  virtual void Setup(const std::vector<int>& shape);
107  /*
108  * Fill the values according to init method, e.g., gaussian distribution.
109  *
110  * @param version initial version
111  */
112  virtual void InitValues();
113  virtual void InitValues(int version);
119  void ShareFrom(const Param& other);
123  void FromProto(const BlobProto& blob);
127  void ToProto(BlobProto* blob);
134  void AddSlice(int slice_id, int size);
138  inline float lr_scale() const { return proto_.lr_scale(); }
142  inline float wd_scale() const { return proto_.wd_scale(); }
147  inline const std::string& name() const { return proto_.name(); }
148  inline void set_name(const std::string& name) { proto_.set_name(name); }
153  inline int owner() const { return proto_.owner(); }
157  inline int id() const { return proto_.id(); }
161  inline void set_id(int id) {
162  proto_.set_id(id);
163  proto_.set_owner(id);
164  }
170  inline int version() const { return data_->version(); }
171  inline void set_version(int v) { data_->set_version(v); }
175  inline int local_version() const { return local_version_; }
176  inline void set_local_version(int v) { local_version_ = v; }
177  inline const std::string& share_from() const { return proto_.share_from(); }
181  inline int size() const { return data_->count(); }
182  inline const Blob<float>& data() const { return *data_; }
183  inline Blob<float>* mutable_data() { return data_.get(); }
184  inline const Blob<float> &grad() const { return grad_; }
185  inline Blob<float> *mutable_grad() { return &grad_; }
186  inline float* mutable_cpu_data() { return data_->mutable_cpu_data(); }
187  inline float* mutable_cpu_grad() { return grad_.mutable_cpu_data(); }
188  inline float* mutable_cpu_history() { return history_.mutable_cpu_data(); }
192  inline int slice_start() const { return slice_start_; }
193  inline int num_slices() const { return num_slices_; }
194 
221  virtual Msg* GenPutMsg(bool copy, int slice_idx);
226  virtual Msg* GenGetMsg(bool copy, int slice_idx);
232  virtual Msg* GenUpdateMsg(bool copy, int slice_idx);
239  virtual Msg* GenSyncMsg(int offset, int size);
248  virtual Msg* HandlePutMsg(Msg** msg, bool reserve);
254  virtual Msg* HandleGetMsg(Msg** msg, bool reserve);
259  virtual void ParseUpdateMsgs(const std::vector<Msg*>& msgs);
272  virtual const std::vector<Msg*>
273  GenUpdateResponseMsgs(std::vector<Msg*>* msgs, bool reserve);
279  virtual Msg* HandleSyncMsg(Msg** msg, bool reserve);
286  virtual int ParseGetResponseMsg(Msg* msg, int slice_idx);
292  virtual int ParseUpdateResponseMsg(Msg* msg, int slice_idx);
298  virtual int ParseSyncResponseMsg(Msg* msg, int slice_idx);
299 
300  protected:
305  void ParseResponseMsg(Msg* msg, int slice_idx);
306 
307  protected:
308  int local_version_ = -1;
309  // the ID of the first slice
310  int slice_start_ = 0;
311  int num_slices_ = 0;
312  // offset and size of each slice
313  std::vector<int> slice_offset_;
314  std::vector<int> slice_size_;
315  // for debug checking
316  // since put request has no feedback, we do not track its pending status
317  std::vector<bool> pending_get_;
318  std::vector<bool> pending_update_;
319  int num_pending_requests_ = 0;
320  // data field
321  std::shared_ptr<Blob<float>> data_ = nullptr;
322  // gradient, history gradient of this parameter
323  Blob<float> grad_, history_;
324  ParamProto proto_;
325 };
326 
335 class ParamEntry {
336  public:
337  ParamEntry() {}
338  ParamEntry(int total, Param* p);
345  void AddParam(bool local, Param* p);
346  int next_version = -1; // next_version & num_update are directly used by stub
347  int num_update = 0;
348  int num_local = 0;
349  int num_total = 0;
350  std::vector<Param*> shares;
352 };
353 
354 inline int ParamTrgt(int param_id, int slice_id) {
355  return (param_id << 16) | slice_id;
356 }
357 
358 inline int ParamID(int param_trgt) {
359  return param_trgt >> 16;
360 }
361 
362 inline int SliceID(int param_trgt) {
363  static const int mask = (1 << 16) -1;
364  return param_trgt & mask;
365 }
366 
367 } // namespace singa
368 
369 #endif // SINGA_UTILS_PARAM_H_
int slice_start() const
Definition: param.h:192
int num_total
total workers using the shared parameter
Definition: param.h:349
Definition: param.h:55
virtual Msg * HandlePutMsg(Msg **msg, bool reserve)
Server handling function for put request.
float wd_scale() const
Scale the weight decay when updating parameters in the Param object.
Definition: param.h:142
Base paramter class.
Definition: param.h:93
Definition: param.h:50
void AddParam(bool local, Param *p)
Associate the counter to a Param object.
virtual Msg * GenSyncMsg(int offset, int size)
Generate the message for a synchronization request between server groups.
ParamEntry is used for aggregating gradients of Params shared by workers from the same group...
Definition: param.h:335
const std::string & name() const
Parameter name used for Param re-use in other model or sharing between layers.
Definition: param.h:147
virtual Msg * GenUpdateMsg(bool copy, int slice_idx)
Generate the message for a update request, i.e., pass info to server for parameter update...
Definition: param.h:65
int num_local
local workers using the shared parameter
Definition: param.h:348
virtual Msg * HandleSyncMsg(Msg **msg, bool reserve)
Server handling function for synchronization message.
void ParseResponseMsg(Msg *msg, int slice_idx)
Implement the common code of ParseGetResponseMsg and ParseUpdateResponseMsg.
Definition: param.h:70
Definition: param.h:60
virtual const std::vector< Msg * > GenUpdateResponseMsgs(std::vector< Msg * > *msgs, bool reserve)
Generate the messages to response the update requests.
virtual Msg * GenPutMsg(bool copy, int slice_idx)
Below are message/request related functions.
void ShareFrom(const Param &other)
Share the data blob from other Param objects.
Base parameter generator which intializes parameter values.
Definition: param.h:37
virtual void Setup(const std::vector< int > &shape)
Setup param object.
int local_version() const
Definition: param.h:175
virtual int ParseGetResponseMsg(Msg *msg, int slice_idx)
Worker/Stub parsing function for get response.
void FromProto(const BlobProto &blob)
Init param values from checkpoint blob.
void ToProto(BlobProto *blob)
Dump param values to blob.
virtual void ParseUpdateMsgs(const std::vector< Msg * > &msgs)
Server parse update requests.
float lr_scale() const
Scale the learning rate when updating parameters in the Param object.
Definition: param.h:138
virtual Msg * HandleGetMsg(Msg **msg, bool reserve)
Server handling function for put request.
virtual int ParseSyncResponseMsg(Msg *msg, int slice_idx)
Server parsing function for synchronization response.
void set_id(int id)
Set ID.
Definition: param.h:161
virtual Msg * GenGetMsg(bool copy, int slice_idx)
Generate the message for a get request, i.e., get parameters from a server.
void AddSlice(int slice_id, int size)
Add a slice.
int id() const
ID start from 0 and ordered for all Param from the same neuralnet.
Definition: param.h:157
int owner() const
If it shares data from others, then owner is the id of that Param, otherwise it is itself's id...
Definition: param.h:153
int size() const
Definition: param.h:181
virtual int ParseUpdateResponseMsg(Msg *msg, int slice_idx)
Worker/Server parsing function for update response.
int version() const
Param version is stored inside the data blob to enable all Param objs sharing the same values have th...
Definition: param.h:170