1 #ifndef CAFFE2_OPERATORS_IF_OP_H_ 2 #define CAFFE2_OPERATORS_IF_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 10 template <
class Context>
16 this->
template HasSingleArgumentOfType<NetDef>(
"then_net"),
17 "then_net must be specified in If operator");
19 this->
template GetSingleArgument<NetDef>(
"then_net", NetDef());
21 CAFFE_ENFORCE(then_net_,
"Failed to initialize then subnet");
23 if (this->
template HasSingleArgumentOfType<NetDef>(
"else_net")) {
25 this->
template GetSingleArgument<NetDef>(
"else_net", NetDef());
27 CAFFE_ENFORCE(else_net_,
"Failed to initialize else subnet");
31 USE_OPERATOR_CONTEXT_FUNCTIONS;
33 bool RunOnDevice()
override {
35 this->InputIsTensorType(0, Context::GetDeviceType()),
36 "Invalid condition in If operator: tensor expected");
38 const auto& condition =
Input(0);
42 "Invalid condition tensor in If operator: single value expected");
44 auto conditionValue = *condition.template data<bool>();
46 return then_net_->Run();
47 }
else if (else_net_) {
48 return else_net_->Run();
55 std::unique_ptr<NetBase> then_net_;
56 std::unique_ptr<NetBase> else_net_;
61 #endif // CAFFE2_OPERATORS_IF_OP_H_ Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.