3 #include <ATen/core/dispatch/Dispatcher.h> 4 #include "caffe2/core/operator.h" 5 #include <c10/util/ArrayRef.h> 6 #include <c10/util/Metaprogramming.h> 7 #include <ATen/core/ivalue.h> 20 template <
class Context>
23 USE_OPERATOR_CONTEXT_FUNCTIONS;
27 const OperatorDef& operator_def,
32 has_preallocated_outputs_(
33 op_.schema().arguments().size() != 0 &&
34 op_.schema().arguments().back().name() ==
35 detail::PREALLOCATED_OUTPUT_ARGNAME) {
37 !has_preallocated_outputs_ ||
38 op_.schema().arguments().back().type()->isSubtypeOf(
39 OptionalType::create(ListType::ofTensors())));
41 AT_ASSERT(operator_def.output_size() == op_.schema().returns().size());
43 operator_def.input_size() + (has_preallocated_outputs_ ? 1 : 0) <=
49 bool RunOnDevice()
override {
52 std::lock_guard<std::mutex> lock(mutex_);
63 AT_ASSERT(stack_.size() == 0);
65 op_.schema().arguments().size() + (has_preallocated_outputs_ ? 1 : 0));
67 size_t input_tensor_index = 0;
69 for (
const auto& argument : op_.schema().arguments()) {
70 if (argument.name() == detail::PREALLOCATED_OUTPUT_ARGNAME) {
74 has_preallocated_outputs_,
75 "Error in caffe2->c10 wrapper: Operator schema has a parameter named ",
76 detail::PREALLOCATED_OUTPUT_ARGNAME,
77 ", but it's not at the end of the argument list");
80 argument.type()->isSubtypeOf(
81 OptionalType::create(ListType::ofTensors())),
82 "Error in caffe2->c10 wrapper: Operator schema has a parameter named ",
83 detail::PREALLOCATED_OUTPUT_ARGNAME,
84 ", but it's not of type TensorList?");
85 stack_.emplace_back(preallocated_outputs_());
87 }
else if (argument.type()->isSubtypeOf(TensorType::get())) {
89 input_tensor_index < InputSize(),
90 "Error in caffe2->c10 wrapper: Too few tensor arguments given (",
92 "), operator schema expected more.");
95 }
else if (argument.type()->isSubtypeOf(ListType::ofTensors())) {
97 input_tensor_index == 0,
98 "Error in caffe2->c10 wrapper: Schema can only have either one or more Tensor inputs or one TensorList input.");
99 stack_.emplace_back(ivalue::TensorList::create(array_inputs_()));
100 input_tensor_index = InputSize();
103 stack_.emplace_back(get_nontensor_argument_(argument));
107 input_tensor_index == InputSize(),
108 "Error in caffe2->c10 wrapper: Number of caffe2 operator inputs (",
110 ") doesn't match number of tensor arguments (",
112 ") in the c10 operator schema.");
117 AT_ASSERT(stack_.size() == op_.schema().arguments().size());
118 if (!kernel_.has_value()) {
120 kernel_ = c10::Dispatcher::singleton().
lookup(op_, &stack_);
122 kernel_->call(&stack_);
126 AT_ASSERT(stack_.size() == op_.schema().returns().size());
127 for (
size_t i = 0; i < op_.schema().returns().size(); ++i) {
128 OperatorBase::SetOutputTensor(i,
Tensor(
C10Tensor(std::move(stack_[i]).toTensor())));
133 std::vector<at::Tensor> array_inputs_() {
134 std::vector<at::Tensor> result;
135 result.reserve(InputSize());
136 for (
size_t i = 0; i < InputSize(); ++i) {
137 result.emplace_back(
Input(i));
142 std::vector<at::Tensor> preallocated_outputs_() {
143 std::vector<at::Tensor> result;
144 result.reserve(OutputSize());
145 for (
size_t i = 0; i < OutputSize(); ++i) {
146 result.emplace_back(OperatorBase::OutputTensorOrUndefined(i));
152 if (argument.type()->isSubtypeOf(IntType::get())) {
153 return get_nontensor_argument_<int>(
154 argument.name(), argument.default_value());
155 }
else if (argument.type()->isSubtypeOf(FloatType::get())) {
156 return get_nontensor_argument_<double>(
157 argument.name(), argument.default_value());
158 }
else if (argument.type()->isSubtypeOf(BoolType::get())) {
159 return get_nontensor_argument_<bool>(
160 argument.name(), argument.default_value());
164 "Error in caffe2->c10 wrapper: Unsupported argument type ",
165 argument.type()->str(),
166 " in c10 operator schema");
171 IValue get_nontensor_argument_(
172 const std::string& name,
174 if (default_value.has_value()) {
175 return this->
template GetSingleArgument<T>(name, default_value->to<
T>());
178 this->
template HasSingleArgumentOfType<T>(name),
179 "Error in caffe2->c10 wrapper: Expected argument '",
181 "' missing or wrong type.");
182 return this->
template GetSingleArgument<T>(name, 0);
193 bool has_preallocated_outputs_;
198 std::vector<IValue> stack_;
202 template <
class Context>
203 inline std::function<
204 std::unique_ptr<OperatorBase>(
const OperatorDef&,
Workspace*)>
206 return [op_handle](
const OperatorDef& op_def,
Workspace* ws) {
207 return c10::guts::make_unique<C10OperatorWrapper<Context>>(
208 op_handle, op_def, ws);
218 #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(OperatorHandle, Name) \ 219 REGISTER_CPU_OPERATOR_CREATOR( \ 220 Name, detail::createC10OperatorWrapper<CPUContext>(OperatorHandle)) 221 #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CUDA(OperatorHandle, Name) \ 222 REGISTER_CUDA_OPERATOR_CREATOR( \ 223 Name, detail::createC10OperatorWrapper<CUDAContext>(OperatorHandle)) 224 #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_HIP(OperatorHandle, Name) \ 225 REGISTER_HIP_OPERATOR_CREATOR( \ 226 Name, detail::createC10OperatorWrapper<HIPContext>(OperatorHandle)) 228 #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(OperatorHandle, Name) 229 #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CUDA(OperatorHandle, Name) 230 #define REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_HIP(OperatorHandle, Name) This is a minimal Tensor class for use in c10 code.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
OpKernel lookup(const OperatorHandle &op, const Stack *stack) const
Perform a dynamic dispatch and get the kernel for an operator.
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
This is a handle to an operator schema registered with the dispatcher.