1 /************************************************************
2 *
3 * Licensed to the Apache Software Foundation (ASF) under one
4 * or more contributor license agreements. See the NOTICE file
5 * distributed with this work for additional information
6 * regarding copyright ownership. The ASF licenses this file
7 * to you under the Apache License, Version 2.0 (the
8 * "License"); you may not use this file except in compliance
9 * with the License. You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing,
14 * software distributed under the License is distributed on an
16 * KIND, either express or implied. See the License for the
17 * specific language governing permissions and limitations
18 * under the License.
19 *
20 *************************************************************/
25 #include <memory>
26 #include <string>
27 #include <vector>
28 #include "communication/msg.h"
29 #include "proto/job.pb.h"
30 #include "utils/blob.h"
32 namespace singa {
38  public:
39  static ParamGenerator* Create(const ParamGenProto& proto);
41  virtual ~ParamGenerator() {}
43  virtual void Init(const ParamGenProto& proto) { proto_ = proto; }
44  virtual void Fill(Blob<float>* data);
46  protected:
47  ParamGenProto proto_;
48 };
50 class GaussianGen : public ParamGenerator {
51  public:
52  void Fill(Blob<float>* data) override;
53 };
56  public:
57  void Fill(Blob<float>* data) override;
58 };
60 class UniformGen : public ParamGenerator {
61  public:
62  void Fill(Blob<float>* data) override;
63 };
66  public:
67  void Fill(Blob<float>* data) override;
68 };
71  public:
72  void Fill(Blob<float>* data) override;
73 };
93 class Param {
94  public:
95  static Param* Create(const ParamProto& proto);
97  Param() {}
98  virtual ~Param() {}
99  void Init(const ParamProto& proto) { proto_ = proto; }
106  virtual void Setup(const std::vector<int>& shape);
107  /*
108  * Fill the values according to init method, e.g., gaussian distribution.
109  *
110  * @param version initial version
111  */
112  virtual void InitValues();
113  virtual void InitValues(int version);
119  void ShareFrom(const Param& other);
123  void FromProto(const BlobProto& blob);
127  void ToProto(BlobProto* blob);
134  void AddSlice(int slice_id, int size);
138  inline float lr_scale() const { return proto_.lr_scale(); }
142  inline float wd_scale() const { return proto_.wd_scale(); }
147  inline const std::string& name() const { return proto_.name(); }
148  inline void set_name(const std::string& name) { proto_.set_name(name); }
153  inline int owner() const { return proto_.owner(); }
157  inline int id() const { return proto_.id(); }
161  inline void set_id(int id) {
162  proto_.set_id(id);
163  proto_.set_owner(id);
164  }
170  inline int version() const { return data_->version(); }
171  inline void set_version(int v) { data_->set_version(v); }
175  inline int local_version() const { return local_version_; }
176  inline void set_local_version(int v) { local_version_ = v; }
177  inline const std::string& share_from() const { return proto_.share_from(); }
181  inline int size() const { return data_->count(); }
182  inline const Blob<float>& data() const { return *data_; }
183  inline Blob<float>* mutable_data() { return data_.get(); }
184  inline const Blob<float> &grad() const { return grad_; }
185  inline Blob<float> *mutable_grad() { return &grad_; }
186  inline float* mutable_cpu_data() { return data_->mutable_cpu_data(); }
187  inline float* mutable_cpu_grad() { return grad_.mutable_cpu_data(); }
188  inline float* mutable_cpu_history() { return history_.mutable_cpu_data(); }
192  inline int slice_start() const { return slice_start_; }
193  inline int num_slices() const { return num_slices_; }
221  virtual Msg* GenPutMsg(bool copy, int slice_idx);
226  virtual Msg* GenGetMsg(bool copy, int slice_idx);
232  virtual Msg* GenUpdateMsg(bool copy, int slice_idx);
239  virtual Msg* GenSyncMsg(int offset, int size);
248  virtual Msg* HandlePutMsg(Msg** msg, bool reserve);
254  virtual Msg* HandleGetMsg(Msg** msg, bool reserve);
259  virtual void ParseUpdateMsgs(const std::vector<Msg*>& msgs);
272  virtual const std::vector<Msg*>
273  GenUpdateResponseMsgs(std::vector<Msg*>* msgs, bool reserve);
279  virtual Msg* HandleSyncMsg(Msg** msg, bool reserve);
286  virtual int ParseGetResponseMsg(Msg* msg, int slice_idx);
292  virtual int ParseUpdateResponseMsg(Msg* msg, int slice_idx);
298  virtual int ParseSyncResponseMsg(Msg* msg, int slice_idx);
300  protected:
305  void ParseResponseMsg(Msg* msg, int slice_idx);
307  protected:
308  int local_version_ = -1;
309  // the ID of the first slice
310  int slice_start_ = 0;
311  int num_slices_ = 0;
312  // offset and size of each slice
313  std::vector<int> slice_offset_;
314  std::vector<int> slice_size_;
315  // for debug checking
316  // since put request has no feedback, we do not track its pending status
317  std::vector<bool> pending_get_;
318  std::vector<bool> pending_update_;
319  int num_pending_requests_ = 0;
320  // data field
321  std::shared_ptr<Blob<float>> data_ = nullptr;
322  // gradient, history gradient of this parameter
323  Blob<float> grad_, history_;
324  ParamProto proto_;
325 };
335 class ParamEntry {
336  public:
337  ParamEntry() {}
338  ParamEntry(int total, Param* p);
345  void AddParam(bool local, Param* p);
346  int next_version = -1; // next_version & num_update are directly used by stub
347  int num_update = 0;
348  int num_local = 0;
349  int num_total = 0;
350  std::vector<Param*> shares;
352 };
354 inline int ParamTrgt(int param_id, int slice_id) {
355  return (param_id << 16) | slice_id;
356 }
358 inline int ParamID(int param_trgt) {
359  return param_trgt >> 16;
360 }
362 inline int SliceID(int param_trgt) {
363  static const int mask = (1 << 16) -1;
364  return param_trgt & mask;
365 }
367 } // namespace singa
369 #endif // SINGA_UTILS_PARAM_H_
