Caffe2 - C++ API
A deep learning, cross platform ML framework
onnx_while_op.h
1 
17 #ifndef CAFFE2_OPERATORS_ONNX_WHILE_OP_H_
18 #define CAFFE2_OPERATORS_ONNX_WHILE_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 
24 namespace caffe2 {
25 
26 template <class Context>
27 class ONNXWhileOp final : public Operator<Context> {
28  public:
29  ONNXWhileOp(const OperatorDef& operator_def, Workspace* ws)
30  : Operator<Context>(operator_def, ws),
31  ws_(ws),
32  has_trip_count_(
33  OperatorBase::GetSingleArgument<int64_t>("has_trip_count", 0)),
34  has_cond_(OperatorBase::GetSingleArgument<int64_t>("has_cond", 0)) {
35  CAFFE_ENFORCE(
36  this->template HasSingleArgumentOfType<NetDef>("body"),
37  "body net must be specified in ONNXWhile operator");
38  body_net_def_ = this->template GetSingleArgument<NetDef>("body", NetDef());
39 
40  // Create loop-carried deps in Workspace
41  for (int i = 2; i < body_net_def_.external_input_size(); ++i) {
42  Blob* b = ws_->CreateBlob(body_net_def_.external_input(i));
43  Tensor<Context>* t = b->template GetMutable<Tensor<Context>>();
44  lcd_tensors_.push_back(t);
45  }
46  // First output is the iteration variable
47  auto* iteration_var_blob = ws_->CreateBlob(body_net_def_.external_input(0));
48  iteration_var_ = iteration_var_blob->template GetMutable<Tensor<Context>>();
49 
50  input_condition_var_ = ws_->CreateBlob(body_net_def_.external_input(1))
51  ->template GetMutable<Tensor<Context>>();
52 
53  auto* condition_var_blob =
54  ws_->CreateBlob(body_net_def_.external_output(0));
55  condition_var_ = condition_var_blob->template GetMutable<Tensor<Context>>();
56 
57  body_net_ = CreateNet(body_net_def_, ws);
58  CAFFE_ENFORCE(body_net_, "Failed to initialize loop subnet");
59  }
60 
61  USE_OPERATOR_CONTEXT_FUNCTIONS;
62 
63  // Operator
64  // Inputs: max trip count, condition, initial loop-carried dependencies
65  // Outputs: Final loop-carried dependencies, scan_outputs
66  // Body
67  // Inputs: iteration number, condition, loop-carried dependencies
68  // Outputs: condition, loop-carried dependencies, scan_outputs
69  bool RunOnDevice() override {
70  constexpr int64_t num_inputs_before_lcds = 2;
71  // First input is the maximumt trip count. Second input is the condition
72  // variable (for the first iteration). The rest of the inputs are
73  // loop-carried dependencies.
74  int num_loop_carried_deps = InputSize() - num_inputs_before_lcds;
75  int64_t max_trip_count = *Input(0).template data<int64_t>();
76  const bool first_iter_condition = *Input(1).template data<bool>();
77 
78  // Body graph has 2+N inputs: iteration number, condition value, and N
79  // loop-carried dependencies
80  CAFFE_ENFORCE_EQ(
81  num_loop_carried_deps + 2,
82  body_net_->external_input().size(),
83  "Body graph must have 2+N inputs, where N is the number of "
84  "loop carried dependencies.");
85 
86  // Body graph has 1+N+K outputs: recalculated condition variable, N
87  // loop-carried dependencies, and K scan_outputs
88  int num_scan_outputs =
89  body_net_->external_output().size() - num_loop_carried_deps - 1;
90 
91  CAFFE_ENFORCE_GE(
92  num_scan_outputs,
93  0,
94  "Body graph must have N+K outputs, where N is the number "
95  "of loop-carried dependencies and K is the number of scan "
96  "outputs");
97 
98  // Copy initial loop-carried dependencies
99  for (int i = 0; i < num_loop_carried_deps; ++i) {
100  lcd_tensors_[i]->CopyFrom(Input(i + num_inputs_before_lcds));
101  }
102 
103  // Initialize iteration variable
104  iteration_var_->Resize(1);
105  auto* iteration_var_ptr = iteration_var_->template mutable_data<int64_t>();
106  *iteration_var_ptr = 0ll;
107 
108  // Input condition var. This requires special handling
109  input_condition_var_->Resize(1);
110  auto* input_condition_var_ptr =
111  input_condition_var_->template mutable_data<bool>();
112  *input_condition_var_ptr = first_iter_condition;
113 
114  // Output condition var. This is yielded by the body net and we will use its
115  // value to determine further iteration
116 
117  condition_var_->Resize(1);
118  auto* condition_var_ptr = condition_var_->template mutable_data<bool>();
119 
120  auto valid_iter_num = [this, max_trip_count](int64_t i) {
121  if (has_trip_count_) {
122  return i < max_trip_count;
123  } else {
124  return true;
125  }
126  };
127 
128  auto condition_true =
129  [this, first_iter_condition, condition_var_ptr](int64_t i) {
130  if (has_cond_) {
131  if (i == 0) {
132  return (bool)first_iter_condition;
133  } else {
134  return (bool)*condition_var_ptr;
135  }
136  } else {
137  return true;
138  }
139  };
140 
141  // Allocate scan_outputs for zero-iteration case
142  for (int i = 0; i < num_scan_outputs; ++i) {
143  Output(i + num_loop_carried_deps)->Resize(0);
144  Output(i + num_loop_carried_deps)->template mutable_data<int32_t>();
145  }
146 
147  // Use this to keep track of the sizes of the scan outputs and validate
148  // they're the same across iterations.
149  std::vector<std::vector<TIndex>> scan_outputs_sizes;
150 
151  while (true) {
152  int64_t itr = *iteration_var_ptr;
153  if (valid_iter_num(itr) && condition_true(itr)) {
154  if (!body_net_->Run()) {
155  return false;
156  }
157  // Copy forward loop-carried dependencies
158  for (int i = 0; i < num_loop_carried_deps; ++i) {
159  Blob* b = ws_->GetBlob(body_net_->external_output()[i + 1]);
160  const Tensor<Context>& t = b->template Get<Tensor<Context>>();
161  lcd_tensors_[i]->CopyFrom(t);
162  }
163  // Copy out scan_outputs
164  for (int i = 0; i < num_scan_outputs; ++i) {
165  int net_output_idx = i + 1 + num_loop_carried_deps;
166  const Tensor<Context>& scan_output =
167  ws_->GetBlob(body_net_->external_output()[net_output_idx])
168  ->template Get<Tensor<Context>>();
169  auto* scan_output_target = Output(i + num_loop_carried_deps);
170  if (itr == 0) {
171  auto dims = scan_output.dims();
172  scan_outputs_sizes.push_back(dims);
173  dims.insert(dims.begin(), 1);
174  scan_output_target->Resize(dims);
175  scan_output_target->CopyFrom(scan_output);
176  } else {
177  auto dims = scan_output.dims();
178  CAFFE_ENFORCE_EQ(
179  dims,
180  scan_outputs_sizes[i],
181  "Size of scan output changed across iterations");
182  dims.insert(dims.begin(), itr);
183  scan_output_target->Extend(1, 2.0f, &context_);
184 
185  TIndex timestep_size = 1;
186  for (const TIndex t : scan_outputs_sizes[i]) {
187  timestep_size *= t;
188  }
189 
190  const void* src_data = scan_output.raw_data();
191  auto& sot_meta = scan_output_target->meta();
192  void* dst_data =
193  (char*)scan_output_target->raw_mutable_data(sot_meta) +
194  timestep_size * scan_output.itemsize() * itr;
195  memcpy(dst_data, src_data, timestep_size * scan_output.itemsize());
196  }
197  }
198  } else {
199  break;
200  }
201  *iteration_var_ptr += 1ll;
202  *input_condition_var_ptr = *condition_var_ptr;
203  }
204 
205  if (*iteration_var_ptr > 0) {
206  // Copy out final loop-carried dependencies
207  for (int i = 0; i < num_loop_carried_deps; ++i) {
208  Output(i)->CopyFrom(*lcd_tensors_[i]);
209  }
210  } else {
211  // Copy out final loop-carried dependencies
212  for (int i = 0; i < num_loop_carried_deps; ++i) {
213  Output(i)->CopyFrom(Input(i + num_inputs_before_lcds));
214  }
215  }
216 
217  return true;
218  }
219 
220  NetDef body_net_def_;
221  std::unique_ptr<NetBase> body_net_;
222  Workspace* ws_;
223 
224  bool has_trip_count_, has_cond_;
225 
226  Tensor<Context>*iteration_var_, *input_condition_var_, *condition_var_;
227 
228  std::vector<Tensor<Context>*> lcd_tensors_;
229 };
230 
231 } // namespace caffe2
232 
233 #endif // CAFFE2_OPERATORS_ONNX_WHILE_OP_H
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:120
size_t itemsize() const
Return the number of bytes each item takes in the tensor.
Definition: tensor.h:613
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
const vector< TIndex > & dims() const
Returns the dimensions of the tensor as a vector.
Definition: tensor.h:627
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:180
Copyright (c) 2016-present, Facebook, Inc.
const void * raw_data() const
Returns a const raw void* pointer of the underlying storage.
Definition: tensor.h:488
unique_ptr< NetBase > CreateNet(const NetDef &net_def, Workspace *ws)
Creates a network, accessing / creating blobs in the given workspace.
Definition: net.cc:117