1 #include "caffe2/core/tensor.h" 2 #include "caffe2/core/workspace.h" 4 #include "test_utils.h" 8 void assertTensorEqualsWithType(
11 CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
12 for (
auto idx = 0; idx < tensor1.numel(); ++idx) {
13 CAFFE_ENFORCE_EQ(tensor1.data<
T>()[idx], tensor2.data<
T>()[idx]);
22 void assertNear(
float value1,
float value2,
float epsilon) {
24 CAFFE_ENFORCE_LE(value1, value2 + epsilon);
25 CAFFE_ENFORCE_GE(value1, value2 - epsilon);
28 void assertTensorEquals(
const TensorCPU& tensor1,
const TensorCPU& tensor2) {
29 CAFFE_ENFORCE_EQ(tensor1.sizes(), tensor2.sizes());
30 if (tensor1.IsType<
float>()) {
31 CAFFE_ENFORCE(tensor2.IsType<
float>());
32 assertTensorEqualsWithType<float>(tensor1, tensor2);
33 }
else if (tensor1.IsType<
int>()) {
34 CAFFE_ENFORCE(tensor2.IsType<
int>());
35 assertTensorEqualsWithType<int>(tensor1, tensor2);
36 }
else if (tensor1.IsType<int64_t>()) {
37 CAFFE_ENFORCE(tensor2.IsType<int64_t>());
38 assertTensorEqualsWithType<int64_t>(tensor1, tensor2);
43 void assertTensorListEquals(
44 const std::vector<std::string>& tensorNames,
45 const Workspace& workspace1,
46 const Workspace& workspace2) {
47 for (
const std::string& tensorName : tensorNames) {
48 CAFFE_ENFORCE(workspace1.HasBlob(tensorName));
49 CAFFE_ENFORCE(workspace2.HasBlob(tensorName));
50 auto& tensor1 = getTensor(workspace1, tensorName);
51 auto& tensor2 = getTensor(workspace2, tensorName);
52 assertTensorEquals(tensor1, tensor2);
58 const std::string& name) {
59 CAFFE_ENFORCE(workspace.
HasBlob(name));
64 const std::string& name,
66 return BlobGetMutableTensor(workspace->
CreateBlob(name), caffe2::CPU);
69 caffe2::OperatorDef* createOperator(
70 const std::string& type,
71 const std::vector<std::string>& inputs,
72 const std::vector<std::string>& outputs,
73 caffe2::NetDef* net) {
74 auto* op = net->add_op();
76 for (
const auto& in : inputs) {
79 for (
const auto& out : outputs) {
85 NetMutator& NetMutator::newOp(
86 const std::string& type,
87 const std::vector<std::string>& inputs,
88 const std::vector<std::string>& outputs) {
89 lastCreatedOp_ = createOperator(type, inputs, outputs, net_);
93 NetMutator& NetMutator::externalInputs(
94 const std::vector<std::string>& externalInputs) {
95 for (
auto& blob : externalInputs) {
96 net_->add_external_input(blob);
101 NetMutator& NetMutator::externalOutputs(
102 const std::vector<std::string>& externalOutputs) {
103 for (
auto& blob : externalOutputs) {
104 net_->add_external_output(blob);
109 NetMutator& NetMutator::setDeviceOptionName(
const std::string& name) {
110 CAFFE_ENFORCE(lastCreatedOp_ !=
nullptr);
111 lastCreatedOp_->mutable_device_option()->set_node_name(name);
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasBlob(const string &name) const
Checks if a blob with the given name is present in the current workspace.
const T & Get() const
Gets the const reference of the stored object.