1 #include "caffe2/opt/device.h" 2 #include "caffe2/core/logging.h" 3 #include "nomnigraph/Graph/Algorithms.h" 8 std::vector<NNGraph::EdgeRef> getInputEdges(
11 std::vector<NNGraph::EdgeRef> inputTensorEdges;
12 for (
const auto& node : sg.getNodes()) {
13 NOM_REQUIRE_OR_CONT(nn::is<NeuralNetOperator>(node));
14 NOM_REQUIRE_OR_CONT(nn::hasInputs(node));
17 for (
const auto& input : nn::getInputs(node)) {
19 !nn::hasProducer(input) || !sg.hasNode(nn::getProducer(input)));
20 inputTensorEdges.emplace_back(g.
getEdge(input, node));
23 return inputTensorEdges;
26 std::vector<NNGraph::EdgeRef> getOutputEdges(
29 std::vector<NNGraph::EdgeRef> outputTensorEdges;
30 for (
const auto& node : sg.getNodes()) {
31 NOM_REQUIRE_OR_CONT(nn::is<NeuralNetOperator>(node));
33 for (
const auto& output : nn::getOutputs(node)) {
34 auto consumers = nn::getConsumers(output);
35 for (
const auto& consumer : consumers) {
36 NOM_REQUIRE_OR_CONT(!sg.hasNode(consumer));
37 outputTensorEdges.emplace_back(g.
getEdge(node, output));
39 NOM_REQUIRE_OR_CONT(consumers.size() == 0);
40 outputTensorEdges.emplace_back(g.
getEdge(node, output));
43 return outputTensorEdges;
54 auto matches = nom::algorithm::binaryMatch(&nn->dataFlow, supported);
57 std::set<NNGraph::EdgeRef> changedEdges;
59 for (
const auto& match : matches) {
60 for (
const auto& edge : getInputEdges(match, nn->dataFlow)) {
61 NOM_REQUIRE_OR_CONT(changedEdges.count(edge) == 0);
62 auto input = edge->tail();
66 auto copyNode = copyToFn(nn->dataFlow);
67 auto copyOp = nn::get<NeuralNetOperator>(copyNode);
70 for (
const auto& consumer : nn::getConsumers(input)) {
71 auto consumerOp = nn::get<NeuralNetOperator>(consumer);
73 if (consumerOp->getKind() == copyOp->getKind()) {
76 newInput = nn::getOutputs(copyNode).front();
82 auto copyFromNode = copyFromFn(nn->dataFlow);
83 auto copyFromOp = nn::get<NeuralNetOperator>(copyFromNode);
85 NOM_REQUIRE_OR_CONT(nn::hasProducer(input));
86 const auto& producer = nn::getProducer(input);
87 const auto& producerOp = nn::get<NeuralNetOperator>(producer);
88 NOM_REQUIRE_OR_CONT(producerOp->getKind() == copyFromOp->getKind());
89 NOM_REQUIRE_OR_CONT(nn::hasInputs(producer));
90 auto oldInputs = nn::getInputs(producer);
91 NOM_REQUIRE_OR_CONT(oldInputs.size() == 1);
93 newInput = oldInputs.front();
100 auto data = nn::get<NeuralNetData>(input);
102 util::make_unique<repr::Tensor>(data->getName() +
"_opencl_0"));
108 input->removeOutEdge(edge);
109 edge->setTail(newInput);
110 newInput->addOutEdge(edge);
112 changedEdges.insert(edge);
115 for (
const auto& edge : getOutputEdges(match, nn->dataFlow)) {
116 NOM_REQUIRE_OR_CONT(changedEdges.count(edge) == 0);
117 auto output = edge->head();
119 auto copyNode = copyFromFn(nn->dataFlow);
120 auto data = nn::get<NeuralNetData>(output);
123 util::make_unique<repr::Tensor>(data->getName() +
"_opencl_0"));
126 edge->setHead(newOutput);
128 changedEdges.insert(edge);
134 for (
auto consumer : nn::getConsumers(output)) {
135 if (match.getNodes().count(consumer)) {
136 auto brokenEdge = nn->dataFlow.
getEdge(output, consumer);
137 output->removeOutEdge(brokenEdge);
138 brokenEdge->setTail(newOutput);
139 newOutput->addOutEdge(brokenEdge);
NodeRef createNode(T &&data)
Creates a node and retains ownership of it.
EdgeRef getEdge(NodeRef tail, NodeRef head) const
Get a reference to the edge between two nodes if it exists.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void deleteNode(NodeRef n)
Deletes a node from the graph.
A simple graph implementation.
void removeInEdge(EdgeRef e)
Removes an edge by reference to known in-edges.
EdgeRef createEdge(NodeRef tail, NodeRef head, U...data)
Creates a directed edge and retains ownership of it.