3 #include <ATen/core/functional.h> 8 namespace torch {
namespace utils {
11 static auto flatten = [](
const at::Tensor &t) {
return t.contiguous().view({-1}); };
12 if (tensors.
size() == 1)
13 return flatten(tensors[0]);
14 return at::cat(fmap(tensors, flatten));
18 std::vector<at::Tensor> outputs;
19 outputs.reserve(tensors.
size());
21 for (
const auto & tensor : tensors) {
22 auto numel = tensor.numel();
23 outputs.push_back(flat.narrow(0, offset, numel).view(tensor.sizes()));
31 std::vector<at::Tensor> tensors;
35 AT_ASSERT(!tensors.empty());
36 return tensors[0].type();
62 std::vector<TensorGroup> take_tensors(
65 bool fine_grained =
false);
67 void reorder_tensors_like(std::vector<at::Tensor>& tensors,
at::TensorList order);
69 std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(
at::TensorList tensors);
71 std::vector<at::Tensor> unflatten_sparse_tensors(
constexpr size_t size() const
size - Get the array size.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...