Caffe2 - C++ API
A deep learning, cross platform ML framework
register_c10_ops.cpp
1 #include <ATen/core/dispatch/Dispatcher.h>
2 #include <torch/csrc/jit/operator.h>
3 #include <torch/csrc/jit/tracer.h>
4 
5 namespace torch {
6 namespace jit {
7 namespace {
8 
9 at::Tensor unwrap(at::Tensor&& tensor) {
10  if (tensor.requires_grad()) {
11  throw std::runtime_error("Autograd not yet supported for c10 ops.");
12  }
13  return torch::autograd::Variable(std::move(tensor)).data();
14 }
15 
16 // TODO This currently only handles tensors with requires_grad==False correctly.
17 // It should also handle autograd.
18 Operator createOperatorFromC10(const c10::OperatorHandle& op) {
19  return Operator(op.schema(), [op](Stack& stack) {
20  const auto input_size = op.schema().arguments().size();
21  const auto output_size = op.schema().returns().size();
22 
23  Node* node = nullptr;
24 
25  // unwrap tensor inputs from variable
26  for (auto iter = stack.end() - input_size; iter != stack.end(); ++iter) {
27  // TODO Remove the .defined() check once we don't have undefined tensors on the stack anymore (@wanchaol is working on this)
28  if (iter->isTensor() && iter->toTensor().defined()) {
29  *iter = unwrap(std::move(*iter).toTensor());
30  } else if (iter->isTensorList()) {
31  for (auto& item : iter->toTensorList()->elements()) {
32  item = unwrap(std::move(item));
33  }
34  }
35  }
36 
37  if (jit::tracer::isTracing()) {
38  auto symbol = Symbol::fromQualString(op.schema().name());
39  const auto& graph = tracer::getTracingState()->graph;
40  node = graph->create(symbol, 0);
41  const auto& args = op.schema().arguments();
42  int i = 0;
43  for (auto iter = stack.end() - input_size; iter != stack.end();
44  ++iter, ++i) {
45  // TODO we need to refactor graph APIs (e.g., addInputs)
46  // appropriately; after that, we can get rid of the giant if-else
47  // block we will clean this tech debt together in the following PRs
48  auto type = args[i].type();
49  if (type->kind() == TypeKind::OptionalType) {
50  if (iter->isNone()) {
51  Value* none =
52  graph
53  ->insertNode(graph->createNone(
54  reinterpret_cast<OptionalType*>(args[i].type().get())
55  ->getElementType()))
56  ->output();
57  node->addInput(none);
58  continue;
59  } else {
60  type =
61  reinterpret_cast<OptionalType*>(type.get())->getElementType();
62  }
63  }
64  if (type->isSubclass(TypeKind::TensorType)) {
65  AT_ASSERT(iter->isTensor());
66  tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
67  } else if (type->kind() == TypeKind::FloatType) {
68  AT_ASSERT(iter->isDouble());
69  tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());
70  } else if (type->kind() == TypeKind::IntType) {
71  AT_ASSERT(iter->isInt());
72  tracer::addInputs(node, args[i].name().c_str(), iter->toInt());
73  } else if (type->kind() == TypeKind::BoolType) {
74  AT_ASSERT(iter->isBool());
75  tracer::addInputs(node, args[i].name().c_str(), iter->toBool());
76  } else if (type->kind() == TypeKind::StringType) {
77  AT_ASSERT(iter->isString());
78  tracer::addInputs(
79  node, args[i].name().c_str(), iter->toStringRef());
80  } else if (type->kind() == TypeKind::ListType) {
81  const auto& elem_type =
82  reinterpret_cast<ListType*>(type.get())->getElementType();
83  if (elem_type->isSubclass(TypeKind::TensorType)) {
84  AT_ASSERT(iter->isTensorList());
85  tracer::addInputs(
86  node,
87  args[i].name().c_str(),
88  iter->toTensorList()->elements());
89  } else if (elem_type->kind() == TypeKind::FloatType) {
90  AT_ASSERT(iter->isDoubleList());
91  tracer::addInputs(
92  node,
93  args[i].name().c_str(),
94  iter->toDoubleList()->elements());
95  } else if (elem_type->kind() == TypeKind::IntType) {
96  AT_ASSERT(iter->isIntList());
97  tracer::addInputs(
98  node, args[i].name().c_str(), iter->toIntList()->elements());
99  } else if (elem_type->kind() == TypeKind::BoolType) {
100  AT_ASSERT(iter->isBoolList());
101  tracer::addInputs(
102  node, args[i].name().c_str(), iter->toBoolList()->elements());
103  } else {
104  throw std::runtime_error(
105  "unsupported input list type: " + elem_type->str());
106  }
107  } else {
108  throw std::runtime_error("unsupported input type: " + type->str());
109  }
110  }
111  graph->insertNode(node);
112  }
113 
114  c10::Dispatcher::singleton().lookup(op, &stack).call(&stack);
115 
116  // wrap tensor outputs as variable
117  for (auto iter = stack.end() - output_size; iter != stack.end(); ++iter) {
118  if (iter->isTensor()) {
119  *iter = torch::autograd::make_variable(std::move(*iter).toTensor());
120  }
121  }
122 
123  if (jit::tracer::isTracing()) {
124  int i = 0;
125  for (auto iter = stack.end() - output_size; iter != stack.end();
126  ++iter, ++i) {
127  const auto& type = op.schema().returns()[i].type();
128  if (type->isSubclass(TypeKind::TensorType)) {
129  AT_ASSERT(iter->isTensor());
130  tracer::addOutput(node, iter->toTensor());
131  } else if (type->kind() == TypeKind::ListType) {
132  const auto& elem_type =
133  reinterpret_cast<ListType*>(type.get())->getElementType();
134  if (elem_type->isSubclass(TypeKind::TensorType)) {
135  AT_ASSERT(iter->isTensorList());
136  tracer::addOutput(node, iter->toTensorList()->elements());
137  } else {
138  throw std::runtime_error(
139  "unsupported ouptut list type: " + elem_type->str());
140  }
141  } else {
142  throw std::runtime_error("unsupported output type: " + type->str());
143  }
144  }
145  }
146 
147  return 0;
148  });
149 }
150 
151 class RegistrationListener final : public c10::OpRegistrationListener {
152 public:
153  void onOperatorRegistered(const c10::OperatorHandle& op) override {
154  torch::jit::registerOperator(createOperatorFromC10(op));
155  }
156 
157  void onOperatorDeregistered(const c10::OperatorHandle& op) override {
158  // TODO Do something like torch::jit::deregisterOperator(op.schema());
159  }
160 };
161 
162 struct Registerer final {
163  Registerer() {
164  // this immediately calls the listener on all existing ops,
165  // and calls it in future whenever a new op is registered
166  c10::Dispatcher::singleton().addRegistrationListener(
167  c10::guts::make_unique<RegistrationListener>()
168  );
169  }
170 };
171 
172 // global instance to run its constructor on startup
173 Registerer registerer;
174 
175 }
176 }
177 }
void addRegistrationListener(std::unique_ptr< OpRegistrationListener > listener)
Add a listener that gets called whenever a new op is registered or an existing op is deregistered...
Definition: Dispatcher.cpp:79
OpKernel lookup(const OperatorHandle &op, const Stack *stack) const
Perform a dynamic dispatch and get the kernel for an operator.
Definition: Dispatcher.h:159
This is a handle to an operator schema registered with the dispatcher.
Definition: Dispatcher.h:139
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
Implement this interface and register your instance with the dispatcher to get notified when operator...
Definition: Dispatcher.h:55
void call(Stack *stack) const
Call the operator kernel with the given arguments.
Definition: Dispatcher.h:37