Apache Singa
A General Distributed Deep Learning Library
loss.h
1 
19 #ifndef SINGA_MODEL_LOSS_H_
20 #define SINGA_MODEL_LOSS_H_
21 #include <stack>
22 #include "singa/proto/model.pb.h"
23 #include "singa/core/tensor.h"
24 namespace singa {
25 
30 // template <typename T = Tensor>
31 class Loss {
32 public:
33  Loss() = default;
34  void Setup(const string &conf) {
35  LossConf loss;
36  loss.ParseFromString(conf);
37  Setup(loss);
38  }
39  virtual ~Loss() {};
40  virtual void ToDevice(std::shared_ptr<Device> device) {}
42  virtual void Setup(const LossConf &conf) {}
43 
46  virtual Tensor Forward(int flag, const Tensor &prediction,
47  const Tensor &target) = 0;
48 
52  float Evaluate(int flag, const Tensor &prediction, const Tensor &target) {
53  Tensor loss = Forward(flag, prediction, target);
54  return Sum<float>(loss) / (1.0f * loss.Size());
55  }
56 
58  virtual Tensor Backward() = 0;
59 };
60 
61 // ============= Mean Squared Error ===========================================
63 class MSE : public Loss {
64  public:
69  Tensor Forward(int flag, const Tensor& prediction,
70  const Tensor& target) override;
71 
74  Tensor Backward() override;
75 
76  private:
77  // to buffer intermediate data, i.e., prediction-target
78  std::stack<Tensor> buf_;
79 };
80 
81 
82 // ===============Softamx Cross Entropy =======================================
84 class SoftmaxCrossEntropy : public Loss {
85  public:
99  Tensor Forward(int flag, const Tensor& prediction,
100  const Tensor& target) override;
101 
104  Tensor Backward() override;
105 
106  private:
107  // to buffer intermediate data, i.e., probability for each category and
108  // the target (ground truth)
109  std::stack<Tensor> buf_;
110 };
111 
112 } // namespace singa
113 
114 #endif // SINGA_MODEL_LOSS_H_
virtual Tensor Forward(int flag, const Tensor &prediction, const Tensor &target)=0
Compute the loss values for each sample/instance given the prediction and the target.
Softmax + cross entropy for multi-category classification.
Definition: loss.h:84
virtual Tensor Backward()=0
Compute the gradients of the loss values w.r.t. the prediction.
A Tensor instance is a multi-dimensional array resident on a Device (default device is the host CPU)...
Definition: tensor.h:56
The base loss class, which declares the APIs for computing the objective score (loss) for a pair of p...
Definition: loss.h:31
Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements...
Definition: common.h:48
size_t Size() const
Return number of total elements.
Definition: tensor.h:128
MSE is for mean squared error or squared euclidean distance.
Definition: loss.h:63
float Evaluate(int flag, const Tensor &prediction, const Tensor &target)
Average loss values for all samples in the mini-batch It calls Forward() internally.
Definition: loss.h:52
virtual void Setup(const LossConf &conf)
Set meta fields from user configurations.
Definition: loss.h:42