Caffe2 - C++ API
A deep learning, cross platform ML framework
init.cpp
1 #include <torch/csrc/onnx/init.h>
2 #include <torch/csrc/onnx/onnx.h>
3 #include <onnx/onnx_pb.h>
4 
5 namespace torch { namespace onnx {
6 void initONNXBindings(PyObject* module) {
7  auto m = py::handle(module).cast<py::module>();
8  auto onnx = m.def_submodule("_onnx");
9  py::enum_<::ONNX_NAMESPACE::TensorProto_DataType>(onnx, "TensorProtoDataType")
10  .value("UNDEFINED", ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED)
11  .value("FLOAT", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT)
12  .value("UINT8", ::ONNX_NAMESPACE::TensorProto_DataType_UINT8)
13  .value("INT8", ::ONNX_NAMESPACE::TensorProto_DataType_INT8)
14  .value("UINT16", ::ONNX_NAMESPACE::TensorProto_DataType_UINT16)
15  .value("INT16", ::ONNX_NAMESPACE::TensorProto_DataType_INT16)
16  .value("INT32", ::ONNX_NAMESPACE::TensorProto_DataType_INT32)
17  .value("INT64", ::ONNX_NAMESPACE::TensorProto_DataType_INT64)
18  .value("STRING", ::ONNX_NAMESPACE::TensorProto_DataType_STRING)
19  .value("BOOL", ::ONNX_NAMESPACE::TensorProto_DataType_BOOL)
20  .value("FLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
21  .value("DOUBLE", ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE)
22  .value("UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32)
23  .value("UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64)
24  .value("COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64)
25  .value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128);
26 
27  py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes")
28  .value("ONNX", OperatorExportTypes::ONNX)
29  .value("ONNX_ATEN", OperatorExportTypes::ONNX_ATEN)
30  .value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK)
31  .value("RAW", OperatorExportTypes::RAW);
32 
33 #ifdef PYTORCH_ONNX_CAFFE2_BUNDLE
34  onnx.attr("PYTORCH_ONNX_CAFFE2_BUNDLE") = true;
35 #else
36  onnx.attr("PYTORCH_ONNX_CAFFE2_BUNDLE") = false;
37 #endif
38 }
39 }} // namespace torch::onnx
Definition: jit_type.h:17