1 #ifndef CAFFE2_UTILS_TEST_UTILS_H_ 2 #define CAFFE2_UTILS_TEST_UTILS_H_ 4 #include "caffe2/core/tensor.h" 5 #include "caffe2/core/workspace.h" 6 #include "caffe2/utils/proto_utils.h" 19 void assertTensorEquals(
const TensorCPU& tensor1,
const TensorCPU& tensor2);
22 void assertNear(
float value1,
float value2,
float epsilon);
26 void assertTensorEquals(
27 const TensorCPU& tensor,
28 const std::vector<T>& data,
29 float epsilon = 0.1f) {
30 CAFFE_ENFORCE(tensor.IsType<
T>());
31 CAFFE_ENFORCE_EQ(tensor.numel(), data.size());
32 for (
auto idx = 0; idx < tensor.numel(); ++idx) {
33 if (tensor.IsType<
float>()) {
34 assertNear(tensor.data<
T>()[idx], data[idx], epsilon);
36 CAFFE_ENFORCE_EQ(tensor.data<
T>()[idx], data[idx]);
44 const TensorCPU& tensor,
45 const std::vector<int64_t>& sizes,
46 const std::vector<T>& data,
47 float epsilon = 0.1f) {
48 CAFFE_ENFORCE_EQ(tensor.sizes(), sizes);
49 assertTensorEquals(tensor, data, epsilon);
53 void assertTensorListEquals(
54 const std::vector<std::string>& tensorNames,
55 const Workspace& workspace1,
56 const Workspace& workspace2);
61 const std::string& name);
65 const std::string& name,
69 caffe2::OperatorDef* createOperator(
70 const std::string& type,
71 const std::vector<std::string>& inputs,
72 const std::vector<std::string>& outputs,
78 const std::vector<int64_t>& shape,
79 const std::vector<T>& data,
81 tensor->Resize(shape);
82 CAFFE_ENFORCE_EQ(data.size(), tensor->numel());
83 auto ptr = tensor->mutable_data<
T>();
84 for (
int i = 0; i < tensor->numel(); ++i) {
92 const std::string& name,
93 const std::vector<int64_t>& shape,
94 const std::vector<T>& data,
95 Workspace* workspace) {
96 auto* tensor = createTensor(name, workspace);
97 fillTensor<T>(shape, data, tensor);
102 template <
typename T>
103 void constantFillTensor(
104 const vector<int64_t>& shape,
107 tensor->Resize(shape);
108 auto ptr = tensor->mutable_data<
T>();
109 for (
int i = 0; i < tensor->numel(); ++i) {
115 template <
typename T>
117 const std::string& name,
118 const std::vector<int64_t>& shape,
120 Workspace* workspace) {
121 auto* tensor = createTensor(name, workspace);
122 constantFillTensor<T>(shape, data, tensor);
129 explicit NetMutator(caffe2::NetDef* net) : net_(net) {}
132 const std::string& type,
133 const std::vector<std::string>& inputs,
134 const std::vector<std::string>& outputs);
136 NetMutator& externalInputs(
const std::vector<std::string>& externalInputs);
138 NetMutator& externalOutputs(
const std::vector<std::string>& externalOutputs);
141 template <
typename T>
142 NetMutator& addArgument(
const std::string& name,
const T& value) {
143 CAFFE_ENFORCE(lastCreatedOp_ !=
nullptr);
144 AddArgument(name, value, lastCreatedOp_);
149 NetMutator& setDeviceOptionName(
const std::string& name);
152 caffe2::NetDef* net_;
153 caffe2::OperatorDef* lastCreatedOp_;
160 : workspace_(workspace) {}
163 template <
typename T>
165 const std::string& name,
166 const std::vector<int64_t>& shape,
167 const std::vector<T>& data) {
168 createTensorAndFill<T>(name, shape, data, workspace_);
173 template <
typename T>
175 const std::string& name,
176 const std::vector<int64_t>& shape,
178 createTensorAndConstantFill<T>(name, shape, data, workspace_);
189 #endif // CAFFE2_UTILS_TEST_UTILS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...