Caffe2 - C++ API
A deep learning, cross platform ML framework
if_op.h
1 #ifndef CAFFE2_OPERATORS_IF_OP_H_
2 #define CAFFE2_OPERATORS_IF_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 
8 namespace caffe2 {
9 
10 template <class Context>
11 class IfOp final : public Operator<Context> {
12  public:
13  explicit IfOp(const OperatorDef& operator_def, Workspace* ws)
14  : Operator<Context>(operator_def, ws) {
15  CAFFE_ENFORCE(
16  this->template HasSingleArgumentOfType<NetDef>("then_net"),
17  "then_net must be specified in If operator");
18  auto then_net_def =
19  this->template GetSingleArgument<NetDef>("then_net", NetDef());
20  then_net_ = CreateNet(then_net_def, ws);
21  CAFFE_ENFORCE(then_net_, "Failed to initialize then subnet");
22 
23  if (this->template HasSingleArgumentOfType<NetDef>("else_net")) {
24  auto else_net_def =
25  this->template GetSingleArgument<NetDef>("else_net", NetDef());
26  else_net_ = CreateNet(else_net_def, ws);
27  CAFFE_ENFORCE(else_net_, "Failed to initialize else subnet");
28  }
29  }
30 
31  USE_OPERATOR_CONTEXT_FUNCTIONS;
32 
33  bool RunOnDevice() override {
34  CAFFE_ENFORCE(
35  this->InputIsTensorType(0, Context::GetDeviceType()),
36  "Invalid condition in If operator: tensor expected");
37 
38  const auto& condition = Input(0);
39  CAFFE_ENFORCE_EQ(
40  condition.numel(),
41  1,
42  "Invalid condition tensor in If operator: single value expected");
43 
44  auto conditionValue = *condition.template data<bool>();
45  if (conditionValue) {
46  return then_net_->Run();
47  } else if (else_net_) {
48  return else_net_->Run();
49  }
50 
51  return true;
52  }
53 
54  private:
55  std::unique_ptr<NetBase> then_net_;
56  std::unique_ptr<NetBase> else_net_;
57 };
58 
59 } // namespace caffe2
60 
61 #endif // CAFFE2_OPERATORS_IF_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:151