Caffe2 - C++ API
A deep learning, cross platform ML framework
tensor_flatten.cpp
1 #include <torch/csrc/utils/tensor_flatten.h>
2 
3 #include <map>
4 #include <unordered_map>
5 
6 namespace torch { namespace utils {
7 
8 using namespace at;
9 
10 std::vector<TensorGroup> take_tensors(
11  TensorList tensors,
12  size_t size_limit,
13  bool fine_grained) {
14  std::vector<TensorGroup> results;
15  // an overapproximation, but at least we won't have to copy stuff around
16  results.reserve(tensors.size());
17  std::map<TypeID, TensorGroup> groups;
18  size_t cur_group_size = 0;
19 
20  for (const auto & tensor : tensors) {
21  auto& type = tensor.type();
22  size_t tensor_size;
23  if (type.is_sparse()) {
24  const auto& indices = tensor._indices();
25  const auto& values = tensor._values();
26  tensor_size = indices.numel() * indices.element_size() +
27  values.numel() * indices.element_size();
28  } else {
29  tensor_size = tensor.numel() * tensor.element_size();
30  }
31 
32  auto& type_group = groups[type.ID()];
33  type_group.tensors.push_back(tensor);
34 
35  if (fine_grained) {
36  cur_group_size += tensor_size;
37  // Regardless the type, the current total size exceeds the limit
38  if (cur_group_size >= size_limit) {
39  // Spill all types to separate groups in results
40  for (auto& entry : groups) {
41  auto& group = entry.second;
42  results.emplace_back(std::move(group));
43  }
44  cur_group_size = 0;
45  groups.clear();
46  }
47  } else {
48  type_group.size += tensor_size;
49  if (type_group.size >= size_limit) {
50  results.emplace_back();
51  std::swap(results.back(), type_group);
52  }
53  }
54  }
55  // End case. Look for any remaining groups and return them.
56  for (auto& entry : groups) {
57  auto& group = entry.second;
58  if (!fine_grained && group.size == 0) {
59  continue;
60  }
61  results.emplace_back(std::move(group));
62  }
63  return results;
64 }
65 
66 void reorder_tensors_like(std::vector<Tensor>& tensors, TensorList order) {
67  AT_ASSERT(tensors.size() == order.size());
68  std::unordered_map<at::Type*, std::vector<size_t>> type_indices;
69  for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i)
70  type_indices[&tensors[i].type()].push_back(i);
71 
72  std::unordered_map<at::Type*, size_t> type_used;
73  std::vector<Tensor> ordered_tensors;
74  ordered_tensors.reserve(tensors.size());
75  for (auto & tmpl_tensor : order) {
76  auto * type = &tmpl_tensor.type();
77  auto & indices = type_indices[type];
78  auto & used = type_used[type];
79  ordered_tensors.push_back(tensors[indices[used++]]);
80  }
81  std::swap(tensors, ordered_tensors);
82 }
83 
84 namespace {
85 
86 at::Tensor get_indices(const at::Tensor& t) {
87  return t._indices();
88 }
89 
90 at::Tensor get_values(const at::Tensor& t) {
91  return t._values();
92 }
93 
94 }
95 
96 std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(at::TensorList tensors) {
97  auto flat_indices = flatten_dense_tensors(fmap(tensors, &get_indices));
98  auto flat_values = flatten_dense_tensors(fmap(tensors, &get_values));
99  return std::make_pair(flat_indices, flat_values);
100 }
101 
102 std::vector<at::Tensor> unflatten_sparse_tensors(
103  const at::Tensor& flat_indices, const at::Tensor& flat_values,
104  at::TensorList tensors) {
105  if (tensors.size() == 0) return {};
106 
107  auto indices = unflatten_dense_tensors(flat_indices, fmap(tensors, &get_indices));
108  auto values = unflatten_dense_tensors(flat_values, fmap(tensors, &get_values));
109 
110  std::vector<at::Tensor> outputs;
111  outputs.reserve(tensors.size());
112  for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; ++i) {
113  auto &ref_t = tensors[i];
114  auto t = at::_sparse_coo_tensor_unsafe(indices[i], values[i], ref_t.sizes());
115  outputs.emplace_back(t._coalesced_(ref_t.is_coalesced()));
116  }
117  return outputs;
118 }
119 
120 
121 }}
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41
Flush-To-Zero and Denormals-Are-Zero mode.