1 #include <ATen/core/dispatch/Dispatcher.h> 2 #include <torch/csrc/jit/operator.h> 3 #include <torch/csrc/jit/tracer.h> 10 if (tensor.requires_grad()) {
11 throw std::runtime_error(
"Autograd not yet supported for c10 ops.");
19 return Operator(op.schema(), [op](Stack& stack) {
20 const auto input_size = op.schema().arguments().size();
21 const auto output_size = op.schema().returns().size();
26 for (
auto iter = stack.end() - input_size; iter != stack.end(); ++iter) {
28 if (iter->isTensor() && iter->toTensor().defined()) {
29 *iter = unwrap(std::move(*iter).toTensor());
30 }
else if (iter->isTensorList()) {
31 for (
auto& item : iter->toTensorList()->elements()) {
32 item = unwrap(std::move(item));
37 if (jit::tracer::isTracing()) {
38 auto symbol = Symbol::fromQualString(op.schema().name());
39 const auto& graph = tracer::getTracingState()->graph;
40 node = graph->create(symbol, 0);
41 const auto& args = op.schema().arguments();
43 for (
auto iter = stack.end() - input_size; iter != stack.end();
48 auto type = args[i].type();
49 if (type->kind() == TypeKind::OptionalType) {
53 ->insertNode(graph->createNone(
54 reinterpret_cast<OptionalType*>(args[i].type().get())
61 reinterpret_cast<OptionalType*
>(type.get())->getElementType();
64 if (type->isSubclass(TypeKind::TensorType)) {
65 AT_ASSERT(iter->isTensor());
66 tracer::addInputs(node, args[i].name().c_str(), iter->toTensor());
67 }
else if (type->kind() == TypeKind::FloatType) {
68 AT_ASSERT(iter->isDouble());
69 tracer::addInputs(node, args[i].name().c_str(), iter->toDouble());
70 }
else if (type->kind() == TypeKind::IntType) {
71 AT_ASSERT(iter->isInt());
72 tracer::addInputs(node, args[i].name().c_str(), iter->toInt());
73 }
else if (type->kind() == TypeKind::BoolType) {
74 AT_ASSERT(iter->isBool());
75 tracer::addInputs(node, args[i].name().c_str(), iter->toBool());
76 }
else if (type->kind() == TypeKind::StringType) {
77 AT_ASSERT(iter->isString());
79 node, args[i].name().c_str(), iter->toStringRef());
80 }
else if (type->kind() == TypeKind::ListType) {
81 const auto& elem_type =
82 reinterpret_cast<ListType*
>(type.get())->getElementType();
83 if (elem_type->isSubclass(TypeKind::TensorType)) {
84 AT_ASSERT(iter->isTensorList());
87 args[i].name().c_str(),
88 iter->toTensorList()->elements());
89 }
else if (elem_type->kind() == TypeKind::FloatType) {
90 AT_ASSERT(iter->isDoubleList());
93 args[i].name().c_str(),
94 iter->toDoubleList()->elements());
95 }
else if (elem_type->kind() == TypeKind::IntType) {
96 AT_ASSERT(iter->isIntList());
98 node, args[i].name().c_str(), iter->toIntList()->elements());
99 }
else if (elem_type->kind() == TypeKind::BoolType) {
100 AT_ASSERT(iter->isBoolList());
102 node, args[i].name().c_str(), iter->toBoolList()->elements());
104 throw std::runtime_error(
105 "unsupported input list type: " + elem_type->str());
108 throw std::runtime_error(
"unsupported input type: " + type->str());
111 graph->insertNode(node);
114 c10::Dispatcher::singleton().
lookup(op, &stack).
call(&stack);
117 for (
auto iter = stack.end() - output_size; iter != stack.end(); ++iter) {
118 if (iter->isTensor()) {
119 *iter = torch::autograd::make_variable(std::move(*iter).toTensor());
123 if (jit::tracer::isTracing()) {
125 for (
auto iter = stack.end() - output_size; iter != stack.end();
127 const auto& type = op.schema().returns()[i].type();
128 if (type->isSubclass(TypeKind::TensorType)) {
129 AT_ASSERT(iter->isTensor());
130 tracer::addOutput(node, iter->toTensor());
131 }
else if (type->kind() == TypeKind::ListType) {
132 const auto& elem_type =
133 reinterpret_cast<ListType*
>(type.get())->getElementType();
134 if (elem_type->isSubclass(TypeKind::TensorType)) {
135 AT_ASSERT(iter->isTensorList());
136 tracer::addOutput(node, iter->toTensorList()->elements());
138 throw std::runtime_error(
139 "unsupported ouptut list type: " + elem_type->str());
142 throw std::runtime_error(
"unsupported output type: " + type->str());
154 torch::jit::registerOperator(createOperatorFromC10(op));
162 struct Registerer final {
167 c10::guts::make_unique<RegistrationListener>()
173 Registerer registerer;
void addRegistrationListener(std::unique_ptr< OpRegistrationListener > listener)
Add a listener that gets called whenever a new op is registered or an existing op is deregistered...
OpKernel lookup(const OperatorHandle &op, const Stack *stack) const
Perform a dynamic dispatch and get the kernel for an operator.
This is a handle to an operator schema registered with the dispatcher.
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Implement this interface and register your instance with the dispatcher to get notified when operator...
void call(Stack *stack) const
Call the operator kernel with the given arguments.