Caffe2 - C++ API
A deep learning, cross platform ML framework
input_metadata.h
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 
5 #include <cstdint>
6 
7 namespace torch { namespace autograd {
8 
12 struct InputMetadata {
13  InputMetadata() = default;
14 
15  InputMetadata(const at::Type& type, at::IntArrayRef shape, at::Device device)
16  : type_{&type} , shape_{shape}, device_{device} { }
17 
18  InputMetadata(const at::Tensor& t)
19  : InputMetadata(t.type(), t.sizes(), t.device()) { }
20 
21  bool is_valid() const {
22  return type_ != nullptr;
23  }
24 
25  const at::Type& type() const {
26  AT_ASSERT(type_);
27  return *type_;
28  }
29 
30  at::IntArrayRef shape() const {
31  return shape_;
32  }
33 
34  at::Device device() const {
35  return device_;
36  }
37 
38  at::Tensor zeros_like() const {
39  return at::zeros(shape_, type_->options(device_));
40  }
41 
42 private:
43  const at::Type* type_ = nullptr;
44  at::DimVector shape_;
45  at::Device device_ = at::kCPU;
46 };
47 
48 }}
A tensor&#39;s type and shape.
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Device device() const
Returns a Tensor&#39;s device.
Definition: jit_type.h:17