19 #ifndef SINGA_MODEL_METRIC_H_ 20 #define SINGA_MODEL_METRIC_H_ 21 #include "singa/core/tensor.h" 22 #include "singa/proto/model.pb.h" 37 virtual void ToDevice(std::shared_ptr<Device> device) {}
38 void Setup(
const string& conf) {
40 metric.ParseFromString(conf);
45 virtual void Setup(
const MetricConf& conf) {}
53 return Sum<float>(metric) / (1.0f * metric.
Size());
62 void Setup(
const MetricConf& conf)
override { top_k_ = conf.top_k(); }
72 Tensor Match(
const Tensor& prediction,
const vector<int>& target);
81 #endif // SINGA_MODEL_METRIC_H_ void Setup(const MetricConf &conf) override
Set meta fields from user configurations.
Definition: metric.h:62
float Evaluate(const Tensor &prediction, const Tensor &target)
Comptue the metric value averaged over all samples (in a batch)
Definition: metric.h:51
virtual void Setup(const MetricConf &conf)
Set meta fields from user configurations.
Definition: metric.h:45
A Tensor instance is a multi-dimensional array resident on a Device (default device is the host CPU)...
Definition: tensor.h:56
virtual Tensor Forward(const Tensor &prediction, const Tensor &target)=0
Compute the metric for each data sample.
Compute the accuray of the prediction, which is matched against the ground truth labels.
Definition: metric.h:59
The base metric class, which declares the APIs for computing the performance evaluation metrics given...
Definition: metric.h:32
Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements...
Definition: common.h:48
size_t Size() const
Return number of total elements.
Definition: tensor.h:128