Caffe2 - C++ API
A deep learning, cross platform ML framework
partition_desc.h
1 #pragma once
2 
3 #include <torch/csrc/WindowsTorchApiMacro.h>
4 #include <c10/util/Exception.h>
5 #include <torch/csrc/jit/fuser/tensor_desc.h>
6 
7 #include <cstdint>
8 #include <memory>
9 #include <vector>
10 
11 namespace torch {
12 namespace jit {
13 namespace fuser {
14 
15 // Descriptor for chunk-ing an input tensor into subtensors
16 // OR concat-ing an output tensor from subtensors
17 // Note: default constructed used for tensors that do not participate in
18 // chunk or cat operations.
19 struct TORCH_API PartitionDesc {
20  PartitionDesc() : nSubTensors_{1}, dim_{0} {}
21 
22  PartitionDesc(const TensorDesc& _desc, size_t _nSubTensors, size_t _dim)
23  : nSubTensors_{_nSubTensors}, dim_{_dim} {
24  AT_ASSERT(nSubTensors_ > 1);
25  std::vector<bool> cont = _desc.contiguity;
26  if (dim_ > 0) {
27  // when we narrow the concatenated output/chunked input
28  // we make the size[dim] smaller while keeping the stride[dim] the same,
29  // meaning: stride[dim - 1] != stride[dim]*size[dim]
30  // so dim - 1 is no longer contiguous
31  cont[dim_ - 1] = false;
32  }
33  subTensorDesc_.reset(new TensorDesc(_desc.scalar_type, cont));
34  }
35 
36  bool isNoop() const {
37  return (nSubTensors_ == 1);
38  }
39  size_t nSubTensors() const {
40  return nSubTensors_;
41  }
42  size_t dim() const {
43  return dim_;
44  }
45  std::shared_ptr<TensorDesc> subTensorDesc() {
46  return subTensorDesc_;
47  }
48  const std::shared_ptr<TensorDesc> subTensorDesc() const {
49  return subTensorDesc_;
50  }
51 
52  private:
53  size_t nSubTensors_; // == 1 for tensors that should not be operated on via
54  // chunk/cat
55  size_t dim_; // dimension along which the chunk/concat occurs
56  std::shared_ptr<TensorDesc>
57  subTensorDesc_; // descriptor for the subtensor, if it exists
58 };
59 
60 } // namespace fuser
61 } // namespace jit
62 } // namespace torch
Definition: jit_type.h:17