1 #ifndef CAFFE2_CORE_BLOB_H_ 2 #define CAFFE2_CORE_BLOB_H_ 9 #include "caffe2/core/common.h" 11 #include <ATen/core/blob.h> 12 #include <c10/util/typeid.h> 13 #include "caffe2/core/logging.h" 14 #include "caffe2/core/tensor.h" 18 inline bool BlobIsTensorType(
const Blob& blob, DeviceType device_type) {
19 bool is_match = blob.meta().Match<
Tensor>();
24 return tensor && *tensor && tensor->GetDeviceType() == device_type;
27 inline Tensor* BlobSetTensor(Blob* blob,
Tensor&& tensor) {
28 return blob->Reset<
Tensor>(
new Tensor(std::move(tensor)));
31 inline Tensor GetSizedTensorWithOptions(
35 Tensor tensor = std::move(previous_tensor);
36 if (!tensor.defined()) {
37 return caffe2::empty(dims, options);
39 if (tensor.GetDevice() == options.
device() ||
40 (!tensor.GetDevice().has_index() &&
41 tensor.GetDeviceType() == options.
device().type())) {
42 if (tensor.sizes() != dims) {
46 if (tensor.dtype() == options.
dtype()) {
47 tensor.raw_mutable_data();
50 return caffe2::empty(dims, options);
54 return caffe2::empty(dims, options);
61 if (blob->IsType<
Tensor>()) {
66 if (tensor->GetDevice() == options.
device() || (!tensor->GetDevice().has_index() && tensor->GetDeviceType() == options.
device().type())) {
67 if (tensor->sizes() != dims) {
71 if (tensor->dtype() == options.
dtype()) {
72 tensor->raw_mutable_data();
74 tensor->raw_mutable_data(options.
dtype());
82 VLOG(1) <<
"Create new mutable object " << TypeMeta::TypeName<Tensor>()
85 return BlobSetTensor(blob, caffe2::empty(dims, options));
90 return BlobGetMutableTensor(blob, dims, options)->UnsafeSharedInstance();
93 inline Tensor* BlobGetMutableTensor(Blob* blob, DeviceType device_type) {
94 if (blob->IsType<
Tensor>()) {
96 if (*tensor && tensor->GetDeviceType() == device_type) {
103 VLOG(1) <<
"Create new mutable object " << TypeMeta::TypeName<Tensor>()
104 <<
" DeviceType:" << device_type;
106 return BlobSetTensor(blob,
Tensor(device_type));
109 inline const Tensor& BlobGetTensor(
const Blob& blob, DeviceType device_type) {
110 if (blob.IsType<
Tensor>()) {
111 const auto& tensor = blob.Get<
Tensor>();
112 if (tensor.GetDeviceType() == device_type) {
116 CAFFE_THROW(
"Blob didn't contain a Tensor or the device_type doesn't match");
119 inline Tensor BlobGetTensorOrUndefined(
const Blob& blob) {
120 if (blob.IsType<
Tensor>()) {
121 return blob.Get<
Tensor>().UnsafeSharedInstance();
128 #endif // CAFFE2_CORE_BLOB_H_ C10_NODISCARD TensorOptions device(c10::optional< Device > device) const noexcept
Return a copy of TensorOptions with device set to the given one, or cleared if device is nullopt...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
C10_NODISCARD TensorOptions dtype(c10::optional< caffe2::TypeMeta > dtype) const noexcept
Return a copy of TensorOptions with dtype set to the given one.