Caffe2 - C++ API
A deep learning, cross platform ML framework
multi_class_accuracy_op.cc
1 
17 #include "caffe2/operators/multi_class_accuracy_op.h"
18 
19 namespace caffe2 {
20 
21 template <>
22 bool MultiClassAccuracyOp<float, CPUContext>::RunOnDevice() {
23  auto& X = Input(PREDICTION);
24  auto& label = Input(LABEL);
25  auto* Y0 = Output(0);
26  auto* Y1 = Output(1);
27  DCHECK_EQ(X.ndim(), 2);
28  // amount, number of instances
29  int N = X.dim32(0);
30  // dimension, number of classes
31  int D = X.dim32(1);
32  DCHECK_EQ(label.ndim(), 1);
33  DCHECK_EQ(label.dim32(0), N);
34  Y0->Resize(D);
35  Y1->Resize(D);
36 
37  const auto* Xdata = X.data<float>();
38  const auto* labeldata = label.data<int>();
39  auto* accuracies = Y0->mutable_data<float>();
40  auto* amounts = Y1->mutable_data<int>();
41  std::fill(accuracies, accuracies + D, 0);
42  std::fill(amounts, amounts + D, 0);
43 
44  for (int i = 0; i < N; ++i) {
45  float maxval = std::numeric_limits<float>::lowest();
46  int maxid = 0;
47  for (int j = 0; j < D; ++j) {
48  if (Xdata[i * D + j] > maxval) {
49  maxval = Xdata[i * D + j];
50  maxid = j;
51  }
52  }
53  int labelid = labeldata[i];
54  DCHECK_LT(labelid, D);
55  if (maxid == labelid) {
56  accuracies[labelid]++;
57  }
58  amounts[labelid]++;
59  }
60 
61  for (int i = 0; i < D; ++i) {
62  int amount = amounts[i];
63  if (amount) {
64  accuracies[i] /= amount;
65  }
66  }
67 
68  return true;
69 }
70 
71 REGISTER_CPU_OPERATOR(
72  MultiClassAccuracy, MultiClassAccuracyOp<float, CPUContext>);
73 
74 OPERATOR_SCHEMA(MultiClassAccuracy)
75  .NumInputs(2)
76  .NumOutputs(2)
77  .SetDoc(R"DOC(
78 Respectively compute accuracy score for each class given a number of instances
79 and predicted scores of each class for each instance.
80 )DOC")
81  .Input(
82  0,
83  "prediction",
84  "2-D float tensor (N,D,) of predicted scores of each class for "
85  "each data. N is the number of instances, i.e., batch size. D is number of "
86  "possible classes/labels.")
87  .Input(
88  1,
89  "labels",
90  "1-D int tensor (N,) of labels for each instance.")
91  .Output(
92  0,
93  "accuracies",
94  "1-D float tensor (D,) of accuracy for each class. If a class has no "
95  "instance in the batch, its accuracy score is set to zero.")
96  .Output(
97  1,
98  "amounts",
99  "1-D int tensor (D,) of number of instances for each class in the batch.");
100 
101 SHOULD_NOT_DO_GRADIENT(MultiClassAccuracy);
102 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.