Caffe2 - C++ API
A deep learning, cross platform ML framework
accuracy_op.cc
1 
17 #include "caffe2/operators/accuracy_op.h"
18 
19 namespace caffe2 {
20 
21 template <>
22 bool AccuracyOp<float, CPUContext>::RunOnDevice() {
23  auto& X = Input(PREDICTION);
24  auto& label = Input(LABEL);
25  auto* Y = Output(0);
26  CAFFE_ENFORCE_EQ(X.ndim(), 2);
27  int N = X.dim32(0);
28  int D = X.dim32(1);
29  CAFFE_ENFORCE_EQ(label.ndim(), 1);
30  CAFFE_ENFORCE_EQ(label.dim32(0), N);
31  Y->Resize(vector<TIndex>());
32  const auto* Xdata = X.data<float>();
33  const auto* labelData = label.data<int>();
34  const int top_k = top_k_;
35  int correct = 0;
36 
37  // it's equivalent to using a stable sorting algorithm to sort the
38  // classes (with their predictions as key) and then check whether
39  // the label is within the first top_k slots.
40  for (int i = 0; i < N; ++i) {
41  auto label_i = labelData[i];
42  auto label_pred = Xdata[i * D + label_i];
43  int ngt = 1;
44  for (int j = 0; j < D; ++j) {
45  auto pred = Xdata[i * D + j];
46  if ((pred > label_pred) || (pred == label_pred && j < label_i)) {
47  if (++ngt > top_k) {
48  break;
49  }
50  }
51  }
52  if (ngt <= top_k) {
53  ++correct;
54  }
55  }
56  CAFFE_ENFORCE_LE(correct, N);
57  *(Y->mutable_data<float>()) = static_cast<float>(correct) / N;
58 
59  return true;
60 }
61 
62 REGISTER_CPU_OPERATOR(Accuracy, AccuracyOp<float, CPUContext>);
63 
64 OPERATOR_SCHEMA(Accuracy)
65  .NumInputs(2)
66  .NumOutputs(1)
67  .ScalarType(TensorProto::FLOAT)
68  .SetDoc(R"DOC(
69 Accuracy takes two inputs- predictions and labels, and returns a float
70 accuracy value for the batch. Predictions are expected in the form of 2-D tensor
71 containing a batch of scores for various classes, and labels are expected in the
72  form of 1-D tensor containing true label indices of samples in the batch. If
73 the score for the label index in the predictions is the highest among all
74 classes, it is considered a correct prediction.
75 )DOC")
76  .Arg(
77  "top_k",
78  "Count as correct by comparing the true label to the top k scoring "
79  "classes (default 1: only compare to the top scoring class i.e. argmax)")
80  .Input(0, "predictions", "2-D tensor (Tensor<float>) of size "
81  "(num_batches x num_classes) containing scores")
82  .Input(1, "labels", "1-D tensor (Tensor<int>) of size (num_batches) having "
83  "the indices of true labels")
84  .Output(0, "accuracy", "1-D tensor (Tensor<float>) of size 1 containing "
85  "accuracy");
86 
87 SHOULD_NOT_DO_GRADIENT(Accuracy);
88 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.