1 #ifndef CAFFE2_OPERATORS_WHILE_OP_H_ 2 #define CAFFE2_OPERATORS_WHILE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 10 template <
class Context>
16 this->
template HasSingleArgumentOfType<NetDef>(
"loop_net"),
17 "loop_net must be specified in While operator");
19 this->
template GetSingleArgument<NetDef>(
"loop_net", NetDef());
21 CAFFE_ENFORCE(loop_net_,
"Failed to initialize loop subnet");
25 this->
template HasSingleArgumentOfType<NetDef>(
"cond_net");
28 this->
template GetSingleArgument<NetDef>(
"cond_net", NetDef());
30 CAFFE_ENFORCE(cond_net_,
"Failed to initialize condition subnet");
34 USE_OPERATOR_CONTEXT_FUNCTIONS;
36 bool RunOnDevice()
override {
38 this->InputIsTensorType(0, Context::GetDeviceType()),
39 "Invalid condition in While operator: tensor expected");
41 const auto& condition =
Input(0);
45 "Invalid condition tensor in While operator: single value expected");
48 if (cond_net_ && !cond_net_->Run()) {
51 if (!*condition.template data<bool>()) {
54 if (!loop_net_->Run()) {
64 std::unique_ptr<NetBase> loop_net_;
67 std::unique_ptr<NetBase> cond_net_;
72 #endif // CAFFE2_OPERATORS_WHILE_OP_H_ Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.