19 #ifndef SINGA_MODEL_UPDATER_H_ 20 #define SINGA_MODEL_UPDATER_H_ 22 #include "singa/model/optimizer.h" 23 #include "singa/core/device.h" 24 #include "singa/core/tensor.h" 25 #include "singa/utils/logging.h" 30 #include <condition_variable> 33 #include <unordered_map> 44 virtual void Setup(
const OptimizerConf& conf);
46 virtual void Register(
const string& name,
const ParamSpec& specs);
49 Optimizer* GetOptimizer() {
return opt_; }
53 void operator=(
const Updater&) =
delete;
65 :
Updater(opt), total_num_{total_num}, dev_(dev) {}
68 virtual void Register(
const string& name,
const ParamSpec& specs)
override;
72 virtual void Apply(
int step,
const string& name,
Tensor& grad,
75 template <
typename T1,
typename T2>
77 size_t operator() (
const std::pair<T1, T2>& p)
const {
78 auto h1 = std::hash<T1>{}(p.first);
79 auto h2 = std::hash<T2>{}(p.second);
85 std::shared_ptr<Device> dev_;
86 std::unordered_map<std::string, std::atomic<int>> dev_index_;
87 std::unordered_map<std::string, int> to_updater_finished_;
88 std::unordered_map<std::pair<int, std::string>,
Tensor,
89 key_hasher<int, std::string>> grad_buffer_;
90 std::unordered_map<std::string, Tensor> sum_, param_buffer_;
91 std::unordered_map<std::string, std::mutex> mtx_;
92 std::unordered_map<std::string, std::condition_variable>
93 to_updater_all_finished_;
97 #endif // SINGA_MODEL_UPDATER_H_ virtual void Apply(int step, const string &name, Tensor &grad, Tensor &value)
Forward Apply() to Optimizer.
A Tensor instance is a multi-dimensional array resident on a Device (default device is the host CPU)...
Definition: tensor.h:56
Basic Updater class just forward all the method function call to the wrapped Optimizer.
Definition: updater.h:39
std::shared_ptr< Device > defaultDevice
a singleton CppDevice as the host for all devices.
virtual void Setup(const OptimizerConf &conf)
Forward Setup() to Optimizer.
virtual void Register(const string &name, const ParamSpec &specs)
Forward Register() to Optimizer.
Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements...
Definition: common.h:48
The base class for gradient descent algorithms used to update the model parameters in order to optimi...
Definition: optimizer.h:41
LocalUpdater do gradient aggregation and update gradient calling the wrapped Optimizer on a specific ...
Definition: updater.h:61