Caffe2 - C++ API
A deep learning, cross platform ML framework
Related Pages
Modules
Data Structures
Files
C++ API
Python API
GitHub
File List
Globals
torch
csrc
autograd
input_metadata.h
1
#pragma once
2
3
#include <ATen/ATen.h>
4
5
#include <cstdint>
6
7
namespace
torch
{
namespace
autograd {
8
12
struct
InputMetadata
{
13
InputMetadata
() =
default
;
14
15
InputMetadata
(
const
at::Type
& type,
at::IntArrayRef
shape,
at::Device
device)
16
: type_{&type} , shape_{shape}, device_{device} { }
17
18
InputMetadata
(
const
at::Tensor
& t)
19
:
InputMetadata
(t.type(), t.sizes(), t.
device
()) { }
20
21
bool
is_valid()
const
{
22
return
type_ !=
nullptr
;
23
}
24
25
const
at::Type
& type()
const
{
26
AT_ASSERT(type_);
27
return
*type_;
28
}
29
30
at::IntArrayRef
shape()
const
{
31
return
shape_;
32
}
33
34
at::Device
device()
const
{
35
return
device_;
36
}
37
38
at::Tensor
zeros_like()
const
{
39
return
at::zeros(shape_, type_->options(device_));
40
}
41
42
private
:
43
const
at::Type
* type_ =
nullptr
;
44
at::DimVector
shape_;
45
at::Device
device_ = at::kCPU;
46
};
47
48
}}
c10::SmallVector< int64_t, 5 >
at::Tensor
Definition:
Tensor.h:48
torch::autograd::InputMetadata
A tensor's type and shape.
Definition:
input_metadata.h:12
c10::Type
Definition:
jit_type.h:65
c10::Device
Represents a a compute device on which a tensor is located.
Definition:
Device.h:30
at::Tensor::device
Device device() const
Returns a Tensor's device.
Definition:
TensorMethods.h:1295
torch
Definition:
jit_type.h:17
c10::ArrayRef< int64_t >
Generated on Thu Mar 21 2019 13:06:22 for Caffe2 - C++ API by
1.8.11