Caffe2 - C++ API
A deep learning, cross platform ML framework
fusion.cc
1 #include "caffe2/opt/fusion.h"
2 #include "caffe2/core/logging.h"
3 #include "caffe2/opt/converter.h"
4 #include "caffe2/opt/passes.h"
5 
6 namespace caffe2 {
7 namespace opt {
8 
9 using namespace nom;
10 
11 // $$ X_{bn} = \frac{s(X - m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
12 // $$ X_{conv} = X * W + b_{conv} $$
13 // thus, substituting $X$ with $X_{conv}$ in the BN equation we get:
14 // $$X_{bn} = X * \frac{sW}{\sqrt{\sigma + \epsilon}} + \frac{s(b_{conv} -
15 // m)}{\sqrt{\sigma + \epsilon}} + b_{bn}$$ or
16 // $$ W' = W\frac{s}{\sqrt{\sigma + \epsilon}}$$
17 // $$ b' = (b_{conv} - m)\frac{s}{\sqrt{\sigma + \epsilon}} + b_{bn}$$
18 bool fuseConvBNHelper(repr::NNModule* nn, caffe2::Workspace* ws) {
19  size_t convOrder = 0;
20  for (auto node_pair : repr::nn::dataIterator<repr::Conv>(nn->dataFlow)) {
21  repr::NNGraph::NodeRef convNode;
22  repr::Conv* conv;
23  std::tie(conv, convNode) = node_pair;
24 
25  auto output = repr::nn::getOutputs(convNode).front();
26  auto consumers = repr::nn::getConsumers(output);
27  NOM_REQUIRE_OR_CONT(consumers.size() == 1);
28 
29  auto consumer = consumers.front();
30  NOM_REQUIRE_OR_CONT(repr::nn::is<repr::BatchNormalization>(consumer));
31 
32  auto bnNode = consumer;
33  auto bn = repr::nn::get<repr::BatchNormalization>(bnNode);
34  auto bnOutputs = nn::getOutputs(bnNode);
35  NOM_REQUIRE_OR_CONT(bnOutputs.size() == 1);
36  auto bnOutput = bnOutputs.front();
37 
38  auto convInputs = repr::nn::getInputs(convNode);
39  if (convInputs.size() < 2) {
40  continue;
41  }
42 
43  auto bnInputs = repr::nn::getInputs(bnNode);
44  CAFFE_ENFORCE(
45  bnInputs.size() >= 5, "Invalid batch normalization input size");
46 
47 #define EXPOSE_TENSOR_DATA(name, index, inputs) \
48  auto name = repr::nn::get<repr::Tensor>(inputs[index]); \
49  assert(ws->HasBlob(name->getName()) && "Blob not in workspace"); \
50  auto name##Tensor = BlobGetMutableTensor(ws->GetBlob(name->getName()), CPU); \
51  auto name##Data = name##Tensor->mutable_data<float>();
52 
53  EXPOSE_TENSOR_DATA(filter, 1, convInputs);
54 
55  EXPOSE_TENSOR_DATA(scale, 1, bnInputs);
56  EXPOSE_TENSOR_DATA(biasBN, 2, bnInputs);
57  EXPOSE_TENSOR_DATA(mean, 3, bnInputs);
58  EXPOSE_TENSOR_DATA(variance, 4, bnInputs);
59 
60  if (convInputs.size() == 2) {
61  NOM_REQUIRE_OR_CONT(conv->getMutableAnnotation() != nullptr);
62  auto annotation =
63  dyn_cast<caffe2::Caffe2Annotation>(conv->getMutableAnnotation());
64  NOM_REQUIRE_OR_CONT(annotation != nullptr);
65  auto op = annotation->getOperatorDef();
66  auto convName = op.name();
67 
68  while (true) {
69  auto convBiasName = convName + "_bias" + to_string(convOrder);
70  if (!ws->HasBlob(convBiasName)) {
71  auto convBiasTensor = make_unique<repr::Tensor>(convBiasName);
72  convBiasTensor->setType(repr::Tensor::DataType::Float);
73  auto convBiasNode = nn->dataFlow.createNode(
74  unique_dyn_cast<repr::NeuralNetData>(convBiasTensor));
75  nn->inputs.insert(convBiasNode);
76  nn->dataFlow.createEdge(convBiasNode, convNode);
77 
78  auto* blob = ws->CreateBlob(convBiasName);
79  caffe2::TensorCPU* tensor = BlobGetMutableTensor(blob, caffe2::CPU);
80  CHECK_NOTNULL(tensor);
81  // Get output channel
82  size_t c = filterTensor->dim32(0);
83  tensor->Resize(c);
84  float* tensor_data = tensor->mutable_data<float>();
85  memset(tensor_data, 0, tensor->nbytes());
86  break;
87  }
88  convOrder++;
89  }
90  }
91 
92  convInputs = repr::nn::getInputs(convNode);
93  EXPOSE_TENSOR_DATA(biasConv, 2, convInputs);
94 
95 #undef EXPOSE_TENSOR_DATA
96 
97  // Assume M{CHW,HWC}
98  auto chwDim = filterTensor->size_from_dim(1);
99  for (auto c = 0; c < filterTensor->dim32(0); ++c) {
100  float coeff =
101  scaleData[c] / std::sqrt(varianceData[c] + bn->getEpsilon());
102  for (auto i = 0; i < chwDim; ++i) {
103  filterData[c * chwDim + i] *= coeff;
104  }
105  auto bias = (biasConvData[c] - meanData[c]) * coeff + biasBNData[c];
106  biasConvData[c] = bias;
107  }
108 
109  nn->dataFlow.deleteNode(output);
110  nn->dataFlow.createEdge(convNode, bnOutput);
111  nn->dataFlow.deleteNode(bnNode);
112  return true;
113  }
114  return false;
115 }
116 
117 void fuseConvBN(nom::repr::NNModule* nn, caffe2::Workspace* ws) {
118  while (fuseConvBNHelper(nn, ws)) {
119  }
120 }
121 
122 REGISTER_WS_OPT_PASS_FROM_FUNC(FuseConvBN, fuseConvBN);
123 
124 } // namespace opt
125 } // namespace caffe2
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:100
NodeRef createNode(T &&data)
Creates a node and retains ownership of it.
Definition: Graph.h:240
Definition: Dot.h:16
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
void deleteNode(NodeRef n)
Deletes a node from the graph.
Definition: Graph.h:460
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
Definition: workspace.h:179
EdgeRef createEdge(NodeRef tail, NodeRef head, U...data)
Creates a directed edge and retains ownership of it.
Definition: Graph.h:415