Caffe2 - C++ API
A deep learning, cross platform ML framework
filler.cc
1 #include "caffe2/operators/experimental/c10/schemas/filler.h"
2 #include <ATen/core/dispatch/OpSchemaRegistration.h>
3 #include "caffe2/core/operator_c10wrapper.h"
4 #include "caffe2/utils/cast.h"
5 
7 using c10::C10Tensor;
8 using c10::ivalue::IntList;
10 
11 namespace caffe2 {
12 namespace ops {
13 // TODO Parse schema strings instead of creating FunctionSchema manually
14 C10_DEFINE_OP_SCHEMA(
15  ConstantFill,
16  FunctionSchema(
17  "_c10_experimental::ConstantFill",
18  "",
19  (std::vector<c10::Argument>{
20  c10::Argument("inputs", ListType::ofTensors()),
21  c10::Argument("output"),
22  c10::Argument("shape", ListType::ofInts()),
23  c10::Argument("extra_shape", ListType::ofInts()),
24  c10::Argument("input_as_shape", BoolType::get()),
25  c10::Argument("dtype", IntType::get()),
26  c10::Argument("value", NumberType::get())}),
27  (std::vector<c10::Argument>{})));
28 C10_DEFINE_OP_SCHEMA(
29  UniformFill,
30  FunctionSchema(
31  "_c10_experimental::ConstantFill",
32  "",
33  (std::vector<c10::Argument>{
34  c10::Argument("inputs", ListType::ofTensors()),
35  c10::Argument("output"),
36  c10::Argument("shape", ListType::ofInts()),
37  c10::Argument("extra_shape", ListType::ofInts()),
38  c10::Argument("input_as_shape", BoolType::get()),
39  c10::Argument("min", FloatType::get()),
40  c10::Argument("max", FloatType::get())}),
41  (std::vector<c10::Argument>{})));
42 C10_DEFINE_OP_SCHEMA(
44  FunctionSchema(
45  "_c10_experimental::ConstantFill",
46  "",
47  (std::vector<c10::Argument>{
48  c10::Argument("inputs", ListType::ofTensors()),
49  c10::Argument("output"),
50  c10::Argument("shape", ListType::ofInts()),
51  c10::Argument("extra_shape", ListType::ofInts()),
52  c10::Argument("input_as_shape", BoolType::get()),
53  c10::Argument("values"),
54  }),
55  (std::vector<c10::Argument>{})));
56 C10_DEFINE_OP_SCHEMA(
57  GivenTensorIntFill,
58  FunctionSchema(
59  "_c10_experimental::ConstantFill",
60  "",
61  (std::vector<c10::Argument>{
62  c10::Argument("inputs", ListType::ofTensors()),
63  c10::Argument("output"),
64  c10::Argument("shape", ListType::ofInts()),
65  c10::Argument("extra_shape", ListType::ofInts()),
66  c10::Argument("input_as_shape", BoolType::get()),
67  c10::Argument("values"),
68  }),
69  (std::vector<c10::Argument>{})));
70 C10_DEFINE_OP_SCHEMA(
71  GivenTensorInt64Fill,
72  FunctionSchema(
73  "_c10_experimental::ConstantFill",
74  "",
75  (std::vector<c10::Argument>{
76  c10::Argument("inputs", ListType::ofTensors()),
77  c10::Argument("output"),
78  c10::Argument("shape", ListType::ofInts()),
79  c10::Argument("extra_shape", ListType::ofInts()),
80  c10::Argument("input_as_shape", BoolType::get()),
81  c10::Argument("values"),
82  }),
83  (std::vector<c10::Argument>{})));
84 }
85 }
86 
87 namespace caffe2 {
88 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
89  ops::ConstantFill(),
90  C10ConstantFill_DontUseThisOpYet)
91 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
92  ops::UniformFill(),
93  C10UniformFill_DontUseThisOpYet)
94 
95 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
96  ops::GivenTensorFill(),
97  C10GivenTensorFill_DontUseThisOpYet)
98 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
99  ops::GivenTensorIntFill(),
100  C10GivenTensorIntFill_DontUseThisOpYet)
101 REGISTER_C10_OPERATOR_FOR_CAFFE2_DISPATCH_CPU(
102  ops::GivenTensorInt64Fill(),
103  C10GivenTensorInt64Fill_DontUseThisOpYet)
104 } // namespace caffe2
This is a minimal Tensor class for use in c10 code.
Definition: Tensor.h:18
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13