Caffe2 - C++ API
A deep learning, cross platform ML framework
if_op.h
1 
17 #ifndef CAFFE2_OPERATORS_IF_OP_H_
18 #define CAFFE2_OPERATORS_IF_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 
24 namespace caffe2 {
25 
26 template <class Context>
27 class IfOp final : public Operator<Context> {
28  public:
29  IfOp(const OperatorDef& operator_def, Workspace* ws)
30  : Operator<Context>(operator_def, ws) {
31  CAFFE_ENFORCE(
32  this->template HasSingleArgumentOfType<NetDef>("then_net"),
33  "then_net must be specified in If operator");
34  auto then_net_def =
35  this->template GetSingleArgument<NetDef>("then_net", NetDef());
36  then_net_ = CreateNet(then_net_def, ws);
37  CAFFE_ENFORCE(then_net_, "Failed to initialize then subnet");
38 
39  if (this->template HasSingleArgumentOfType<NetDef>("else_net")) {
40  auto else_net_def =
41  this->template GetSingleArgument<NetDef>("else_net", NetDef());
42  else_net_ = CreateNet(else_net_def, ws);
43  CAFFE_ENFORCE(else_net_, "Failed to initialize else subnet");
44  }
45  }
46 
47  USE_OPERATOR_CONTEXT_FUNCTIONS;
48 
49  bool RunOnDevice() override {
50  CAFFE_ENFORCE(
51  this->template InputIsType<Tensor<Context>>(0),
52  "Invalid condition in If operator: tensor expected");
53 
54  const auto& condition = Input(0);
55  CAFFE_ENFORCE_EQ(
56  condition.size(),
57  1,
58  "Invalid condition tensor in If operator: single value expected");
59 
60  auto conditionValue = *condition.template data<bool>();
61  if (conditionValue) {
62  return then_net_->Run();
63  } else if (else_net_) {
64  return else_net_->Run();
65  }
66 
67  return true;
68  }
69 
70  private:
71  std::unique_ptr<NetBase> then_net_;
72  std::unique_ptr<NetBase> else_net_;
73 };
74 
75 } // namespace caffe2
76 
77 #endif // CAFFE2_OPERATORS_IF_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:117