Caffe2 - C++ API
A deep learning, cross platform ML framework
while_op.h
1 #ifndef CAFFE2_OPERATORS_WHILE_OP_H_
2 #define CAFFE2_OPERATORS_WHILE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 
8 namespace caffe2 {
9 
10 template <class Context>
11 class WhileOp final : public Operator<Context> {
12  public:
13  explicit WhileOp(const OperatorDef& operator_def, Workspace* ws)
14  : Operator<Context>(operator_def, ws) {
15  CAFFE_ENFORCE(
16  this->template HasSingleArgumentOfType<NetDef>("loop_net"),
17  "loop_net must be specified in While operator");
18  loop_net_def_ =
19  this->template GetSingleArgument<NetDef>("loop_net", NetDef());
20  loop_net_ = CreateNet(loop_net_def_, ws);
21  CAFFE_ENFORCE(loop_net_, "Failed to initialize loop subnet");
22 
23  cond_net_ = nullptr;
24  bool has_cond_net =
25  this->template HasSingleArgumentOfType<NetDef>("cond_net");
26  if (has_cond_net) {
27  cond_net_def_ =
28  this->template GetSingleArgument<NetDef>("cond_net", NetDef());
29  cond_net_ = CreateNet(cond_net_def_, ws);
30  CAFFE_ENFORCE(cond_net_, "Failed to initialize condition subnet");
31  }
32  }
33 
34  USE_OPERATOR_CONTEXT_FUNCTIONS;
35 
36  bool RunOnDevice() override {
37  CAFFE_ENFORCE(
38  this->InputIsTensorType(0, Context::GetDeviceType()),
39  "Invalid condition in While operator: tensor expected");
40 
41  const auto& condition = Input(0);
42  CAFFE_ENFORCE_EQ(
43  condition.numel(),
44  1,
45  "Invalid condition tensor in While operator: single value expected");
46 
47  while (true) {
48  if (cond_net_ && !cond_net_->Run()) {
49  return false;
50  }
51  if (!*condition.template data<bool>()) {
52  return true;
53  }
54  if (!loop_net_->Run()) {
55  return false;
56  }
57  }
58 
59  return true;
60  }
61 
62  private:
63  NetDef loop_net_def_;
64  std::unique_ptr<NetBase> loop_net_;
65 
66  NetDef cond_net_def_;
67  std::unique_ptr<NetBase> cond_net_;
68 };
69 
70 } // namespace caffe2
71 
72 #endif // CAFFE2_OPERATORS_WHILE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:151