Caffe2 - C++ API
A deep learning, cross platform ML framework
test_custom_operators.h
1 #pragma once
2 
3 #include "test/cpp/jit/test_base.h"
4 #include "test/cpp/jit/test_utils.h"
5 
6 #include "torch/csrc/jit/custom_operator.h"
7 
8 namespace torch {
9 namespace jit {
10 namespace test {
11 
12 void testCustomOperators() {
13  {
14  RegisterOperators reg({createOperator(
15  "foo::bar", [](double a, at::Tensor b) { return a + b; })});
16  auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::bar"));
17  ASSERT_EQ(ops.size(), 1);
18 
19  auto& op = ops.front();
20  ASSERT_EQ(op->schema().name(), "foo::bar");
21 
22  ASSERT_EQ(op->schema().arguments().size(), 2);
23  ASSERT_EQ(op->schema().arguments()[0].name(), "_0");
24  ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
25  ASSERT_EQ(op->schema().arguments()[1].name(), "_1");
26  ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
27 
28  ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
29 
30  Stack stack;
31  push(stack, 2.0f, autograd::make_variable(at::ones(5)));
32  op->getOperation()(stack);
33  at::Tensor output;
34  pop(stack, output);
35 
36  ASSERT_TRUE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
37  }
38  {
39  RegisterOperators reg({createOperator(
40  "foo::bar_with_schema(float a, Tensor b) -> Tensor",
41  [](double a, at::Tensor b) { return a + b; })});
42 
43  auto& ops =
44  getAllOperatorsFor(Symbol::fromQualString("foo::bar_with_schema"));
45  ASSERT_EQ(ops.size(), 1);
46 
47  auto& op = ops.front();
48  ASSERT_EQ(op->schema().name(), "foo::bar_with_schema");
49 
50  ASSERT_EQ(op->schema().arguments().size(), 2);
51  ASSERT_EQ(op->schema().arguments()[0].name(), "a");
52  ASSERT_EQ(op->schema().arguments()[0].type()->kind(), TypeKind::FloatType);
53  ASSERT_EQ(op->schema().arguments()[1].name(), "b");
54  ASSERT_EQ(op->schema().arguments()[1].type()->kind(), TypeKind::TensorType);
55 
56  ASSERT_EQ(op->schema().returns().size(), 1);
57  ASSERT_EQ(op->schema().returns()[0].type()->kind(), TypeKind::TensorType);
58 
59  Stack stack;
60  push(stack, 2.0f, autograd::make_variable(at::ones(5)));
61  op->getOperation()(stack);
62  at::Tensor output;
63  pop(stack, output);
64 
65  ASSERT_TRUE(output.allclose(autograd::make_variable(at::full(5, 3.0f))));
66  }
67  {
68  // Check that lists work well.
69  RegisterOperators reg({createOperator(
70  "foo::lists(int[] ints, float[] floats, Tensor[] tensors) -> float[]",
71  [](const std::vector<int64_t>& ints,
72  const std::vector<double>& floats,
73  std::vector<at::Tensor> tensors) { return floats; })});
74 
75  auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists"));
76  ASSERT_EQ(ops.size(), 1);
77 
78  auto& op = ops.front();
79  ASSERT_EQ(op->schema().name(), "foo::lists");
80 
81  ASSERT_EQ(op->schema().arguments().size(), 3);
82  ASSERT_EQ(op->schema().arguments()[0].name(), "ints");
83  ASSERT_TRUE(
84  op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofInts()));
85  ASSERT_EQ(op->schema().arguments()[1].name(), "floats");
86  ASSERT_TRUE(
87  op->schema().arguments()[1].type()->isSubtypeOf(ListType::ofFloats()));
88  ASSERT_EQ(op->schema().arguments()[2].name(), "tensors");
89  ASSERT_TRUE(
90  op->schema().arguments()[2].type()->isSubtypeOf(ListType::ofTensors()));
91 
92  ASSERT_EQ(op->schema().returns().size(), 1);
93  ASSERT_TRUE(
94  op->schema().returns()[0].type()->isSubtypeOf(ListType::ofFloats()));
95 
96  Stack stack;
97  push(stack, std::vector<int64_t>{1, 2});
98  push(stack, std::vector<double>{1.0, 2.0});
99  push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))});
100  op->getOperation()(stack);
101  std::vector<double> output;
102  pop(stack, output);
103 
104  ASSERT_EQ(output.size(), 2);
105  ASSERT_EQ(output[0], 1.0);
106  ASSERT_EQ(output[1], 2.0);
107  }
108  {
109  RegisterOperators reg(
110  "foo::lists2(Tensor[] tensors) -> Tensor[]",
111  [](std::vector<at::Tensor> tensors) { return tensors; });
112 
113  auto& ops = getAllOperatorsFor(Symbol::fromQualString("foo::lists2"));
114  ASSERT_EQ(ops.size(), 1);
115 
116  auto& op = ops.front();
117  ASSERT_EQ(op->schema().name(), "foo::lists2");
118 
119  ASSERT_EQ(op->schema().arguments().size(), 1);
120  ASSERT_EQ(op->schema().arguments()[0].name(), "tensors");
121  ASSERT_TRUE(
122  op->schema().arguments()[0].type()->isSubtypeOf(ListType::ofTensors()));
123 
124  ASSERT_EQ(op->schema().returns().size(), 1);
125  ASSERT_TRUE(
126  op->schema().returns()[0].type()->isSubtypeOf(ListType::ofTensors()));
127 
128  Stack stack;
129  push(stack, std::vector<at::Tensor>{autograd::make_variable(at::ones(5))});
130  op->getOperation()(stack);
131  std::vector<at::Tensor> output;
132  pop(stack, output);
133 
134  ASSERT_EQ(output.size(), 1);
135  ASSERT_TRUE(output[0].allclose(autograd::make_variable(at::ones(5))));
136  }
137  {
138  auto op = createOperator(
139  "traced::op(float a, Tensor b) -> Tensor",
140  [](double a, at::Tensor b) { return a + b; });
141 
142  std::shared_ptr<tracer::TracingState> state;
143  std::tie(state, std::ignore) = tracer::enter({});
144 
145  Stack stack;
146  push(stack, 2.0f, autograd::make_variable(at::ones(5)));
147  op.getOperation()(stack);
148  at::Tensor output = autograd::make_variable(at::empty({}));
149  pop(stack, output);
150 
151  tracer::exit({IValue(output)});
152 
153  std::string op_name("traced::op");
154  bool contains_traced_op = false;
155  for (const auto& node : state->graph->nodes()) {
156  if (std::string(node->kind().toQualString()) == op_name) {
157  contains_traced_op = true;
158  break;
159  }
160  }
161  ASSERT_TRUE(contains_traced_op);
162  }
163  {
164  ASSERT_THROWS_WITH(
165  createOperator(
166  "foo::bar_with_bad_schema(Tensor a) -> Tensor",
167  [](double a, at::Tensor b) { return a + b; }),
168  "Inferred 2 argument(s) for operator implementation, "
169  "but the provided schema specified 1 argument(s).");
170  ASSERT_THROWS_WITH(
171  createOperator(
172  "foo::bar_with_bad_schema(Tensor a) -> Tensor",
173  [](double a) { return a; }),
174  "Inferred type for argument #0 was float, "
175  "but the provided schema specified type Tensor "
176  "for the argument in that position");
177  ASSERT_THROWS_WITH(
178  createOperator(
179  "foo::bar_with_bad_schema(float a) -> (float, float)",
180  [](double a) { return a; }),
181  "Inferred 1 return value(s) for operator implementation, "
182  "but the provided schema specified 2 return value(s).");
183  ASSERT_THROWS_WITH(
184  createOperator(
185  "foo::bar_with_bad_schema(float a) -> Tensor",
186  [](double a) { return a; }),
187  "Inferred type for return value #0 was float, "
188  "but the provided schema specified type Tensor "
189  "for the return value in that position");
190  }
191  {
192  // vector<double> is not supported yet.
193  auto op = createOperator(
194  "traced::op(float[] f) -> int",
195  [](const std::vector<double>& f) -> int64_t { return f.size(); });
196 
197  std::shared_ptr<tracer::TracingState> state;
198  std::tie(state, std::ignore) = tracer::enter({});
199 
200  Stack stack;
201  push(stack, std::vector<double>{1.0});
202 
203  ASSERT_THROWS_WITH(
204  op.getOperation()(stack),
205  "Tracing float lists currently not supported!");
206 
207  tracer::abandon();
208  }
209 }
210 } // namespace test
211 } // namespace jit
212 } // namespace torch
Definition: module.cpp:17
Definition: jit_type.h:17