Caffe2 - C++ API
A deep learning, cross platform ML framework
Data Structures | Public Types | Public Member Functions | Static Public Member Functions | Friends
caffe2::OpSchema Class Reference

A class to record the schema of an op. More...

#include <operator_schema.h>

Data Structures

struct  Argument
 
struct  Cost
 

Public Types

typedef std::function< vector< TensorShape >const OperatorDef &, const vector< TensorShape > &)> TensorInferenceFunctionType
 
typedef std::function< struct Cost(const OperatorDef &, const vector< TensorShape > &)> CostInferenceFunctionType
 Registers a function that takes in an OperatorDef and a series of input shapes and returns the total "cost" required to run the operator via struct by value.
 
using DeviceInferenceFunctionType = std::function< std::pair< std::vector< DeviceOption >, std::vector< DeviceOption >>(const OperatorDef &def)>
 Returns the required device location of inputs and outputs.
 

Public Member Functions

 OpSchema (const string &type, const string &file, const int line)
 
const string & file () const
 Returns the file that the op schema is registered from.
 
int line () const
 Returns the line in file that the op schema is registered from.
 
const char * doc () const
 Returns the docstring of the op schema.
 
bool Verify (const OperatorDef &def) const
 Verifies if an operator definition protobuf matches the pattern specified in the schema.
 
OpSchemaNumInputs (int n)
 A single input.
 
OpSchemaNumInputs (int min, int max)
 Input could be in range [min, max], inclusive.
 
OpSchemaNumInputs (set< int > allowed_input_nums)
 Input could be one of the values specified in allowed_input_nums.
 
OpSchemaNumInputs (std::function< bool(int)> func)
 Input is checked with a specified function.
 
OpSchemaNumOutputs (int n)
 A single output.
 
OpSchemaNumOutputs (int min, int max)
 Output could be in range [min, max], inclusive.
 
OpSchemaNumOutputs (set< int > allowed_output_nums)
 Output could be one of the values specified in allowed_output_nums.
 
OpSchemaNumOutputs (std::function< bool(int)> func)
 Output is checked with a specified function.
 
OpSchemaNumInputsOutputs (std::function< bool(int, int)> func)
 Relationship between inputs and outputs is checked with a specified function.
 
OpSchemaOutputCalculator (std::function< int(int)> calc)
 Set the output calculator to a user-defined function.
 
OpSchemaSameNumberOfOutput ()
 Set the number of outputs to be the same as the number of inputs.
 
OpSchemaAllowInplace (std::function< bool(int, int)> inplace)
 
OpSchemaAllowInplace (set< std::pair< int, int >> inplace)
 
OpSchemaAllowOneToOneInplace ()
 
OpSchemaEnforceInplace (std::function< bool(int, int)> inplace)
 
OpSchemaEnforceInplace (set< std::pair< int, int >> inplace)
 
OpSchemaEnforceOneToOneInplace ()
 
OpSchemaTensorInferenceFunction (TensorInferenceFunctionType function)
 Sets the tensor inference function, which is a std::function object defined in operator_schema.h.
 
OpSchemaInheritOnnxSchema (const std::string &onnx_schema_name)
 Sets the corresponding onnx schema name.
 
OpSchemaInheritOnnxSchema ()
 Shortcut to InheritOnnxSchema(type_)
 
OpSchemaIdenticalTypeAndShape ()
 Sets the tensor inference function to produce the same output as the input.
 
OpSchemaIdenticalTypeAndShapeOfInput (int idx)
 
OpSchemaIdenticalTypeAndShapeOfInputDim (int idx, int dim)
 
OpSchemaIdenticalTypeAndShapeOfMultipleInputs (const vector< int > &indices)
 
OpSchemaScalarType (::caffe2::TensorProto_DataType dt)
 
vector< TensorShape > InferTensor (const OperatorDef &def, const vector< TensorShape > &input_type_shape) const
 A function to allow one to infer the type and shape from the op schema.
 
OpSchemaCostInferenceFunction (CostInferenceFunctionType function)
 Register the Cost inference function.
 
bool HasCostInferenceFunction () const
 
struct Cost InferCost (const OperatorDef &def, const vector< TensorShape > &input_tensor_shape) const
 
OpSchemaSetDoc (const string &doc)
 
OpSchemaArg (const char *name, const char *description, bool required=false)
 
OpSchemaInput (const int n, const char *name, const char *description)
 
OpSchemaOutput (const int n, const char *name, const char *description)
 
OpSchemaFillUsing (std::function< void(OpSchema &)> populator)
 
OpSchemaPrivate ()
 
OpSchemaInputsCanCrossDevices ()
 
int CalculateOutput (int num_input) const
 A function to allow one to get the number of outputs based on the number of inputs, if this schema supports it.
 
const std::string & onnx_schema () const
 
int min_input () const
 
int max_input () const
 
int min_output () const
 
int max_output () const
 
bool num_inputs_allowed (int x) const
 
bool num_outputs_allowed (int x) const
 
bool num_inputs_outputs_allowed (int x, int y) const
 
int inf () const
 
bool inplace_enforced (int x, int y) const
 
const std::vector< Argument > & args () const
 
const std::vector< std::pair< const char *, const char * > > & input_desc () const
 
const std::vector< std::pair< const char *, const char * > > & output_desc () const
 
bool private_op ()
 
bool inputs_can_cross_devices () const
 
OpSchemaDeviceInferenceFunction (DeviceInferenceFunctionType function)
 
std::pair< std::vector< DeviceOption >, std::vector< DeviceOption > > InferDevice (const OperatorDef &def) const
 Infer required device location of an op's inputs and outputs.
 
OpSchemaWeightedValueKeyLengthInputFillers (size_t value_index, size_t key_index, size_t length_index, size_t weight_index)
 
OpSchemaValueKeyLengthInputFillers (size_t value_index, size_t key_index, size_t length_index)
 
OpSchemaValueLengthInputFillers (size_t value_index, size_t length_index)
 
OpSchemaDisallowInputFillers ()
 
std::vector< TensorFillerInputFillers (const std::vector< std::vector< int64_t >> &shapes) const
 

Static Public Member Functions

static TensorInferenceFunctionType NeedsAllInputShapes (TensorInferenceFunctionType f)
 A wrapper that makes an infer tensor function to return unknown shape for all outputs if any one of the inputs has unknown shape.
 

Friends

CAFFE2_API friend std::ostream & operator<< (std::ostream &out, const OpSchema &schema)
 

Detailed Description

A class to record the schema of an op.

OpSchema records the common interface of an op specified by its name. This is optional for each operator implemented in Caffe2 but is strongly recommended.

To register an OpSchema, one can use the macro OPERATOR_SCHEMA(name) and then append the various functions in the class. For example, for an op that takes in two inputs, one output, and the first input and output could be in-place, can be written as

OPERATOR_SCHEMA(name)
    .NumInputs(2).NumOutputs(1).AllowInplace({{0, 0}});

Definition at line 39 of file operator_schema.h.


The documentation for this class was generated from the following files: