Caffe2 - C++ API
A deep learning, cross platform ML framework
blob.h
1 #ifndef CAFFE2_CORE_BLOB_H_
2 #define CAFFE2_CORE_BLOB_H_
3 
4 #include <cstddef>
5 #include <sstream>
6 #include <typeinfo>
7 #include <type_traits>
8 #include <vector>
9 #include "caffe2/core/common.h"
10 
11 #include <ATen/core/blob.h>
12 #include <c10/util/typeid.h>
13 #include "caffe2/core/logging.h"
14 #include "caffe2/core/tensor.h"
15 
16 namespace caffe2 {
17 
18 inline bool BlobIsTensorType(const Blob& blob, DeviceType device_type) {
19  bool is_match = blob.meta().Match<Tensor>();
20  if (!is_match) {
21  return false;
22  }
23  const Tensor* tensor = &blob.Get<Tensor>();
24  return tensor && *tensor && tensor->GetDeviceType() == device_type;
25 }
26 
27 inline Tensor* BlobSetTensor(Blob* blob, Tensor&& tensor) {
28  return blob->Reset<Tensor>(new Tensor(std::move(tensor)));
29 }
30 
31 inline Tensor GetSizedTensorWithOptions(
32  Tensor&& previous_tensor,
33  at::IntArrayRef dims,
34  at::TensorOptions options) {
35  Tensor tensor = std::move(previous_tensor);
36  if (!tensor.defined()) {
37  return caffe2::empty(dims, options);
38  }
39  if (tensor.GetDevice() == options.device() ||
40  (!tensor.GetDevice().has_index() &&
41  tensor.GetDeviceType() == options.device().type())) {
42  if (tensor.sizes() != dims) {
43  // Resize when the dims doesn't match
44  tensor.Resize(dims);
45  }
46  if (tensor.dtype() == options.dtype()) {
47  tensor.raw_mutable_data();
48  } else {
49  // create a new Tensor when the data_type doesn't match
50  return caffe2::empty(dims, options);
51  }
52  return tensor;
53  }
54  return caffe2::empty(dims, options);
55 }
56 
57 // need to keep both functions that returns Tensor* and the one
58 // returns Tensor for clangr codemod
59 inline Tensor*
60 BlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) {
61  if (blob->IsType<Tensor>()) {
62  Tensor* tensor = blob->GetMutable<Tensor>();
63  if (*tensor) {
64  // We only compare device_type if the index is not set since there are Tensors
65  // TODO: remove the extra check when all the Tensors are properly initialized
66  if (tensor->GetDevice() == options.device() || (!tensor->GetDevice().has_index() && tensor->GetDeviceType() == options.device().type())) {
67  if (tensor->sizes() != dims) {
68  // Resize when the dims doesn't match
69  tensor->Resize(dims);
70  }
71  if (tensor->dtype() == options.dtype()) {
72  tensor->raw_mutable_data();
73  } else {
74  tensor->raw_mutable_data(options.dtype());
75  }
76  return tensor;
77  }
78  // create a new Tensor when device doesn't match
79  }
80  }
81 
82  VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
83  << " dims: " << dims;
84  // << " options: " << options; (operator<< for Options is in at:: now)
85  return BlobSetTensor(blob, caffe2::empty(dims, options));
86 }
87 
88 inline Tensor
89 XBlobGetMutableTensor(Blob* blob, at::IntArrayRef dims, at::TensorOptions options) {
90  return BlobGetMutableTensor(blob, dims, options)->UnsafeSharedInstance();
91 }
92 
93 inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) {
94  if (blob->IsType<Tensor>()) {
95  Tensor* tensor = blob->GetMutable<Tensor>();
96  if (*tensor && tensor->GetDeviceType() == device_type) {
97  return tensor;
98  }
99  }
100 
101  // if we're here, then either Blob didn't hold a Tensor
102  // or that Tensor had the wrong DeviceType.
103  VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<Tensor>()
104  << " DeviceType:" << device_type;
105 
106  return BlobSetTensor(blob, Tensor(device_type));
107 }
108 
109 inline const Tensor& BlobGetTensor(const Blob& blob, DeviceType device_type) {
110  if (blob.IsType<Tensor>()) {
111  const auto& tensor = blob.Get<Tensor>();
112  if (tensor.GetDeviceType() == device_type) {
113  return tensor;
114  }
115  }
116  CAFFE_THROW("Blob didn't contain a Tensor or the device_type doesn't match");
117 }
118 
119 inline Tensor BlobGetTensorOrUndefined(const Blob& blob) {
120  if (blob.IsType<Tensor>()) {
121  return blob.Get<Tensor>().UnsafeSharedInstance();
122  } else {
123  return Tensor();
124  }
125 }
126 
127 } // namespace caffe2
128 #endif // CAFFE2_CORE_BLOB_H_
C10_NODISCARD TensorOptions device(c10::optional< Device > device) const noexcept
Return a copy of TensorOptions with device set to the given one, or cleared if device is nullopt...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
C10_NODISCARD TensorOptions dtype(c10::optional< caffe2::TypeMeta > dtype) const noexcept
Return a copy of TensorOptions with dtype set to the given one.