Apache Singa
A General Distributed Deep Learning Library
metric.h
1 
19 #ifndef SINGA_MODEL_METRIC_H_
20 #define SINGA_MODEL_METRIC_H_
21 #include "singa/core/tensor.h"
22 #include "singa/proto/model.pb.h"
23 namespace singa {
24 
31 // template <typename T = Tensor>
32 class Metric {
33  public:
34  // TODO(wangwei) call Setup using a default MetricConf.
35  Metric() = default;
36  virtual ~Metric() {}
37  virtual void ToDevice(std::shared_ptr<Device> device) {}
38  void Setup(const string& conf) {
39  MetricConf metric;
40  metric.ParseFromString(conf);
41  Setup(metric);
42  }
43 
45  virtual void Setup(const MetricConf& conf) {}
46 
48  virtual Tensor Forward(const Tensor& prediction, const Tensor& target) = 0;
49 
51  float Evaluate(const Tensor& prediction, const Tensor& target) {
52  const Tensor metric = Forward(prediction, target);
53  return Sum<float>(metric) / (1.0f * metric.Size());
54  }
55 };
59 class Accuracy : public Metric {
60  public:
62  void Setup(const MetricConf& conf) override { top_k_ = conf.top_k(); }
63 
68  Tensor Forward(const Tensor& prediction, const Tensor& target);
69 
70  private:
72  Tensor Match(const Tensor& prediction, const vector<int>& target);
75  size_t top_k_ = 1;
76 };
77 
78 
79 } // namespace singa
80 
81 #endif // SINGA_MODEL_METRIC_H_
void Setup(const MetricConf &conf) override
Set meta fields from user configurations.
Definition: metric.h:62
float Evaluate(const Tensor &prediction, const Tensor &target)
Comptue the metric value averaged over all samples (in a batch)
Definition: metric.h:51
virtual void Setup(const MetricConf &conf)
Set meta fields from user configurations.
Definition: metric.h:45
A Tensor instance is a multi-dimensional array resident on a Device (default device is the host CPU)...
Definition: tensor.h:56
virtual Tensor Forward(const Tensor &prediction, const Tensor &target)=0
Compute the metric for each data sample.
Compute the accuray of the prediction, which is matched against the ground truth labels.
Definition: metric.h:59
The base metric class, which declares the APIs for computing the performance evaluation metrics given...
Definition: metric.h:32
Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements...
Definition: common.h:48
size_t Size() const
Return number of total elements.
Definition: tensor.h:128