3 #include <torch/csrc/WindowsTorchApiMacro.h> 4 #include <c10/util/Exception.h> 5 #include <torch/csrc/jit/fuser/tensor_desc.h> 22 PartitionDesc(
const TensorDesc& _desc,
size_t _nSubTensors,
size_t _dim)
23 : nSubTensors_{_nSubTensors}, dim_{_dim} {
24 AT_ASSERT(nSubTensors_ > 1);
25 std::vector<bool> cont = _desc.contiguity;
31 cont[dim_ - 1] =
false;
33 subTensorDesc_.reset(
new TensorDesc(_desc.scalar_type, cont));
37 return (nSubTensors_ == 1);
39 size_t nSubTensors()
const {
45 std::shared_ptr<TensorDesc> subTensorDesc() {
46 return subTensorDesc_;
48 const std::shared_ptr<TensorDesc> subTensorDesc()
const {
49 return subTensorDesc_;
56 std::shared_ptr<TensorDesc>