Caffe2 - C++ API
A deep learning, cross platform ML framework
while_op.h
1 
17 #ifndef CAFFE2_OPERATORS_WHILE_OP_H_
18 #define CAFFE2_OPERATORS_WHILE_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 
24 namespace caffe2 {
25 
26 template <class Context>
27 class WhileOp final : public Operator<Context> {
28  public:
29  WhileOp(const OperatorDef& operator_def, Workspace* ws)
30  : Operator<Context>(operator_def, ws) {
31  CAFFE_ENFORCE(
32  this->template HasSingleArgumentOfType<NetDef>("loop_net"),
33  "loop_net must be specified in While operator");
34  loop_net_def_ =
35  this->template GetSingleArgument<NetDef>("loop_net", NetDef());
36  loop_net_ = CreateNet(loop_net_def_, ws);
37  CAFFE_ENFORCE(loop_net_, "Failed to initialize loop subnet");
38 
39  cond_net_ = nullptr;
40  bool has_cond_net =
41  this->template HasSingleArgumentOfType<NetDef>("cond_net");
42  if (has_cond_net) {
43  cond_net_def_ =
44  this->template GetSingleArgument<NetDef>("cond_net", NetDef());
45  cond_net_ = CreateNet(cond_net_def_, ws);
46  CAFFE_ENFORCE(cond_net_, "Failed to initialize condition subnet");
47  }
48  }
49 
50  USE_OPERATOR_CONTEXT_FUNCTIONS;
51 
52  bool RunOnDevice() override {
53  CAFFE_ENFORCE(
54  this->template InputIsType<Tensor<Context>>(0),
55  "Invalid condition in While operator: tensor expected");
56 
57  const auto& condition = Input(0);
58  CAFFE_ENFORCE_EQ(
59  condition.size(),
60  1,
61  "Invalid condition tensor in While operator: single value expected");
62 
63  while (true) {
64  if (cond_net_ && !cond_net_->Run()) {
65  return false;
66  }
67  if (!*condition.template data<bool>()) {
68  return true;
69  }
70  if (!loop_net_->Run()) {
71  return false;
72  }
73  }
74 
75  return true;
76  }
77 
78  private:
79  NetDef loop_net_def_;
80  std::unique_ptr<NetBase> loop_net_;
81 
82  NetDef cond_net_def_;
83  std::unique_ptr<NetBase> cond_net_;
84 };
85 
86 } // namespace caffe2
87 
88 #endif // CAFFE2_OPERATORS_WHILE_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:117