Apache Singa
A General Distributed Deep Learning Library
updater.h
1 
19 #ifndef SINGA_MODEL_UPDATER_H_
20 #define SINGA_MODEL_UPDATER_H_
21 
22 #include "singa/model/optimizer.h"
23 #include "singa/core/device.h"
24 #include "singa/core/tensor.h"
25 #include "singa/utils/logging.h"
26 
27 #include <memory>
28 #include <vector>
29 #include <mutex>
30 #include <condition_variable>
31 #include <string>
32 #include <utility>
33 #include <unordered_map>
34 #include <atomic>
35 
36 namespace singa {
39 class Updater {
40  public:
41  explicit Updater(Optimizer* opt) : opt_{opt} {}
42  virtual ~Updater() {}
44  virtual void Setup(const OptimizerConf& conf);
46  virtual void Register(const string& name, const ParamSpec& specs);
48  virtual void Apply(int step, const string& name, Tensor& grad, Tensor& value);
49  Optimizer* GetOptimizer() { return opt_; }
50 
51  // No copy allowed.
52  Updater(const Updater&) = delete;
53  void operator=(const Updater&) = delete;
54 
55  protected:
56  Optimizer* opt_;
57 };
58 
61 class LocalUpdater : public Updater {
62  public:
63  LocalUpdater(int total_num, Optimizer* opt,
64  std::shared_ptr<Device> dev = defaultDevice)
65  : Updater(opt), total_num_{total_num}, dev_(dev) {}
66  virtual ~LocalUpdater() override {}
68  virtual void Register(const string& name, const ParamSpec& specs) override;
72  virtual void Apply(int step, const string& name, Tensor& grad,
73  Tensor& value) override;
74  private:
75  template <typename T1, typename T2>
76  struct key_hasher {
77  size_t operator() (const std::pair<T1, T2>& p) const {
78  auto h1 = std::hash<T1>{}(p.first);
79  auto h2 = std::hash<T2>{}(p.second);
80  return h1 ^ h2;
81  }
82  };
83 
84  int total_num_;
85  std::shared_ptr<Device> dev_;
86  std::unordered_map<std::string, std::atomic<int>> dev_index_;
87  std::unordered_map<std::string, int> to_updater_finished_;
88  std::unordered_map<std::pair<int, std::string>, Tensor,
89  key_hasher<int, std::string>> grad_buffer_;
90  std::unordered_map<std::string, Tensor> sum_, param_buffer_;
91  std::unordered_map<std::string, std::mutex> mtx_;
92  std::unordered_map<std::string, std::condition_variable>
93  to_updater_all_finished_;
94 };
95 } // namespace singa
96 
97 #endif // SINGA_MODEL_UPDATER_H_
virtual void Apply(int step, const string &name, Tensor &grad, Tensor &value)
Forward Apply() to Optimizer.
A Tensor instance is a multi-dimensional array resident on a Device (default device is the host CPU)...
Definition: tensor.h:56
Basic Updater class just forward all the method function call to the wrapped Optimizer.
Definition: updater.h:39
std::shared_ptr< Device > defaultDevice
a singleton CppDevice as the host for all devices.
virtual void Setup(const OptimizerConf &conf)
Forward Setup() to Optimizer.
virtual void Register(const string &name, const ParamSpec &specs)
Forward Register() to Optimizer.
Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements...
Definition: common.h:48
The base class for gradient descent algorithms used to update the model parameters in order to optimi...
Definition: optimizer.h:41
LocalUpdater do gradient aggregation and update gradient calling the wrapped Optimizer on a specific ...
Definition: updater.h:61