Caffe2 - C++ API
A deep learning, cross platform ML framework
test_utils.cc
1 #include "caffe2/core/tensor.h"
2 #include "caffe2/core/workspace.h"
3 
4 #include "test_utils.h"
5 
6 namespace {
7 template <typename T>
8 void assertTensorEqualsWithType(
9  const caffe2::TensorCPU& tensor1,
10  const caffe2::TensorCPU& tensor2) {
11  CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
12  for (auto idx = 0; idx < tensor1.numel(); ++idx) {
13  CAFFE_ENFORCE_EQ(tensor1.data<T>()[idx], tensor2.data<T>()[idx]);
14  }
15 }
16 } // namespace
17 
18 namespace caffe2 {
19 namespace testing {
20 
21 // Asserts that two float values are close within epsilon.
22 void assertNear(float value1, float value2, float epsilon) {
23  // These two enforces will give good debug messages.
24  CAFFE_ENFORCE_LE(value1, value2 + epsilon);
25  CAFFE_ENFORCE_GE(value1, value2 - epsilon);
26 }
27 
28 void assertTensorEquals(const TensorCPU& tensor1, const TensorCPU& tensor2) {
29  CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
30  if (tensor1.IsType<float>()) {
31  CAFFE_ENFORCE(tensor2.IsType<float>());
32  assertTensorEqualsWithType<float>(tensor1, tensor2);
33  } else if (tensor1.IsType<int>()) {
34  CAFFE_ENFORCE(tensor2.IsType<int>());
35  assertTensorEqualsWithType<int>(tensor1, tensor2);
36  } else if (tensor1.IsType<int64_t>()) {
37  CAFFE_ENFORCE(tensor2.IsType<int64_t>());
38  assertTensorEqualsWithType<int64_t>(tensor1, tensor2);
39  }
40  // Add more types if needed.
41 }
42 
43 void assertTensorListEquals(
44  const std::vector<std::string>& tensorNames,
45  const Workspace& workspace1,
46  const Workspace& workspace2) {
47  for (const std::string& tensorName : tensorNames) {
48  CAFFE_ENFORCE(workspace1.HasBlob(tensorName));
49  CAFFE_ENFORCE(workspace2.HasBlob(tensorName));
50  auto& tensor1 = getTensor(workspace1, tensorName);
51  auto& tensor2 = getTensor(workspace2, tensorName);
52  assertTensorEquals(tensor1, tensor2);
53  }
54 }
55 
56 const caffe2::Tensor& getTensor(
57  const caffe2::Workspace& workspace,
58  const std::string& name) {
59  CAFFE_ENFORCE(workspace.HasBlob(name));
60  return workspace.GetBlob(name)->Get<caffe2::Tensor>();
61 }
62 
63 caffe2::Tensor* createTensor(
64  const std::string& name,
65  caffe2::Workspace* workspace) {
66  return BlobGetMutableTensor(workspace->CreateBlob(name), caffe2::CPU);
67 }
68 
69 caffe2::OperatorDef* createOperator(
70  const std::string& type,
71  const std::vector<std::string>& inputs,
72  const std::vector<std::string>& outputs,
73  caffe2::NetDef* net) {
74  auto* op = net->add_op();
75  op->set_type(type);
76  for (const auto& in : inputs) {
77  op->add_input(in);
78  }
79  for (const auto& out : outputs) {
80  op->add_output(out);
81  }
82  return op;
83 }
84 
85 NetMutator& NetMutator::newOp(
86  const std::string& type,
87  const std::vector<std::string>& inputs,
88  const std::vector<std::string>& outputs) {
89  lastCreatedOp_ = createOperator(type, inputs, outputs, net_);
90  return *this;
91 }
92 
93 NetMutator& NetMutator::externalInputs(
94  const std::vector<std::string>& externalInputs) {
95  for (auto& blob : externalInputs) {
96  net_->add_external_input(blob);
97  }
98  return *this;
99 }
100 
101 NetMutator& NetMutator::externalOutputs(
102  const std::vector<std::string>& externalOutputs) {
103  for (auto& blob : externalOutputs) {
104  net_->add_external_output(blob);
105  }
106  return *this;
107 }
108 
109 NetMutator& NetMutator::setDeviceOptionName(const std::string& name) {
110  CAFFE_ENFORCE(lastCreatedOp_ != nullptr);
111  lastCreatedOp_->mutable_device_option()->set_node_name(name);
112  return *this;
113 }
114 
115 } // namespace testing
116 } // namespace caffe2
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Definition: workspace.cc:100
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
Definition: workspace.cc:160
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
Definition: workspace.h:179
const T & Get() const
Gets the const reference of the stored object.
Definition: blob.h:71