1 #ifndef INCLUDE_UTILS_PARAM_H_
2 #define INCLUDE_UTILS_PARAM_H_
7 #include "proto/model.pb.h"
8 #include "utils/blob.h"
9 #include "communication/msg.h"
14 Param():data_(
nullptr){}
17 virtual Msg* GenGetMsg(
void* arg=
nullptr);
18 virtual Msg* GenPutMsg(
void* arg=
nullptr);
19 virtual Msg* GenUpdateMsg(
void* arg=
nullptr);
20 virtual Msg* GenSyncMsg(
void* arg=
nullptr);
22 virtual Msg* HandleGetMsg(
Msg** msg);
23 virtual Msg* HandlePutMsg(
Msg** msg);
24 virtual int ParseUpdateMsg(
Msg** msg);
25 virtual Msg* GenUpdateResponseMsg(
void* arg=
nullptr);
26 virtual Msg* HandleSyncMsg(
Msg** msg);
28 virtual int ParseGetResponseMsg(
Msg** msg);
29 virtual int ParsePutResponseMsg(
Msg** msg);
30 virtual int ParseUpdateResponseMsg(
Msg** msg);
31 virtual int ParseSyncResponseMsg(
Msg** msg);
36 virtual void Setup(
const ParamProto& proto,
const std::vector<int>& shape,
int fan_in);
40 virtual void Init(
int v=0);
41 void ShareData(shared_ptr<Param> other){
42 proto_.set_owner(other->owner());
44 CHECK(std::equal(data_->shape().begin(), data_->shape().end(),
45 other->data_->shape().begin()));
48 float learning_rate_multiplier() {
49 return proto_.learning_rate_multiplier();
51 float weight_decay_multiplier() {
52 return proto_.weight_decay_multiplier();
59 const std::string& name() {
67 return proto_.owner();
78 return data_->version();
80 void set_version(
int v) {
81 data_->set_version(v);
87 return data_->count();
115 float* mutable_cpu_data(){
116 return data_->mutable_cpu_data();
118 float* mutable_cpu_grad(){
119 return grad_.mutable_cpu_data();
121 float* mutable_cpu_history(){
122 return history_.mutable_cpu_data();
129 shared_ptr<Blob<float>> data_;
172 #endif // INCLUDE_UTILS_PARAM_H_
std::string name_
name of the parameter used to share wights between neuralnets
Definition: param.h:128
const int owner() const
if the Param shares data with others, then owner is the id of that param.
Definition: param.h:66
Blob< float > grad_
gradient, history gradient of this parameter
Definition: param.h:131
const Blob< float > & data()
Return const mem address for the content of this parameter.
Definition: param.h:92
virtual void Setup(const ParamProto &proto, const std::vector< int > &shape, int fan_in)
setup param shape
const Blob< float > & grad()
Return gradient of this parameter.
Definition: param.h:101
Definition: model.pb.h:764
int size() const
Definition: param.h:86