3 #include "test/cpp/jit/test_base.h" 4 #include "test/cpp/jit/test_utils.h" 6 #include "torch/csrc/jit/custom_operator.h" 12 void testCustomOperators() {
14 RegisterOperators reg({createOperator(
15 "foo::bar", [](
double a,
at::Tensor b) {
return a + b; })});
16 auto& ops = getAllOperatorsFor(Symbol::fromQualString(
"foo::bar"));
17 ASSERT_EQ(ops.size(), 1);
19 auto& op = ops.front();
20 ASSERT_EQ(op->schema().name(),
"foo::bar");
22 ASSERT_EQ(op->schema().arguments().size(), 2);
23 ASSERT_EQ(op->schema().arguments()[0].name(),
"_0");
24 ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
25 ASSERT_EQ(op->schema().arguments()[1].name(),
"_1");
26 ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
28 ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
31 push(stack, 2.0f, autograd::make_variable(at::ones(5)));
32 op->getOperation()(stack);
36 ASSERT_TRUE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
39 RegisterOperators reg({createOperator(
40 "foo::bar_with_schema(float a, Tensor b) -> Tensor",
41 [](
double a,
at::Tensor b) {
return a + b; })});
44 getAllOperatorsFor(Symbol::fromQualString(
"foo::bar_with_schema"));
45 ASSERT_EQ(ops.size(), 1);
47 auto& op = ops.front();
48 ASSERT_EQ(op->schema().name(),
"foo::bar_with_schema");
50 ASSERT_EQ(op->schema().arguments().size(), 2);
51 ASSERT_EQ(op->schema().arguments()[0].name(),
"a");
52 ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
53 ASSERT_EQ(op->schema().arguments()[1].name(),
"b");
54 ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
56 ASSERT_EQ(op->schema().returns().size(), 1);
57 ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
60 push(stack, 2.0f, autograd::make_variable(at::ones(5)));
61 op->getOperation()(stack);
65 ASSERT_TRUE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
69 RegisterOperators reg({createOperator(
70 "foo::lists(int[] ints, float[] floats, Tensor[] tensors) -> float[]",
71 [](
const std::vector<int64_t>& ints,
72 const std::vector<double>& floats,
73 std::vector<at::Tensor> tensors) {
return floats; })});
75 auto& ops = getAllOperatorsFor(Symbol::fromQualString(
"foo::lists"));
76 ASSERT_EQ(ops.size(), 1);
78 auto& op = ops.front();
79 ASSERT_EQ(op->schema().name(),
"foo::lists");
81 ASSERT_EQ(op->schema().arguments().size(), 3);
82 ASSERT_EQ(op->schema().arguments()[0].name(),
"ints");
84 op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofInts()));
85 ASSERT_EQ(op->schema().arguments()[1].name(),
"floats");
87 op->schema().arguments()[1].type()->isSubtypeOf(ListType::ofFloats()));
88 ASSERT_EQ(op->schema().arguments()[2].name(),
"tensors");
90 op->schema().arguments()[2].type()->isSubtypeOf(ListType::ofTensors()));
92 ASSERT_EQ(op->schema().returns().size(), 1);
94 op->schema().returns()[0].type()->isSubtypeOf(ListType::ofFloats()));
97 push(stack, std::vector<int64_t>{1, 2});
98 push(stack, std::vector<double>{1.0, 2.0});
99 push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))});
100 op->getOperation()(stack);
101 std::vector<double> output;
104 ASSERT_EQ(output.size(), 2);
105 ASSERT_EQ(output[0], 1.0);
106 ASSERT_EQ(output[1], 2.0);
109 RegisterOperators reg(
110 "foo::lists2(Tensor[] tensors) -> Tensor[]",
111 [](std::vector<at::Tensor> tensors) {
return tensors; });
113 auto& ops = getAllOperatorsFor(Symbol::fromQualString(
"foo::lists2"));
114 ASSERT_EQ(ops.size(), 1);
116 auto& op = ops.front();
117 ASSERT_EQ(op->schema().name(),
"foo::lists2");
119 ASSERT_EQ(op->schema().arguments().size(), 1);
120 ASSERT_EQ(op->schema().arguments()[0].name(),
"tensors");
122 op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofTensors()));
124 ASSERT_EQ(op->schema().returns().size(), 1);
126 op->schema().returns()[0].type()->isSubtypeOf(ListType::ofTensors()));
129 push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))});
130 op->getOperation()(stack);
131 std::vector<at::Tensor> output;
134 ASSERT_EQ(output.size(), 1);
135 ASSERT_TRUE(output[0].allclose(autograd::make_variable(at::ones(5))));
138 auto op = createOperator(
139 "traced::op(float a, Tensor b) -> Tensor",
140 [](
double a,
at::Tensor b) {
return a + b; });
142 std::shared_ptr<tracer::TracingState> state;
143 std::tie(state, std::ignore) = tracer::enter({});
146 push(stack, 2.0f, autograd::make_variable(at::ones(5)));
147 op.getOperation()(stack);
148 at::Tensor output = autograd::make_variable(at::empty({}));
151 tracer::exit({IValue(output)});
153 std::string op_name(
"traced::op");
154 bool contains_traced_op =
false;
155 for (
const auto& node : state->graph->nodes()) {
156 if (std::string(node->kind().toQualString()) == op_name) {
157 contains_traced_op =
true;
161 ASSERT_TRUE(contains_traced_op);
166 "foo::bar_with_bad_schema(Tensor a) -> Tensor",
167 [](
double a,
at::Tensor b) {
return a + b; }),
168 "Inferred 2 argument(s) for operator implementation, " 169 "but the provided schema specified 1 argument(s).");
172 "foo::bar_with_bad_schema(Tensor a) -> Tensor",
173 [](
double a) {
return a; }),
174 "Inferred type for argument #0 was float, " 175 "but the provided schema specified type Tensor " 176 "for the argument in that position");
179 "foo::bar_with_bad_schema(float a) -> (float, float)",
180 [](
double a) {
return a; }),
181 "Inferred 1 return value(s) for operator implementation, " 182 "but the provided schema specified 2 return value(s).");
185 "foo::bar_with_bad_schema(float a) -> Tensor",
186 [](
double a) {
return a; }),
187 "Inferred type for return value #0 was float, " 188 "but the provided schema specified type Tensor " 189 "for the return value in that position");
193 auto op = createOperator(
194 "traced::op(float[] f) -> int",
195 [](
const std::vector<double>& f) -> int64_t {
return f.size(); });
197 std::shared_ptr<tracer::TracingState> state;
198 std::tie(state, std::ignore) = tracer::enter({});
201 push(stack, std::vector<double>{1.0});
204 op.getOperation()(stack),
205 "Tracing float lists currently not supported!");