Caffe2 - C++ API
A deep learning, cross platform ML framework
onnx_while_op.h
1 #ifndef CAFFE2_OPERATORS_ONNX_WHILE_OP_H_
2 #define CAFFE2_OPERATORS_ONNX_WHILE_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/operators/create_scope_op.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class ONNXWhileOp final : public Operator<Context> {
13  public:
14  explicit ONNXWhileOp(const OperatorDef& operator_def, Workspace* ws)
15  : Operator<Context>(operator_def, ws),
16  parent_ws_(ws),
17  has_trip_count_(
18  this->template GetSingleArgument<int64_t>("has_trip_count", 0)),
19  has_cond_(this->template GetSingleArgument<int64_t>("has_cond", 0)),
20  save_scopes_(
21  this->template GetSingleArgument<int64_t>("save_scopes", 0)),
22  disable_scopes_(
23  this->template GetSingleArgument<int64_t>("disable_scopes", 0)),
24  num_loop_carried_deps_(this->template GetSingleArgument<int64_t>(
25  "num_loop_carried_deps",
26  -1)) {
27  CAFFE_ENFORCE(
28  this->template HasSingleArgumentOfType<NetDef>("body"),
29  "body net must be specified in ONNXWhile operator");
30  if (disable_scopes_) {
31  CAFFE_ENFORCE(!save_scopes_, "Cannot save scopes when disable_scopes=True");
32  }
33  body_net_def_ = this->template GetSingleArgument<NetDef>("body", NetDef());
34  static int64_t counter = -1;
35  if (!body_net_def_.has_name()) {
36  if (counter == -1) {
37  ++counter;
38  body_net_def_.set_name("loop_net");
39  } else {
40  ++counter;
41  body_net_def_.set_name("loop_net." + c10::to_string(counter));
42  }
43  }
44  }
45 
46  USE_OPERATOR_CONTEXT_FUNCTIONS;
47 
48  bool RunOnDevice() {
49  return DispatchHelper<TensorTypes<int, bool, long>>::call(this, Input(1));
50  }
51 
52  // Operator
53  // Inputs: max trip count, condition, initial loop-carried dependencies
54  // Outputs: Final loop-carried dependencies, scan_outputs
55  // Body
56  // Inputs: iteration number, condition, loop-carried dependencies
57  // Outputs: condition, loop-carried dependencies, scan_outputs
58  template <typename CondVarType>
59  bool DoRunWithType() {
60  // Clear workspaces from the previous invocations of the loop
61  // and setup a local scope for the first iteration
62  ws_stack_.clear();
63  auto loop_ws = !disable_scopes_ ? ws_stack_.pushForwardWorkspace(parent_ws_).get() : parent_ws_;
64 
65  constexpr int64_t num_inputs_before_lcds = 2;
66  // First input is the maximumt trip count. Second input is the condition
67  // variable (for the first iteration). The rest of the inputs are
68  // loop-carried dependencies.
69  int64_t num_loop_carried_deps;
70  if (num_loop_carried_deps_ != -1) {
71  num_loop_carried_deps = num_loop_carried_deps_;
72  } else {
73  num_loop_carried_deps = InputSize() - num_inputs_before_lcds;
74  }
75  int64_t max_trip_count = *Input(0).template data<int64_t>();
76  const bool first_iter_condition = *Input(1).template data<CondVarType>();
77 
78  scope_ = std::make_shared<LocalScope>(loop_ws, body_net_def_, num_loop_carried_deps);
79 
80  // Body graph has 1+N+K outputs: recalculated condition variable, N
81  // loop-carried dependencies, and K scan_outputs
82  int num_scan_outputs =
83  scope_->net()->external_output().size() - num_loop_carried_deps - 1;
84 
85  CAFFE_ENFORCE_GE(
86  num_scan_outputs,
87  0,
88  "Body graph must have N+K outputs, where N is the number "
89  "of loop-carried dependencies and K is the number of scan "
90  "outputs");
91 
92  // Copy initial loop-carried dependencies
93  for (int i = 0; i < num_loop_carried_deps; ++i) {
94  scope_->lcd_tensor(i)->CopyFrom(Input(i + num_inputs_before_lcds));
95  }
96 
97  // Initialize iteration variable
98  scope_->set_iteration(0ll);
99 
100  // Initialize input condition variable
101  scope_->template set_input_condition<CondVarType>(first_iter_condition);
102 
103  auto valid_iter_num = [this, max_trip_count](int64_t i) {
104  if (has_trip_count_) {
105  return i < max_trip_count;
106  } else {
107  return true;
108  }
109  };
110 
111  auto condition_true =
112  [this, first_iter_condition](int64_t i, bool cond_value) {
113  if (has_cond_) {
114  if (i == 0) {
115  return (bool)first_iter_condition;
116  } else {
117  return cond_value;
118  }
119  } else {
120  return true;
121  }
122  };
123 
124  // Allocate scan_outputs for zero-iteration case
125  for (int i = 0; i < num_scan_outputs; ++i) {
126  Output(i + num_loop_carried_deps)->Resize(0);
127  Output(i + num_loop_carried_deps)->template mutable_data<int32_t>();
128  }
129 
130  // Use this to keep track of the sizes of the scan outputs and validate
131  // they're the same across iterations.
132  std::vector<std::vector<int64_t>> scan_outputs_sizes;
133 
134  Workspace *cur_ws = nullptr;
135  bool cur_output_condition = false;
136 
137  while (true) {
138  int64_t itr = scope_->iteration();
139  if (valid_iter_num(itr) && condition_true(itr, cur_output_condition)) {
140  if (!scope_->net()->Run()) {
141  return false;
142  }
143 
144  cur_ws = scope_->workspace();
145  cur_output_condition = scope_->template output_condition<CondVarType>();
146  if (save_scopes_) {
147  loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_).get();
148  scope_ = std::make_shared<LocalScope>(loop_ws, body_net_def_, num_loop_carried_deps);
149  }
150 
151  // Copy forward loop-carried dependencies
152  for (int i = 0; i < num_loop_carried_deps; ++i) {
153  Blob* b = cur_ws->GetBlob(
154  scope_->net()->external_output()[i + 1]);
155  const Tensor& t = b->template Get<Tensor>();
156  scope_->lcd_tensor(i)->CopyFrom(t);
157  }
158  // Copy out scan_outputs
159  for (int i = 0; i < num_scan_outputs; ++i) {
160  int net_output_idx = i + 1 + num_loop_carried_deps;
161  const Tensor& scan_output =
162  cur_ws->GetBlob(scope_->net()->external_output()[net_output_idx])
163  ->template Get<Tensor>();
164  auto* scan_output_target = Output(i + num_loop_carried_deps);
165  if (itr == 0) {
166  auto dims = scan_output.sizes().vec();
167  scan_outputs_sizes.push_back(dims);
168  dims.insert(dims.begin(), 1);
169  scan_output_target->Resize(dims);
170  scan_output_target->CopyFrom(scan_output);
171  } else {
172  auto dims = scan_output.sizes().vec();
173  CAFFE_ENFORCE_EQ(
174  dims,
175  scan_outputs_sizes[i],
176  "Size of scan output changed across iterations");
177  dims.insert(dims.begin(), itr);
178  scan_output_target->Extend(1, 100);
179 
180  int64_t timestep_size = 1;
181  for (const int64_t t : scan_outputs_sizes[i]) {
182  timestep_size *= t;
183  }
184 
185  const void* src_data = scan_output.raw_data();
186  auto& sot_meta = scan_output_target->dtype();
187  void* dst_data =
188  (char*)scan_output_target->raw_mutable_data(sot_meta) +
189  timestep_size * scan_output.itemsize() * itr;
190  memcpy(dst_data, src_data, timestep_size * scan_output.itemsize());
191  }
192  }
193  scope_->set_iteration(itr + 1ll);
194  scope_->template set_input_condition<CondVarType>(cur_output_condition);
195  } else {
196  break;
197  }
198  }
199 
200  // Copy out final loop-carried dependencies
201  for (int i = 0; i < num_loop_carried_deps; ++i) {
202  Output(i)->CopyFrom(*scope_->lcd_tensor(i));
203  }
204 
205  return true;
206  }
207 
208  private:
209  class LocalScope {
210  public:
211  LocalScope(
212  Workspace *loop_ws,
213  const NetDef& body_net_def, size_t num_lcds) : loop_ws_(loop_ws){
214  CAFFE_ENFORCE(loop_ws_,
215  "Failed to initialize local loop workspace");
216 
217  // Create loop-carried deps in Workspace
218  lcd_tensors_.clear();
219  for (int i = 2; i < num_lcds + 2; ++i) {
220  Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i));
221  Tensor* t = BlobGetMutableTensor(b, Context::GetDeviceType());
222  lcd_tensors_.push_back(t);
223  }
224  // First output is the iteration variable
225  auto* iteration_var_blob = loop_ws_->CreateBlob(
226  body_net_def.external_input(0));
227  iteration_var_ =
228  BlobGetMutableTensor(iteration_var_blob, Context::GetDeviceType());
229 
230  input_condition_var_ = BlobGetMutableTensor(
231  loop_ws_->CreateBlob(body_net_def.external_input(1)),
232  Context::GetDeviceType());
233 
234  auto* condition_var_blob =
235  loop_ws_->CreateBlob(body_net_def.external_output(0));
236  condition_var_ =
237  BlobGetMutableTensor(condition_var_blob, Context::GetDeviceType());
238  condition_var_->Resize(1);
239  condition_var_->template mutable_data<bool>();
240 
241  body_net_ = loop_ws_->GetNet(body_net_def.name());
242  if (!body_net_) {
243  body_net_ = loop_ws_->CreateNet(body_net_def, true);
244  }
245  CAFFE_ENFORCE(body_net_, "Failed to initialize loop subnet");
246  }
247 
248  NetBase* net() const {
249  return body_net_;
250  }
251 
252  Workspace* workspace() const {
253  return loop_ws_;
254  }
255 
256  int64_t iteration() const {
257  auto* iteration_var_ptr =
258  iteration_var_->template mutable_data<int64_t>();
259  return *iteration_var_ptr;
260  }
261 
262  Tensor* lcd_tensor(int idx) {
263  return lcd_tensors_[idx];
264  }
265 
266  void set_iteration(int64_t itr) {
267  iteration_var_->Resize();
268  auto* iteration_var_ptr =
269  iteration_var_->template mutable_data<int64_t>();
270  *iteration_var_ptr = itr;
271  }
272 
273  template <typename CondVarType>
274  void set_input_condition(bool cond_value) {
275  input_condition_var_->Resize(1);
276  auto* input_condition_var_ptr =
277  input_condition_var_->template mutable_data<CondVarType>();
278  *input_condition_var_ptr = cond_value;
279  }
280 
281  template <typename CondVarType>
282  bool output_condition() const {
283  auto* condition_var_ptr =
284  condition_var_->template mutable_data<CondVarType>();
285  return *condition_var_ptr;
286  }
287 
288  private:
289  Workspace *loop_ws_;
290 
291  NetBase* body_net_; // owned by a workspace
292  Tensor* iteration_var_;
293  Tensor* input_condition_var_;
294  Tensor* condition_var_;
295 
296  std::vector<Tensor*> lcd_tensors_;
297  };
298 
299  NetDef body_net_def_;
300  Workspace* parent_ws_;
301  detail::WorkspaceStack ws_stack_;
302 
303  bool has_trip_count_;
304  bool has_cond_;
305  bool save_scopes_;
306  bool disable_scopes_;
307  int64_t num_loop_carried_deps_;
308 
309  std::shared_ptr<LocalScope> scope_;
310 };
311 
312 } // namespace caffe2
313 
314 #endif // CAFFE2_OPERATORS_ONNX_WHILE_OP_H
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:160
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13