Caffe2 - C++ API
A deep learning, cross platform ML framework
device.cc
1 #include "caffe2/opt/device.h"
2 #include "caffe2/core/logging.h"
3 #include "nomnigraph/Graph/Algorithms.h"
4 
5 using namespace nom;
6 using namespace nom::repr;
7 
8 std::vector<NNGraph::EdgeRef> getInputEdges(
9  const NNGraph::SubgraphType& sg,
10  const NNGraph& g) {
11  std::vector<NNGraph::EdgeRef> inputTensorEdges;
12  for (const auto& node : sg.getNodes()) {
13  NOM_REQUIRE_OR_CONT(nn::is<NeuralNetOperator>(node));
14  NOM_REQUIRE_OR_CONT(nn::hasInputs(node));
15 
16  // Check if tensor's parents are in the sg
17  for (const auto& input : nn::getInputs(node)) {
18  NOM_REQUIRE_OR_CONT(
19  !nn::hasProducer(input) || !sg.hasNode(nn::getProducer(input)));
20  inputTensorEdges.emplace_back(g.getEdge(input, node));
21  }
22  }
23  return inputTensorEdges;
24 }
25 
26 std::vector<NNGraph::EdgeRef> getOutputEdges(
27  const NNGraph::SubgraphType& sg,
28  const NNGraph& g) {
29  std::vector<NNGraph::EdgeRef> outputTensorEdges;
30  for (const auto& node : sg.getNodes()) {
31  NOM_REQUIRE_OR_CONT(nn::is<NeuralNetOperator>(node));
32 
33  for (const auto& output : nn::getOutputs(node)) {
34  auto consumers = nn::getConsumers(output);
35  for (const auto& consumer : consumers) {
36  NOM_REQUIRE_OR_CONT(!sg.hasNode(consumer));
37  outputTensorEdges.emplace_back(g.getEdge(node, output));
38  }
39  NOM_REQUIRE_OR_CONT(consumers.size() == 0);
40  outputTensorEdges.emplace_back(g.getEdge(node, output));
41  }
42  }
43  return outputTensorEdges;
44 }
45 
46 namespace caffe2 {
47 namespace opt {
48 
49 void insertCopies(
50  NNModule* nn,
51  std::function<bool(NNGraph::NodeRef)> supported,
52  std::function<NNGraph::NodeRef(NNGraph&)> copyToFn,
53  std::function<NNGraph::NodeRef(NNGraph&)> copyFromFn) {
54  auto matches = nom::algorithm::binaryMatch(&nn->dataFlow, supported);
55 
56  // We're doing a lot of inplace mutation so this is necessary.
57  std::set<NNGraph::EdgeRef> changedEdges;
58 
59  for (const auto& match : matches) {
60  for (const auto& edge : getInputEdges(match, nn->dataFlow)) {
61  NOM_REQUIRE_OR_CONT(changedEdges.count(edge) == 0);
62  auto input = edge->tail();
63  NNGraph::NodeRef newInput = nullptr;
64 
65  // First we check if there already is a copyNode that we can reuse.
66  auto copyNode = copyToFn(nn->dataFlow);
67  auto copyOp = nn::get<NeuralNetOperator>(copyNode);
68 
69  // Rectify redudancies.
70  for (const auto& consumer : nn::getConsumers(input)) {
71  auto consumerOp = nn::get<NeuralNetOperator>(consumer);
72  // We already have a copy node, let's reuse it.
73  if (consumerOp->getKind() == copyOp->getKind()) {
74  nn->dataFlow.deleteNode(copyNode);
75  copyNode = consumer;
76  newInput = nn::getOutputs(copyNode).front();
77  break;
78  }
79  }
80 
81  // Second, we may have found the out-edge of a previous match.
82  auto copyFromNode = copyFromFn(nn->dataFlow);
83  auto copyFromOp = nn::get<NeuralNetOperator>(copyFromNode);
84  do {
85  NOM_REQUIRE_OR_CONT(nn::hasProducer(input));
86  const auto& producer = nn::getProducer(input);
87  const auto& producerOp = nn::get<NeuralNetOperator>(producer);
88  NOM_REQUIRE_OR_CONT(producerOp->getKind() == copyFromOp->getKind());
89  NOM_REQUIRE_OR_CONT(nn::hasInputs(producer));
90  auto oldInputs = nn::getInputs(producer);
91  NOM_REQUIRE_OR_CONT(oldInputs.size() == 1);
92  nn->dataFlow.deleteNode(copyNode);
93  newInput = oldInputs.front();
94  } while (false);
95  nn->dataFlow.deleteNode(copyFromNode);
96 
97  // Third, we may have to insert a copy operation
98  // if the above checks failed.
99  if (!newInput) {
100  auto data = nn::get<NeuralNetData>(input);
101  newInput = nn->dataFlow.createNode(
102  util::make_unique<repr::Tensor>(data->getName() + "_opencl_0"));
103  nn->dataFlow.createEdge(input, copyNode);
104  nn->dataFlow.createEdge(copyNode, newInput);
105  }
106  // Finally, swap our input node to reflect a tensor already
107  // on the device.
108  input->removeOutEdge(edge);
109  edge->setTail(newInput);
110  newInput->addOutEdge(edge);
111 
112  changedEdges.insert(edge);
113  }
114 
115  for (const auto& edge : getOutputEdges(match, nn->dataFlow)) {
116  NOM_REQUIRE_OR_CONT(changedEdges.count(edge) == 0);
117  auto output = edge->head();
118 
119  auto copyNode = copyFromFn(nn->dataFlow);
120  auto data = nn::get<NeuralNetData>(output);
121 
122  auto newOutput = nn->dataFlow.createNode(
123  util::make_unique<repr::Tensor>(data->getName() + "_opencl_0"));
124 
125  output->removeInEdge(edge);
126  edge->setHead(newOutput);
127 
128  changedEdges.insert(edge);
129 
130  nn->dataFlow.createEdge(newOutput, copyNode);
131  nn->dataFlow.createEdge(copyNode, output);
132 
133  // We may have broken some consumers that are actually in the match.
134  for (auto consumer : nn::getConsumers(output)) {
135  if (match.getNodes().count(consumer)) {
136  auto brokenEdge = nn->dataFlow.getEdge(output, consumer);
137  output->removeOutEdge(brokenEdge);
138  brokenEdge->setTail(newOutput);
139  newOutput->addOutEdge(brokenEdge);
140  }
141  }
142  }
143  }
144 }
145 
146 } // namespace opt
147 } // namespace caffe2
NodeRef createNode(T &&data)
Creates a node and retains ownership of it.
Definition: Graph.h:240
Definition: Dot.h:16
EdgeRef getEdge(NodeRef tail, NodeRef head) const
Get a reference to the edge between two nodes if it exists.
Definition: Graph.h:452
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
void deleteNode(NodeRef n)
Deletes a node from the graph.
Definition: Graph.h:460
A simple graph implementation.
Definition: Graph.h:29
void removeInEdge(EdgeRef e)
Removes an edge by reference to known in-edges.
Definition: Graph.h:106
EdgeRef createEdge(NodeRef tail, NodeRef head, U...data)
Creates a directed edge and retains ownership of it.
Definition: Graph.h:415