Caffe2 - C++ API
A deep learning, cross platform ML framework
multi_class_accuracy_op.cc
1 #include "caffe2/operators/multi_class_accuracy_op.h"
2 
3 namespace caffe2 {
4 
5 template <>
6 bool MultiClassAccuracyOp<float, CPUContext>::RunOnDevice() {
7  auto& X = Input(PREDICTION);
8  auto& label = Input(LABEL);
9 
10  DCHECK_EQ(X.dim(), 2);
11  // amount, number of instances
12  int N = X.dim32(0);
13  // dimension, number of classes
14  int D = X.dim32(1);
15  DCHECK_EQ(label.dim(), 1);
16  DCHECK_EQ(label.dim32(0), N);
17  auto* Y0 = Output(0, {D}, at::dtype<float>());
18  auto* Y1 = Output(1, {D}, at::dtype<int>());
19 
20  const auto* Xdata = X.data<float>();
21  const auto* labeldata = label.data<int>();
22  auto* accuracies = Y0->template mutable_data<float>();
23  auto* amounts = Y1->template mutable_data<int>();
24  std::fill(accuracies, accuracies + D, 0);
25  std::fill(amounts, amounts + D, 0);
26 
27  for (int i = 0; i < N; ++i) {
28  float maxval = std::numeric_limits<float>::lowest();
29  int maxid = 0;
30  for (int j = 0; j < D; ++j) {
31  if (Xdata[i * D + j] > maxval) {
32  maxval = Xdata[i * D + j];
33  maxid = j;
34  }
35  }
36  int labelid = labeldata[i];
37  DCHECK_LT(labelid, D);
38  if (maxid == labelid) {
39  accuracies[labelid]++;
40  }
41  amounts[labelid]++;
42  }
43 
44  for (int i = 0; i < D; ++i) {
45  int amount = amounts[i];
46  if (amount) {
47  accuracies[i] /= amount;
48  }
49  }
50 
51  return true;
52 }
53 
54 REGISTER_CPU_OPERATOR(
55  MultiClassAccuracy, MultiClassAccuracyOp<float, CPUContext>);
56 
57 OPERATOR_SCHEMA(MultiClassAccuracy)
58  .NumInputs(2)
59  .NumOutputs(2)
60  .SetDoc(R"DOC(
61 Respectively compute accuracy score for each class given a number of instances
62 and predicted scores of each class for each instance.
63 )DOC")
64  .Input(
65  0,
66  "prediction",
67  "2-D float tensor (N,D,) of predicted scores of each class for "
68  "each data. N is the number of instances, i.e., batch size. D is number of "
69  "possible classes/labels.")
70  .Input(
71  1,
72  "labels",
73  "1-D int tensor (N,) of labels for each instance.")
74  .Output(
75  0,
76  "accuracies",
77  "1-D float tensor (D,) of accuracy for each class. If a class has no "
78  "instance in the batch, its accuracy score is set to zero.")
79  .Output(
80  1,
81  "amounts",
82  "1-D int tensor (D,) of number of instances for each class in the batch.");
83 
84 SHOULD_NOT_DO_GRADIENT(MultiClassAccuracy);
85 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:70