1 #include <torch/csrc/utils/tensor_flatten.h> 4 #include <unordered_map> 6 namespace torch {
namespace utils {
10 std::vector<TensorGroup> take_tensors(
14 std::vector<TensorGroup> results;
16 results.reserve(tensors.
size());
17 std::map<TypeID, TensorGroup> groups;
18 size_t cur_group_size = 0;
20 for (
const auto & tensor : tensors) {
21 auto& type = tensor.type();
23 if (type.is_sparse()) {
24 const auto& indices = tensor._indices();
25 const auto& values = tensor._values();
26 tensor_size = indices.numel() * indices.element_size() +
27 values.numel() * indices.element_size();
29 tensor_size = tensor.numel() * tensor.element_size();
32 auto& type_group = groups[type.ID()];
33 type_group.tensors.push_back(tensor);
36 cur_group_size += tensor_size;
38 if (cur_group_size >= size_limit) {
40 for (
auto& entry : groups) {
41 auto& group = entry.second;
42 results.emplace_back(std::move(group));
48 type_group.size += tensor_size;
49 if (type_group.size >= size_limit) {
50 results.emplace_back();
51 std::swap(results.back(), type_group);
56 for (
auto& entry : groups) {
57 auto& group = entry.second;
58 if (!fine_grained && group.size == 0) {
61 results.emplace_back(std::move(group));
66 void reorder_tensors_like(std::vector<Tensor>& tensors,
TensorList order) {
67 AT_ASSERT(tensors.size() == order.
size());
68 std::unordered_map<at::Type*, std::vector<size_t>> type_indices;
69 for (
size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i)
70 type_indices[&tensors[i].type()].push_back(i);
72 std::unordered_map<at::Type*, size_t> type_used;
73 std::vector<Tensor> ordered_tensors;
74 ordered_tensors.reserve(tensors.size());
75 for (
auto & tmpl_tensor : order) {
76 auto * type = &tmpl_tensor.type();
77 auto & indices = type_indices[type];
78 auto & used = type_used[type];
79 ordered_tensors.push_back(tensors[indices[used++]]);
81 std::swap(tensors, ordered_tensors);
96 std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(
at::TensorList tensors) {
97 auto flat_indices = flatten_dense_tensors(fmap(tensors, &get_indices));
98 auto flat_values = flatten_dense_tensors(fmap(tensors, &get_values));
99 return std::make_pair(flat_indices, flat_values);
102 std::vector<at::Tensor> unflatten_sparse_tensors(
105 if (tensors.
size() == 0)
return {};
107 auto indices = unflatten_dense_tensors(flat_indices, fmap(tensors, &get_indices));
108 auto values = unflatten_dense_tensors(flat_values, fmap(tensors, &get_values));
110 std::vector<at::Tensor> outputs;
111 outputs.reserve(tensors.
size());
112 for (
size_t i = 0, num_tensors = tensors.
size(); i < num_tensors; ++i) {
113 auto &ref_t = tensors[i];
114 auto t = at::_sparse_coo_tensor_unsafe(indices[i], values[i], ref_t.sizes());
115 outputs.emplace_back(t._coalesced_(ref_t.is_coalesced()));
constexpr size_t size() const
size - Get the array size.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Flush-To-Zero and Denormals-Are-Zero mode.