1 #include <torch/csrc/onnx/init.h> 2 #include <torch/csrc/onnx/onnx.h> 3 #include <onnx/onnx_pb.h> 6 void initONNXBindings(PyObject* module) {
7 auto m = py::handle(module).cast<py::module>();
8 auto onnx = m.def_submodule(
"_onnx");
9 py::enum_<::ONNX_NAMESPACE::TensorProto_DataType>(onnx,
"TensorProtoDataType")
10 .value(
"UNDEFINED", ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED)
11 .value(
"FLOAT", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT)
12 .value(
"UINT8", ::ONNX_NAMESPACE::TensorProto_DataType_UINT8)
13 .value(
"INT8", ::ONNX_NAMESPACE::TensorProto_DataType_INT8)
14 .value(
"UINT16", ::ONNX_NAMESPACE::TensorProto_DataType_UINT16)
15 .value(
"INT16", ::ONNX_NAMESPACE::TensorProto_DataType_INT16)
16 .value(
"INT32", ::ONNX_NAMESPACE::TensorProto_DataType_INT32)
17 .value(
"INT64", ::ONNX_NAMESPACE::TensorProto_DataType_INT64)
18 .value(
"STRING", ::ONNX_NAMESPACE::TensorProto_DataType_STRING)
19 .value(
"BOOL", ::ONNX_NAMESPACE::TensorProto_DataType_BOOL)
20 .value(
"FLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
21 .value(
"DOUBLE", ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE)
22 .value(
"UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32)
23 .value(
"UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64)
24 .value(
"COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64)
25 .value(
"COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128);
27 py::enum_<OperatorExportTypes>(onnx,
"OperatorExportTypes")
28 .value(
"ONNX", OperatorExportTypes::ONNX)
29 .value(
"ONNX_ATEN", OperatorExportTypes::ONNX_ATEN)
30 .value(
"ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK)
31 .value(
"RAW", OperatorExportTypes::RAW);
33 #ifdef PYTORCH_ONNX_CAFFE2_BUNDLE 34 onnx.attr(
"PYTORCH_ONNX_CAFFE2_BUNDLE") =
true;
36 onnx.attr(
"PYTORCH_ONNX_CAFFE2_BUNDLE") =
false;