Caffe2 - C++ API
A deep learning, cross platform ML framework
operator.cc
1 #include "caffe2/core/operator.h"
2 
3 #include <algorithm>
4 
5 #include "caffe2/core/init.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/net.h"
8 #include "caffe2/core/operator_gradient.h"
9 #include "caffe2/core/tensor.h"
10 #include "caffe2/core/types.h"
11 #include "caffe2/core/workspace.h"
12 
13 #include "caffe2/proto/caffe2_pb.h"
14 #include "caffe2/utils/proto_utils.h"
15 #include "caffe2/utils/string_utils.h"
16 
17 #include "caffe2/core/operator_c10wrapper.h"
18 
19 C10_DEFINE_int(
20  caffe2_operator_max_engine_name_length,
21  10,
22  "Maximum engine name length to be stored");
23 C10_DEFINE_bool(
24  caffe2_disable_implicit_engine_preference,
25  false,
26  "If set, disable implicit engine preferences. This is useful for unit "
27  "testing and debugging cases.");
28 C10_DEFINE_bool(
29  caffe2_operator_throw_if_fp_exceptions,
30  false,
31  "If set, throws if floating point exceptions (FE_DIVBYZERO, FE_INVALID, "
32  "FE_OVERFLOW) are detected when running any operator.");
33 
34 namespace caffe2 {
35 
36 OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws)
37  : operator_ws_(ws),
38  operator_def_(std::make_shared<OperatorDef>(operator_def)),
39  device_option_(
40  operator_def.has_device_option() ? operator_def.device_option()
41  : DeviceOption()),
42  input_size_(operator_def.input_size()),
43  event_(caffe2::make_unique<Event>(device_option_)) {
44  static GlobalInitIsCalledGuard guard;
45  for (const string& input_str : operator_def.input()) {
46  auto* blob = ws->GetBlob(input_str);
47  CAFFE_ENFORCE(
48  blob != nullptr,
49  "op ",
50  operator_def.type(),
51  ": Encountered a non-existing input blob: ",
52  input_str);
53  inputs_.push_back(blob);
54  }
55 
56  GetOperatorLogger()(operator_def);
57 
58  for (const string& output_str : operator_def.output()) {
59  outputs_.push_back(CHECK_NOTNULL(ws->CreateBlob(output_str)));
60  }
61 
62  type_ = operator_def.type();
63 }
64 
65 namespace {
66 int compute_input_size_(const std::vector<c10::IValue>& inputs) {
67  if (inputs.empty()) {
68  return 0;
69  }
70  if (inputs[0].isTensorList()) {
71  // if the first input is a tensor list, we get input tensors by indexing
72  // into that list. currently, this means that only tensors from that list
73  // are accessible as inputs. any hypothetical input tensors that come after
74  // the list are not accessible.
75  return inputs[0].toTensorListRef().size();
76  }
77  // it's not a tensor list. Count the number of tensor inputs and return them.
78  size_t num_tensor_inputs = 0;
79  bool found_nontensor = false;
80  for (const auto& input : inputs) {
81  if (input.isTensor()) {
82  AT_ASSERTM(
83  !found_nontensor,
84  "All tensor arguments must come before non-tensor arguments");
85  ++num_tensor_inputs;
86  } else {
87  found_nontensor = true;
88  }
89  }
90  return num_tensor_inputs;
91 }
92 } // namespace
93 
94 OperatorBase::OperatorBase(
95  const c10::FunctionSchema& fn_schema,
96  std::vector<c10::IValue> inputs,
97  std::vector<at::Tensor> outputs)
98  : fn_schema_(make_unique<c10::FunctionSchema>(std::move(fn_schema))),
99  newstyle_inputs_(std::move(inputs)),
100  newstyle_outputs_(std::move(outputs)),
101  input_size_(compute_input_size_(newstyle_inputs_)) {
102  input_tensors_.resize(input_size_);
103  output_tensors_.resize(newstyle_outputs_.size());
104 }
105 
106 vector<TensorShape> OperatorBase::InputTensorShapes() const {
107  vector<TensorShape> tps;
108  for (const auto& blob : inputs_) {
109  tps.push_back(GetTensorShapeOfBlob(blob));
110  }
111  return tps;
112 }
113 
114 namespace {
115 
116 PerOpEnginePrefType& g_per_op_engine_pref() {
117  static auto* g_per_op_engine_pref_ = new PerOpEnginePrefType();
118  return *g_per_op_engine_pref_;
119 }
120 
121 GlobalEnginePrefType& g_global_engine_pref() {
122  static auto* g_global_engine_pref_ =
123  new GlobalEnginePrefType{{CUDA, {"CUDNN"}}, {HIP, {"MIOPEN"}}};
124  return *g_global_engine_pref_;
125 }
126 
127 unique_ptr<OperatorBase> TryCreateOperator(
128  const string& key,
129  const OperatorDef& operator_def,
130  Workspace* ws) {
131  const auto& type_proto = operator_def.device_option().device_type();
132  const auto& type = ProtoToType(static_cast<DeviceTypeProto>(type_proto));
133  CAFFE_ENFORCE(
134  gDeviceTypeRegistry()->count(type),
135  "Device type ",
136  type,
137  " not registered.");
138  OperatorRegistry* registry = gDeviceTypeRegistry()->at(type);
139  VLOG(1) << "Creating operator with device type " << type;
140  try {
141  return registry->Create(key, operator_def, ws);
142  } catch (const UnsupportedOperatorFeature& err) {
143  LOG(WARNING) << "Operator " << operator_def.type()
144  << " does not support the requested feature. Msg: "
145  << err.what()
146  << ". Proto is: " << ProtoDebugString(operator_def);
147  return nullptr;
148  }
149 }
150 
151 unique_ptr<OperatorBase> _CreateOperator(
152  const OperatorDef& operator_def,
153  Workspace* ws) {
154  static StaticLinkingProtector g_protector;
155  const auto& op_type = operator_def.type();
156  const auto& device_type_proto = operator_def.device_option().device_type();
157  const auto& device_type =
158  ProtoToType(static_cast<DeviceTypeProto>(device_type_proto));
159 
160 #ifndef CAFFE2_NO_OPERATOR_SCHEMA
161  // first, check with OpSchema if the operator is legal.
162  auto* schema = OpSchemaRegistry::Schema(op_type);
163  if (schema) {
164  CAFFE_ENFORCE(
165  schema->Verify(operator_def),
166  "Operator def did not pass schema checking: ",
167  ProtoDebugString(operator_def));
168  } else {
169  // We would like to recommend every op to register its schema, so if there
170  // is not one, we print a LOG_ERROR. But we will still allow the operator
171  // to be constructed.
172  LOG(ERROR) << "Cannot find operator schema for " << op_type
173  << ". Will skip schema checking.";
174  }
175 #endif
176 
177  // second try engines specified in the operator_def and preferred engines
178  std::vector<std::string> engines{};
179  if (operator_def.engine().size()) {
180  const auto op_def_engines = split(',', operator_def.engine());
181  engines.insert(engines.end(), op_def_engines.begin(), op_def_engines.end());
182  }
183  if (!FLAGS_caffe2_disable_implicit_engine_preference &&
184  g_per_op_engine_pref().count(device_type) &&
185  g_per_op_engine_pref()[device_type].count(op_type)) {
186  const auto& preferred_engines =
187  g_per_op_engine_pref()[device_type][op_type];
188  VLOG(2) << "Inserting per-op engine preference: " << preferred_engines;
189  engines.insert(
190  engines.end(), preferred_engines.begin(), preferred_engines.end());
191  }
192  if (!FLAGS_caffe2_disable_implicit_engine_preference &&
193  g_global_engine_pref().count(device_type)) {
194  const auto& preferred_engines = g_global_engine_pref()[device_type];
195  VLOG(2) << "Inserting global engine preference: " << preferred_engines;
196  engines.insert(
197  engines.end(), preferred_engines.begin(), preferred_engines.end());
198  }
199  for (const auto& engine : engines) {
200  const std::string key = OpRegistryKey(op_type, engine);
201  VLOG(1) << "Trying to create operator " << op_type << " with engine "
202  << engine;
203  auto op = TryCreateOperator(key, operator_def, ws);
204  if (op) {
205  if (engine.size() <=
206  (unsigned)FLAGS_caffe2_operator_max_engine_name_length) {
207  op->annotate_engine(engine);
208  } else {
209  op->annotate_engine(
210  engine.substr(0, FLAGS_caffe2_operator_max_engine_name_length));
211  }
212  return op;
213  } else {
214  // If the above fails, we will just return the normal case with the
215  // default implementation.
216  VLOG(1) << "Engine " << engine
217  << " is not available for operator " << op_type << ".";
218  }
219  }
220  if (operator_def.engine().size() && !VLOG_IS_ON(1)) {
221  static int log_occurrences = 0;
222  if (log_occurrences <= 64) {
223  ++log_occurrences;
224  LOG(INFO) << "Engine " << operator_def.engine()
225  << " is not available for operator " << op_type << ".";
226  }
227  }
228  VLOG(1) << "Using default implementation.";
229 
230  // Lastly, if the engine does not work here, try using the default engine.
231  auto op = TryCreateOperator(op_type, operator_def, ws);
232  CAFFE_ENFORCE(
233  op,
234  "Cannot create operator of type '",
235  op_type,
236  "' on the device '",
237  DeviceTypeName(device_type),
238  "'. Verify that implementation for the corresponding device exist. It "
239  "might also happen if the binary is not linked with the operator "
240  "implementation code. If Python frontend is used it might happen if "
241  "dyndep.InitOpsLibrary call is missing. Operator def: ",
242  ProtoDebugString(operator_def));
243  return op;
244 }
245 
246 } // namespace
247 
248 const std::string OpRegistryKey(
249  const std::string& op_type,
250  const std::string& engine) {
251  if (engine == "" || engine == "DEFAULT") {
252  return op_type;
253  } else {
254  return op_type + "_ENGINE_" + engine;
255  }
256 }
257 
258 void SetPerOpEnginePref(const PerOpEnginePrefType& per_op_engine_pref) {
259  for (const auto& device_pref_pair : per_op_engine_pref) {
260  const auto& device_type = device_pref_pair.first;
261  CAFFE_ENFORCE(
262  gDeviceTypeRegistry()->count(device_type),
263  "Device type ",
264  device_type,
265  " not registered.");
266  auto* registry = gDeviceTypeRegistry()->at(device_type);
267 
268  for (const auto& op_pref_pair : device_pref_pair.second) {
269  const auto& op_type = op_pref_pair.first;
270  CAFFE_ENFORCE(
271  registry->Has(op_type),
272  "Operator type ",
273  op_type,
274  " not registered in ",
275  device_type,
276  " registry.");
277  }
278  }
279  g_per_op_engine_pref() = per_op_engine_pref;
280 }
281 
282 void SetGlobalEnginePref(const GlobalEnginePrefType& global_engine_pref) {
283  for (const auto& device_pref_pair : global_engine_pref) {
284  const auto& device_type = device_pref_pair.first;
285  CAFFE_ENFORCE(
286  gDeviceTypeRegistry()->count(device_type),
287  "Device type ",
288  device_type,
289  " not registered.");
290  }
291  g_global_engine_pref() = global_engine_pref;
292 }
293 
294 void SetEnginePref(
295  const PerOpEnginePrefType& per_op_engine_pref,
296  const GlobalEnginePrefType& global_engine_pref) {
297  SetPerOpEnginePref(per_op_engine_pref);
298  SetGlobalEnginePref(global_engine_pref);
299 }
300 
301 void SetOpEnginePref(
302  const std::string& op_type,
303  const CaffeMap<DeviceType, EnginePrefType>& op_pref) {
304  for (const auto& device_pref_pair : op_pref) {
305  const auto& device_type_proto = device_pref_pair.first;
306  const auto& device_type =
307  ProtoToType(static_cast<DeviceTypeProto>(device_type_proto));
308  CAFFE_ENFORCE(
309  gDeviceTypeRegistry()->count(device_type),
310  "Device type ",
311  device_type,
312  " not registered.");
313  CAFFE_ENFORCE(
314  gDeviceTypeRegistry()->at(device_type)->Has(op_type),
315  "Operator type ",
316  op_type,
317  " not registered in ",
318  device_type,
319  " registry.");
320  g_per_op_engine_pref()[device_type][op_type] = device_pref_pair.second;
321  }
322 }
323 
324 unique_ptr<OperatorBase> CreateOperator(
325  const OperatorDef& operator_def,
326  Workspace* ws,
327  int net_position) {
328  try {
329  auto op = _CreateOperator(operator_def, ws);
330  op->set_net_position(net_position);
331  return op;
332  } catch (...) {
333  if (net_position != 0) {
334  VLOG(1) << "Operator constructor with net position " << net_position
335  << " failed";
336  ws->last_failed_op_net_position = net_position;
337  } else {
338  VLOG(1) << "Failed operator constructor doesn't have an id set";
339  }
340  throw;
341  }
342 }
343 
344 std::map<DeviceType, OperatorRegistry*>* gDeviceTypeRegistry() {
345  static std::map<DeviceType, OperatorRegistry*> g_device_type_registry;
346  return &g_device_type_registry;
347 }
348 
349 C10_DEFINE_REGISTRY(
350  CPUOperatorRegistry,
351  OperatorBase,
352  const OperatorDef&,
353  Workspace*);
354 CAFFE_REGISTER_DEVICE_TYPE(CPU, CPUOperatorRegistry);
355 
356 C10_DEFINE_REGISTRY(
357  CUDAOperatorRegistry,
358  OperatorBase,
359  const OperatorDef&,
360  Workspace*);
361 CAFFE_REGISTER_DEVICE_TYPE(CUDA, CUDAOperatorRegistry);
362 
363 C10_DEFINE_REGISTRY(
364  HIPOperatorRegistry,
365  OperatorBase,
366  const OperatorDef&,
367  Workspace*);
368 CAFFE_REGISTER_DEVICE_TYPE(HIP, HIPOperatorRegistry);
369 
370 C10_DEFINE_REGISTRY(
371  GradientRegistry,
372  GradientMakerBase,
373  const OperatorDef&,
374  const vector<GradientWrapper>&);
375 
377  const OperatorDef& def, const vector<GradientWrapper>& g_output) {
378  std::unique_ptr<GradientMakerBase> maker(
379  GradientRegistry()->Create(def.type(), def, g_output));
380  CAFFE_ENFORCE(maker,
381  "Gradient maker for operator ", def.type(), " not implemented.");
382  GradientOpsMeta meta = maker->Get();
383  // Copy device option, engine, and arguments if needed.
384  if (maker->CopyDeviceOption() && def.has_device_option()) {
385  for (OperatorDef& grad_def : meta.ops_) {
386  grad_def.mutable_device_option()->CopyFrom(def.device_option());
387  }
388  }
389  // Copy engine if needed.
390  if (maker->CopyEngine() && def.has_engine()) {
391  for (OperatorDef& grad_def : meta.ops_) {
392  grad_def.set_engine(def.engine());
393  }
394  }
395  // Copy arguments if needed.
396  if (maker->CopyArguments() && def.arg_size()) {
397  for (OperatorDef& grad_def : meta.ops_) {
398  for (auto& arg : def.arg()) {
399  grad_def.add_arg()->CopyFrom(arg);
400  }
401  }
402  }
403  // VLOG for debugging purposes.
404  for (const OperatorDef& grad_def : meta.ops_) {
405  VLOG(1) << "Gradient ops: " << ProtoDebugString(grad_def);
406  }
407  // Check if the gradient computation has returned the right size for the
408  // gradient vector.
409  CAFFE_ENFORCE_EQ(meta.g_input_.size(), def.input_size());
410  VLOG(1) << "Gradients:";
411  for (const GradientWrapper& grad : meta.g_input_) {
412  // The gradient should either be (1) not set, or (2) dense, or (3) sparse,
413  // but cannot be both dense and sparse.
414  if (!grad.IsDense() && !grad.IsSparse()) {
415  VLOG(1) << "\t [no gradient]";
416  } else if (grad.IsDense()) {
417  VLOG(1) << "\t [dense]" << grad.dense_;
418  } else {
419  CAFFE_ENFORCE(
420  grad.indices_.size() && grad.values_.size(),
421  "For sparse gradient, one should set both indices and values. "
422  "Currently we have: (" +
423  grad.indices_ + ", " + grad.values_ + ").");
424  VLOG(1) << "\t [sparse] " << grad.indices_ << ", " << grad.values_;
425  }
426  }
427  return meta;
428 }
429 
430 TensorShapes InferBlobShapesAndTypes(
431  CaffeMap<string, TensorShape>& blob_desc,
432  const vector<NetDef*>& nets) {
433  for (auto& defptr : nets) {
434  // Hack to work with auto split gradients
435  CaffeMap<string, string> unmatched_sum_blobs;
436  CaffeMap<string, TensorShape> reshape_cache;
437 
438  for (const OperatorDef& op : defptr->op()) {
439  // Hack to ignore queues
440  if (op.type().find("Dequeue") != std::string::npos ||
441  op.type().find("Enqueue") != std::string::npos) {
442  continue;
443  }
444 
445  vector<TensorShape> input_desc;
446  bool found_all = true;
447  for (const string& in : op.input()) {
448  auto inp_desc = blob_desc.find(in);
449  if (inp_desc == blob_desc.end()) {
450  LOG(WARNING) << "Shape and type inference failed for input: " << in
451  << " for op " << op.type() << ", skipping.";
452  found_all = false;
453  break;
454  }
455  input_desc.push_back(inp_desc->second);
456  }
457  if (!found_all) {
458  continue;
459  }
460  auto op_schema = OpSchemaRegistry::Schema(op.type());
461  if (op_schema == nullptr) {
462  LOG(WARNING) << "Shape inference failed, no schema for: " << op.type();
463  continue;
464  }
465 
466  // Special handling for Sum as it used with the autosplits, which have
467  // different naming convention. Assuming that all sum inputs must be of
468  // same size, we can infer their shapes.
469  if (op.type() == "Sum") {
470  TensorShape sum_shape;
471  for (auto inp : op.input()) {
472  auto it = blob_desc.find(inp);
473  if (it != blob_desc.end() && !it->second.unknown_shape()) {
474  if (it->second.dims_size() > 0) {
475  sum_shape = blob_desc[inp];
476  break;
477  }
478  }
479  }
480  for (auto inp : op.input()) {
481  auto it = blob_desc.find(inp);
482  if (it == blob_desc.end() || it->second.unknown_shape()) {
483  blob_desc[inp] = sum_shape;
484  if (sum_shape.dims_size() == 0) {
485  // Match later with the output
486  unmatched_sum_blobs[inp] = op.output(0);
487  }
488  }
489  }
490  }
491 
492  if (op.type() == "Reshape" && op.is_gradient_op()) {
493  CAFFE_ENFORCE(reshape_cache.find(op.input(1)) != reshape_cache.end());
494  TensorShape cached = reshape_cache[op.input(1)];
495  blob_desc[op.output(0)] = cached;
496  continue;
497  }
498 
499  std::vector<TensorShape> out;
500  try {
501  out = op_schema->InferTensor(op, input_desc);
502  if (op.is_gradient_op() && out.size()) {
503  // Special handling for gradient ops. We can assume gradients
504  // are of same size as the corresponding variables. This is bit
505  // ugly to base on string matching, but we don't have the connection
506  // between variable and its gradient specified
507 
508  CaffeMap<string, string> grads_to_params =
509  GradientMakerBase::MatchGradsToParams(op);
510 
511  for (size_t i = 0; i < out.size(); i++) {
512  if (out[i].unknown_shape()) {
513  std::string gradout = op.output(i);
514 
515  if (grads_to_params.find(gradout) != grads_to_params.end()) {
516  std::string var = grads_to_params[gradout];
517  if (blob_desc.find(var) != blob_desc.end()) {
518  out[i] = blob_desc[var];
519  }
520  }
521  }
522  }
523  }
524 
525  if (op.type() == "Reshape") {
526  // Reshape stores the original input shape to its second output
527  // blob. We need this for gradient reshape.
528  reshape_cache[op.output(1)] = input_desc[0];
529  }
530 
531  } catch (::caffe2::EnforceNotMet& enf) {
532  LOG(ERROR) << "Shape inference error: " << enf.msg();
533  LOG(ERROR) << "Operator: " << ProtoDebugString(op) << std::endl;
534  LOG(ERROR) << "Returning empty results.";
535 
536  TensorShapes tps;
537  return tps;
538  }
539 
540  if (out.size() != (unsigned)op.output_size()) {
541  if (op.type() == "Slice") {
542  CAFFE_ENFORCE(
543  out.size() == 0,
544  "For Slice operator, either shape of all output blobs are "
545  "inferred or shape of none can be inferred.");
546  } else {
547  CAFFE_THROW(
548  "Invalid shape inference for operator ",
549  op.type(),
550  " Expected ",
551  op.output_size(),
552  " outputs, but got ",
553  out.size());
554  }
555  } else {
556  for (size_t i = 0; i < out.size(); i++) {
557  blob_desc[op.output(i)] = out[i];
558  }
559  }
560  } // net.ops
561 
562  for (auto& unmatched : unmatched_sum_blobs) {
563  if (blob_desc.find(unmatched.second) != blob_desc.end()) {
564  blob_desc[unmatched.first] = blob_desc[unmatched.second];
565  }
566  }
567 
568  } // nets
569  TensorShapes tps;
570  for (auto kv : blob_desc) {
571  TensorShape& tp = kv.second;
572  TensorShape* tpnew = tps.add_shapes();
573  tpnew->CopyFrom(tp);
574  tpnew->set_name(kv.first);
575  }
576  return tps;
577 }
578 
579 TensorShape GetTensorShapeOfBlob(const Blob* b) {
580  TypeCall type_fun = GetTypeCallFunction(b->meta().id());
581  TensorInfoCall tensor_info_fun = GetTensorInfoFunction(b->meta().id());
582  TensorShape tp;
583 
584  if (type_fun) {
585  tp.set_data_type(TypeMetaToDataType(type_fun(b->GetRaw())));
586  }
587  if (tensor_info_fun) {
588  size_t _capacity;
589  DeviceOption _device;
590  auto shape = tensor_info_fun(b->GetRaw(), &_capacity, &_device);
591  for (auto d : shape) {
592  tp.add_dims(d);
593  }
594  } else {
595  tp.set_unknown_shape(true);
596  }
597  return tp;
598 }
599 
600 TensorShapes InferBlobShapesAndTypesFromWorkspace(
601  Workspace* ws,
602  const vector<NetDef*>& nets) {
603  CaffeMap<string, TensorShape> blob_desc;
604  // Populate shapes from workplace
605  const std::vector<string>& ws_blobs = ws->Blobs();
606  for (const auto& s : ws_blobs) {
607  Blob* b = ws->GetBlob(s);
608  TensorShape tp = GetTensorShapeOfBlob(b);
609  blob_desc[s] = tp;
610  }
611  return InferBlobShapesAndTypes(blob_desc, nets);
612 }
613 
614 TensorShapes InferBlobShapesAndTypesFromMap(
615  const CaffeMap<std::string, std::vector<int64_t>>& blob_dimensions,
616  const vector<NetDef*>& nets) {
617  CaffeMap<string, TensorShape> blob_desc;
618  // Populate shapes from known blobs
619  for (const auto& blob : blob_dimensions) {
620  TensorShape tp;
621  for (auto d : blob.second) {
622  CAFFE_ENFORCE_GE(d, 0, blob.first);
623  tp.add_dims(d);
624  }
625  blob_desc[blob.first] = tp;
626  }
627  return InferBlobShapesAndTypes(blob_desc, nets);
628 }
629 
630 TensorShapes InferBlobShapesAndTypesFromMap(
631  const CaffeMap<std::string, std::vector<int64_t>>& blob_dimensions,
632  const CaffeMap<std::string, TensorProto_DataType>& blob_types,
633  const vector<NetDef*>& nets) {
634  CaffeMap<string, TensorShape> blob_desc;
635  // Populate shapes from known blobs
636  for (const auto& blob : blob_dimensions) {
637  TensorShape tp;
638  for (auto d : blob.second) {
639  CAFFE_ENFORCE_GE(d, 0, blob.first);
640  tp.add_dims(d);
641  }
642  auto blob_type = blob_types.find(blob.first);
643  if (blob_type == blob_types.end()) {
644  LOG(WARNING) << "Missing type of " << blob.first
645  << "; assuming to be UNDEFINED";
646  tp.set_data_type(TensorProto_DataType_UNDEFINED);
647  } else {
648  tp.set_data_type(blob_type->second);
649  }
650  blob_desc[blob.first] = tp;
651  }
652  return InferBlobShapesAndTypes(blob_desc, nets);
653 }
654 
655 std::map<string, std::pair<DeviceOption, DeviceOption>> ValidateTensorDevices(
656  OperatorBase& op,
657  const OperatorDef& op_def) {
658  std::map<string, std::pair<DeviceOption, DeviceOption>> mismatches;
659  DeviceOption op_device = op_def.device_option();
660 
661 #ifndef CAFFE2_NO_OPERATOR_SCHEMA
662  // Check from op schema if this op is used for crossing devices
663  auto op_schema = OpSchemaRegistry::Schema(op_def.type());
664  if (op_schema != nullptr) {
665  if (op_schema->inputs_can_cross_devices()) {
666  return mismatches;
667  }
668  }
669 #endif // CAFFE2_NO_OPERATOR_SCHEMA
670 
671  auto Check = [&](const Blob& blob, std::string blob_name) {
672  TensorInfoCall tensor_info_fun = GetTensorInfoFunction(blob.meta().id());
673  if (tensor_info_fun) {
674  size_t _capacity;
675  DeviceOption blob_device;
676  tensor_info_fun(
677  const_cast<Blob&>(blob).GetRaw(),
678  &_capacity,
679  &blob_device);
680 
681  if ((blob_device.device_type() == PROTO_CUDA ||
682  blob_device.device_type() == PROTO_HIP) &&
683  blob_device.device_id() != op_device.device_id()) {
684  mismatches[blob_name] = std::make_pair(op_device, blob_device);
685  }
686  }
687  };
688 
689  // Check that inputs have same device type as the op
690  for (int i = 0; i < op.InputSize(); i++) {
691  Check(op.InputBlob(i), op_def.input(i));
692  }
693  for (int i = 0; i < op.OutputSize(); i++) {
694  Check(*op.OutputBlob(i), op_def.output(i));
695  }
696  return mismatches;
697 }
698 
699 std::set<std::string> GetRegisteredOperators() {
700  std::set<std::string> all_keys;
701 
702  // CPU operators
703  for (const auto& name : CPUOperatorRegistry()->Keys()) {
704  all_keys.emplace(name);
705  }
706  // CUDA operators
707  for (const auto& name : CUDAOperatorRegistry()->Keys()) {
708  all_keys.emplace(name);
709  }
710 
711  // HIP operators
712  for (const auto& name : HIPOperatorRegistry()->Keys()) {
713  all_keys.emplace(name);
714  }
715 
716  return all_keys;
717 }
718 
719 static std::function<void(const OperatorDef&)> OperatorLogger =
720  [](const OperatorDef&) { return; };
721 
722 void SetOperatorLogger(std::function<void(const OperatorDef&)> tracer) {
723  OperatorLogger = tracer;
724 }
725 
726 std::function<void(const OperatorDef&)> GetOperatorLogger() {
727  return OperatorLogger;
728 }
729 
730 } // namespace caffe2
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
GradientOpsMeta GetGradientForOp(const OperatorDef &def, const vector< GradientWrapper > &g_output)
Gets the GradientOpsMeta for the given operator def.
Definition: operator.cc:376
A struct that holds the gradient operators and related gradient maps.
The primary ATen error class.
Definition: Exception.h:27
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:160
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
const TypeMeta & meta() const noexcept
Returns the meta info of the blob.
Definition: blob.h:54
constexpr TypeIdentifier id() const noexcept
Returns the type id.
Definition: typeid.h:359
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
Flush-To-Zero and Denormals-Are-Zero mode.
vector< string > Blobs() const
Return a list of blob names.
Definition: workspace.cc:80