Caffe2 - C++ API
A deep learning, cross platform ML framework
assert_op.h
1 #ifndef CAFFE2_OPERATORS_ASSERT_OP_H_
2 #define CAFFE2_OPERATORS_ASSERT_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 
7 namespace caffe2 {
8 
9 template <class Context>
10 class AssertOp final : public Operator<Context> {
11  public:
12  template <class... Args>
13  explicit AssertOp(Args&&... args)
14  : Operator<Context>(std::forward<Args>(args)...),
15  error_msg_(
16  this->template GetSingleArgument<std::string>("error_msg", "")) {}
17 
18  USE_OPERATOR_CONTEXT_FUNCTIONS;
19 
20  template <typename T>
21  bool DoRunWithType() {
22  // Copy into CPU context for comparison
23  cmp_tensor_.CopyFrom(Input(0));
24  auto* cmp_data = cmp_tensor_.template data<T>();
25 
26  for (int64_t i = 0; i < cmp_tensor_.numel(); ++i) {
27  CAFFE_ENFORCE((bool)cmp_data[i], [&]() {
28  std::stringstream ss;
29  ss << "Assert failed for element " << i
30  << " in tensor, value: " << cmp_data[i] << "\n";
31  if (!error_msg_.empty()) {
32  ss << "Error message: " << error_msg_;
33  }
34  return ss.str();
35  }());
36  }
37  return true;
38  }
39 
40  bool RunOnDevice() override {
41  return DispatchHelper<TensorTypes<long, int, bool>>::call(this, Input(0));
42  }
43 
44  private:
45  Tensor cmp_tensor_{CPU};
46  std::string error_msg_;
47 };
48 
49 } // namespace caffe2
50 
51 #endif /* CAFFE2_OPERATORS_ASSERT_OP_H_ */
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13