Caffe2 - C++ API
A deep learning, cross platform ML framework
assert_op.h
1 
16 #ifndef CAFFE2_OPERATORS_ASSERT_OP_H_
17 #define CAFFE2_OPERATORS_ASSERT_OP_H_
18 
19 #include "caffe2/core/context.h"
20 #include "caffe2/core/operator.h"
21 
22 namespace caffe2 {
23 
24 template <class Context>
25 class AssertOp final : public Operator<Context> {
26  public:
27  AssertOp(const OperatorDef& operator_def, Workspace* ws)
28  : Operator<Context>(operator_def, ws),
29  error_msg_(
30  OperatorBase::GetSingleArgument<std::string>("error_msg", "")) {}
31 
32  USE_OPERATOR_CONTEXT_FUNCTIONS;
33 
34  template <typename T>
35  bool DoRunWithType() {
36  // Copy into CPU context for comparison
37  cmp_tensor_.CopyFrom(Input(0));
38  auto* cmp_data = cmp_tensor_.template data<T>();
39 
40  for (TIndex i = 0; i < cmp_tensor_.size(); ++i) {
41  CAFFE_ENFORCE((bool)cmp_data[i], [&]() {
42  std::stringstream ss;
43  ss << "Assert failed for element " << i
44  << " in tensor, value: " << cmp_data[i] << "\n";
45  if (!error_msg_.empty()) {
46  ss << "Error message: " << error_msg_;
47  }
48  return ss.str();
49  }());
50  }
51  return true;
52  }
53 
54  bool RunOnDevice() override {
55  return DispatchHelper<TensorTypes<long, int, bool>>::call(this, Input(0));
56  }
57 
58  private:
59  TensorCPU cmp_tensor_;
60  std::string error_msg_;
61 };
62 
63 } // namespace caffe2
64 
65 #endif /* CAFFE2_OPERATORS_ASSERT_OP_H_ */
void CopyFrom(const Tensor< SrcContext > &src, ContextForCopy *context)
Copies the data from a source tensor, with a contex provided to carry out the underlying memcpy opera...
Definition: tensor.h:182
TIndex size() const
Returns the size (i.e.
Definition: tensor.h:609
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.