Caffe2 - C++ API
A deep learning, cross platform ML framework
operator_sets.h
1 #pragma once
2 
3 #include "onnx/defs/schema.h"
4 
5 namespace ONNX_NAMESPACE {
6 
7 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
8  PyTorch,
9  1,
10  SparseLengthsSumFused8BitRowwise);
11 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsSum);
12 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, SparseLengthsWeightedSum);
13 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchGather);
14 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, DotProduct);
15 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, FCTransposed);
16 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, BatchMatMul);
17 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(PyTorch, 1, ExpandDims);
18 
19 // Iterate over schema from ai.onnx.pytorch domain opset 1
21  public:
22  static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
23  fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
24  PyTorch, 1, SparseLengthsSumFused8BitRowwise)>());
25  fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
26  PyTorch, 1, SparseLengthsSum)>());
27  fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
28  PyTorch, 1, SparseLengthsWeightedSum)>());
29  fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
30  PyTorch, 1, BatchGather)>());
31  fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
32  PyTorch, 1, DotProduct)>());
33  fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
34  PyTorch, 1, FCTransposed)>());
35  fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
36  PyTorch, 1, BatchMatMul)>());
37  fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(
38  PyTorch, 1, ExpandDims)>());
39  }
40 };
41 
42 inline void RegisterPyTorchOperatorSetSchema() {
43  RegisterOpSetSchema<OpSet_PyTorch_ver1>();
44 }
45 
46 } // namespace ONNX_NAMESPACE