1 #ifndef CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ 2 #define CAFFE2_OPERATORS_ONNX_WHILE_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/operators/create_scope_op.h" 11 template <
class Context>
18 this->
template GetSingleArgument<int64_t>(
"has_trip_count", 0)),
19 has_cond_(this->
template GetSingleArgument<int64_t>(
"has_cond", 0)),
21 this->
template GetSingleArgument<int64_t>(
"save_scopes", 0)),
23 this->
template GetSingleArgument<int64_t>(
"disable_scopes", 0)),
24 num_loop_carried_deps_(this->
template GetSingleArgument<int64_t>(
25 "num_loop_carried_deps",
28 this->
template HasSingleArgumentOfType<NetDef>(
"body"),
29 "body net must be specified in ONNXWhile operator");
30 if (disable_scopes_) {
31 CAFFE_ENFORCE(!save_scopes_,
"Cannot save scopes when disable_scopes=True");
33 body_net_def_ = this->
template GetSingleArgument<NetDef>(
"body", NetDef());
34 static int64_t counter = -1;
35 if (!body_net_def_.has_name()) {
38 body_net_def_.set_name(
"loop_net");
41 body_net_def_.set_name(
"loop_net." + c10::to_string(counter));
46 USE_OPERATOR_CONTEXT_FUNCTIONS;
58 template <
typename CondVarType>
59 bool DoRunWithType() {
63 auto loop_ws = !disable_scopes_ ? ws_stack_.pushForwardWorkspace(parent_ws_).get() : parent_ws_;
65 constexpr int64_t num_inputs_before_lcds = 2;
69 int64_t num_loop_carried_deps;
70 if (num_loop_carried_deps_ != -1) {
71 num_loop_carried_deps = num_loop_carried_deps_;
73 num_loop_carried_deps = InputSize() - num_inputs_before_lcds;
75 int64_t max_trip_count = *
Input(0).template data<int64_t>();
76 const bool first_iter_condition = *
Input(1).template data<CondVarType>();
78 scope_ = std::make_shared<LocalScope>(loop_ws, body_net_def_, num_loop_carried_deps);
82 int num_scan_outputs =
83 scope_->net()->external_output().size() - num_loop_carried_deps - 1;
88 "Body graph must have N+K outputs, where N is the number " 89 "of loop-carried dependencies and K is the number of scan " 93 for (
int i = 0; i < num_loop_carried_deps; ++i) {
94 scope_->lcd_tensor(i)->CopyFrom(
Input(i + num_inputs_before_lcds));
98 scope_->set_iteration(0ll);
101 scope_->template set_input_condition<CondVarType>(first_iter_condition);
103 auto valid_iter_num = [
this, max_trip_count](int64_t i) {
104 if (has_trip_count_) {
105 return i < max_trip_count;
111 auto condition_true =
112 [
this, first_iter_condition](int64_t i,
bool cond_value) {
115 return (
bool)first_iter_condition;
125 for (
int i = 0; i < num_scan_outputs; ++i) {
126 Output(i + num_loop_carried_deps)->Resize(0);
127 Output(i + num_loop_carried_deps)->template mutable_data<int32_t>();
132 std::vector<std::vector<int64_t>> scan_outputs_sizes;
135 bool cur_output_condition =
false;
138 int64_t itr = scope_->iteration();
139 if (valid_iter_num(itr) && condition_true(itr, cur_output_condition)) {
140 if (!scope_->net()->Run()) {
144 cur_ws = scope_->workspace();
145 cur_output_condition = scope_->template output_condition<CondVarType>();
147 loop_ws = ws_stack_.pushForwardWorkspace(parent_ws_).get();
148 scope_ = std::make_shared<LocalScope>(loop_ws, body_net_def_, num_loop_carried_deps);
152 for (
int i = 0; i < num_loop_carried_deps; ++i) {
154 scope_->net()->external_output()[i + 1]);
155 const Tensor& t = b->template Get<Tensor>();
156 scope_->lcd_tensor(i)->CopyFrom(t);
159 for (
int i = 0; i < num_scan_outputs; ++i) {
160 int net_output_idx = i + 1 + num_loop_carried_deps;
161 const Tensor& scan_output =
162 cur_ws->
GetBlob(scope_->net()->external_output()[net_output_idx])
163 ->
template Get<Tensor>();
164 auto* scan_output_target = Output(i + num_loop_carried_deps);
166 auto dims = scan_output.sizes().vec();
167 scan_outputs_sizes.push_back(dims);
168 dims.insert(dims.begin(), 1);
169 scan_output_target->Resize(dims);
170 scan_output_target->CopyFrom(scan_output);
172 auto dims = scan_output.sizes().vec();
175 scan_outputs_sizes[i],
176 "Size of scan output changed across iterations");
177 dims.insert(dims.begin(), itr);
178 scan_output_target->Extend(1, 100);
180 int64_t timestep_size = 1;
181 for (
const int64_t t : scan_outputs_sizes[i]) {
185 const void* src_data = scan_output.raw_data();
186 auto& sot_meta = scan_output_target->dtype();
188 (
char*)scan_output_target->raw_mutable_data(sot_meta) +
189 timestep_size * scan_output.itemsize() * itr;
190 memcpy(dst_data, src_data, timestep_size * scan_output.itemsize());
193 scope_->set_iteration(itr + 1ll);
194 scope_->template set_input_condition<CondVarType>(cur_output_condition);
201 for (
int i = 0; i < num_loop_carried_deps; ++i) {
202 Output(i)->CopyFrom(*scope_->lcd_tensor(i));
213 const NetDef& body_net_def,
size_t num_lcds) : loop_ws_(loop_ws){
214 CAFFE_ENFORCE(loop_ws_,
215 "Failed to initialize local loop workspace");
218 lcd_tensors_.clear();
219 for (
int i = 2; i < num_lcds + 2; ++i) {
220 Blob* b = loop_ws_->CreateBlob(body_net_def.external_input(i));
221 Tensor* t = BlobGetMutableTensor(b, Context::GetDeviceType());
222 lcd_tensors_.push_back(t);
225 auto* iteration_var_blob = loop_ws_->CreateBlob(
226 body_net_def.external_input(0));
228 BlobGetMutableTensor(iteration_var_blob, Context::GetDeviceType());
230 input_condition_var_ = BlobGetMutableTensor(
231 loop_ws_->CreateBlob(body_net_def.external_input(1)),
232 Context::GetDeviceType());
234 auto* condition_var_blob =
235 loop_ws_->CreateBlob(body_net_def.external_output(0));
237 BlobGetMutableTensor(condition_var_blob, Context::GetDeviceType());
238 condition_var_->Resize(1);
239 condition_var_->template mutable_data<bool>();
241 body_net_ = loop_ws_->GetNet(body_net_def.name());
243 body_net_ = loop_ws_->CreateNet(body_net_def,
true);
245 CAFFE_ENFORCE(body_net_,
"Failed to initialize loop subnet");
256 int64_t iteration()
const {
257 auto* iteration_var_ptr =
258 iteration_var_->template mutable_data<int64_t>();
259 return *iteration_var_ptr;
262 Tensor* lcd_tensor(
int idx) {
263 return lcd_tensors_[idx];
266 void set_iteration(int64_t itr) {
267 iteration_var_->Resize();
268 auto* iteration_var_ptr =
269 iteration_var_->template mutable_data<int64_t>();
270 *iteration_var_ptr = itr;
273 template <
typename CondVarType>
274 void set_input_condition(
bool cond_value) {
275 input_condition_var_->Resize(1);
276 auto* input_condition_var_ptr =
277 input_condition_var_->template mutable_data<CondVarType>();
278 *input_condition_var_ptr = cond_value;
281 template <
typename CondVarType>
282 bool output_condition()
const {
283 auto* condition_var_ptr =
284 condition_var_->template mutable_data<CondVarType>();
285 return *condition_var_ptr;
293 Tensor* input_condition_var_;
296 std::vector<Tensor*> lcd_tensors_;
299 NetDef body_net_def_;
303 bool has_trip_count_;
306 bool disable_scopes_;
307 int64_t num_loop_carried_deps_;
309 std::shared_ptr<LocalScope> scope_;
314 #endif // CAFFE2_OPERATORS_ONNX_WHILE_OP_H Blob is a general container that hosts a typed pointer.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...