22 #ifndef SINGA_UTILS_PARAM_H_
23 #define SINGA_UTILS_PARAM_H_
28 #include "communication/msg.h"
29 #include "proto/job.pb.h"
30 #include "utils/blob.h"
43 virtual void Init(
const ParamGenProto& proto) { proto_ = proto; }
95 static Param* Create(
const ParamProto& proto);
99 void Init(
const ParamProto& proto) { proto_ = proto; }
106 virtual void Setup(
const std::vector<int>& shape);
112 virtual void InitValues();
113 virtual void InitValues(
int version);
138 inline float lr_scale()
const {
return proto_.lr_scale(); }
142 inline float wd_scale()
const {
return proto_.wd_scale(); }
147 inline const std::string&
name()
const {
return proto_.name(); }
148 inline void set_name(
const std::string&
name) { proto_.set_name(name); }
153 inline int owner()
const {
return proto_.owner(); }
157 inline int id()
const {
return proto_.id(); }
163 proto_.set_owner(
id);
170 inline int version()
const {
return data_->version(); }
171 inline void set_version(
int v) { data_->set_version(v); }
176 inline void set_local_version(
int v) { local_version_ = v; }
177 inline const std::string& share_from()
const {
return proto_.share_from(); }
181 inline int size()
const {
return data_->count(); }
182 inline const Blob<float>& data()
const {
return *data_; }
183 inline Blob<float>* mutable_data() {
return data_.get(); }
184 inline const Blob<float> &grad()
const {
return grad_; }
185 inline Blob<float> *mutable_grad() {
return &grad_; }
186 inline float* mutable_cpu_data() {
return data_->mutable_cpu_data(); }
187 inline float* mutable_cpu_grad() {
return grad_.mutable_cpu_data(); }
188 inline float* mutable_cpu_history() {
return history_.mutable_cpu_data(); }
193 inline int num_slices()
const {
return num_slices_; }
221 virtual Msg*
GenPutMsg(
bool copy,
int slice_idx);
226 virtual Msg*
GenGetMsg(
bool copy,
int slice_idx);
239 virtual Msg*
GenSyncMsg(
int offset,
int size);
272 virtual const std::vector<Msg*>
308 int local_version_ = -1;
310 int slice_start_ = 0;
313 std::vector<int> slice_offset_;
314 std::vector<int> slice_size_;
317 std::vector<bool> pending_get_;
318 std::vector<bool> pending_update_;
319 int num_pending_requests_ = 0;
321 std::shared_ptr<Blob<float>> data_ =
nullptr;
323 Blob<float> grad_, history_;
346 int next_version = -1;
350 std::vector<Param*> shares;
354 inline int ParamTrgt(
int param_id,
int slice_id) {
355 return (param_id << 16) | slice_id;
358 inline int ParamID(
int param_trgt) {
359 return param_trgt >> 16;
362 inline int SliceID(
int param_trgt) {
363 static const int mask = (1 << 16) -1;
364 return param_trgt & mask;
369 #endif // SINGA_UTILS_PARAM_H_
int slice_start() const
Definition: param.h:192
int num_total
total workers using the shared parameter
Definition: param.h:349
virtual Msg * HandlePutMsg(Msg **msg, bool reserve)
Server handling function for put request.
float wd_scale() const
Scale the weight decay when updating parameters in the Param object.
Definition: param.h:142
Base paramter class.
Definition: param.h:93
void AddParam(bool local, Param *p)
Associate the counter to a Param object.
virtual Msg * GenSyncMsg(int offset, int size)
Generate the message for a synchronization request between server groups.
ParamEntry is used for aggregating gradients of Params shared by workers from the same group...
Definition: param.h:335
const std::string & name() const
Parameter name used for Param re-use in other model or sharing between layers.
Definition: param.h:147
virtual Msg * GenUpdateMsg(bool copy, int slice_idx)
Generate the message for a update request, i.e., pass info to server for parameter update...
int num_local
local workers using the shared parameter
Definition: param.h:348
virtual Msg * HandleSyncMsg(Msg **msg, bool reserve)
Server handling function for synchronization message.
void ParseResponseMsg(Msg *msg, int slice_idx)
Implement the common code of ParseGetResponseMsg and ParseUpdateResponseMsg.
virtual const std::vector< Msg * > GenUpdateResponseMsgs(std::vector< Msg * > *msgs, bool reserve)
Generate the messages to response the update requests.
virtual Msg * GenPutMsg(bool copy, int slice_idx)
Below are message/request related functions.
void ShareFrom(const Param &other)
Share the data blob from other Param objects.
Base parameter generator which intializes parameter values.
Definition: param.h:37
virtual void Setup(const std::vector< int > &shape)
Setup param object.
int local_version() const
Definition: param.h:175
virtual int ParseGetResponseMsg(Msg *msg, int slice_idx)
Worker/Stub parsing function for get response.
void FromProto(const BlobProto &blob)
Init param values from checkpoint blob.
void ToProto(BlobProto *blob)
Dump param values to blob.
virtual void ParseUpdateMsgs(const std::vector< Msg * > &msgs)
Server parse update requests.
float lr_scale() const
Scale the learning rate when updating parameters in the Param object.
Definition: param.h:138
virtual Msg * HandleGetMsg(Msg **msg, bool reserve)
Server handling function for put request.
virtual int ParseSyncResponseMsg(Msg *msg, int slice_idx)
Server parsing function for synchronization response.
void set_id(int id)
Set ID.
Definition: param.h:161
virtual Msg * GenGetMsg(bool copy, int slice_idx)
Generate the message for a get request, i.e., get parameters from a server.
void AddSlice(int slice_id, int size)
Add a slice.
int id() const
ID start from 0 and ordered for all Param from the same neuralnet.
Definition: param.h:157
int owner() const
If it shares data from others, then owner is the id of that Param, otherwise it is itself's id...
Definition: param.h:153
int size() const
Definition: param.h:181
virtual int ParseUpdateResponseMsg(Msg *msg, int slice_idx)
Worker/Server parsing function for update response.
int version() const
Param version is stored inside the data blob to enable all Param objs sharing the same values have th...
Definition: param.h:170