Apache Singa
A General Distributed Deep Learning Library
|
The base metric class, which declares the APIs for computing the performance evaluation metrics given the prediction of the model and the ground truth, i.e., the target. More...
#include <metric.h>
Public Member Functions | |
virtual void | ToDevice (std::shared_ptr< Device > device) |
void | Setup (const string &conf) |
virtual void | Setup (const MetricConf &conf) |
Set meta fields from user configurations. | |
virtual Tensor | Forward (const Tensor &prediction, const Tensor &target)=0 |
Compute the metric for each data sample. | |
float | Evaluate (const Tensor &prediction, const Tensor &target) |
Comptue the metric value averaged over all samples (in a batch) | |
The base metric class, which declares the APIs for computing the performance evaluation metrics given the prediction of the model and the ground truth, i.e., the target.
The target type is a template argument. For data samples with a single label, T could be 1-d tenor (or vector<int>); If each data sample has multiple labels, T could be vector<vector<int>>, one vector per sample.