Caffe2 - C++ API
A deep learning, cross platform ML framework
accuracy_op.cc
1 #include "caffe2/operators/accuracy_op.h"
2 
3 namespace caffe2 {
4 
5 template <>
6 bool AccuracyOp<float, CPUContext>::RunOnDevice() {
7  auto& X = Input(PREDICTION);
8  auto& label = Input(LABEL);
9 
10  CAFFE_ENFORCE_EQ(X.dim(), 2);
11  int N = X.dim32(0);
12  int D = X.dim32(1);
13  CAFFE_ENFORCE_EQ(label.dim(), 1);
14  CAFFE_ENFORCE_EQ(label.dim32(0), N);
15  auto* Y = Output(0, vector<int64_t>(), at::dtype<float>());
16  const auto* Xdata = X.data<float>();
17  const auto* labelData = label.data<int>();
18  const int top_k = top_k_;
19  int correct = 0;
20 
21  // it's equivalent to using a stable sorting algorithm to sort the
22  // classes (with their predictions as key) and then check whether
23  // the label is within the first top_k slots.
24  for (int i = 0; i < N; ++i) {
25  auto label_i = labelData[i];
26  auto label_pred = Xdata[i * D + label_i];
27  int ngt = 1;
28  for (int j = 0; j < D; ++j) {
29  auto pred = Xdata[i * D + j];
30  if ((pred > label_pred) || (pred == label_pred && j < label_i)) {
31  if (++ngt > top_k) {
32  break;
33  }
34  }
35  }
36  if (ngt <= top_k) {
37  ++correct;
38  }
39  }
40  CAFFE_ENFORCE_LE(correct, N);
41  *(Y->template mutable_data<float>()) = static_cast<float>(correct) / N;
42 
43  return true;
44 }
45 
46 REGISTER_CPU_OPERATOR(Accuracy, AccuracyOp<float, CPUContext>);
47 
48 OPERATOR_SCHEMA(Accuracy)
49  .NumInputs(2)
50  .NumOutputs(1)
51  .ScalarType(TensorProto::FLOAT)
52  .SetDoc(R"DOC(
53 Accuracy takes two inputs- predictions and labels, and returns a float
54 accuracy value for the batch. Predictions are expected in the form of 2-D tensor
55 containing a batch of scores for various classes, and labels are expected in the
56  form of 1-D tensor containing true label indices of samples in the batch. If
57 the score for the label index in the predictions is the highest among all
58 classes, it is considered a correct prediction.
59 )DOC")
60  .Arg(
61  "top_k",
62  "Count as correct by comparing the true label to the top k scoring "
63  "classes (default 1: only compare to the top scoring class i.e. argmax)")
64  .Input(
65  0,
66  "predictions",
67  "2-D tensor (Tensor<float>) of size "
68  "(num_batches x num_classes) containing scores")
69  .Input(
70  1,
71  "labels",
72  "1-D tensor (Tensor<float>) of size (num_batches) having "
73  "the indices of true labels")
74  .Output(
75  0,
76  "accuracy",
77  "1-D tensor (Tensor<float>) of size 1 containing "
78  "accuracy");
79 
80 SHOULD_NOT_DO_GRADIENT(Accuracy);
81 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:70