Caffe2 - C++ API
A deep learning, cross platform ML framework
eigen_utils.h
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 
3 #ifndef CAFFE2_OPERATORS_UTILS_EIGEN_H_
4 #define CAFFE2_OPERATORS_UTILS_EIGEN_H_
5 
6 #include "Eigen/Core"
7 #include "Eigen/Dense"
8 #include "caffe2/core/logging.h"
9 
10 namespace caffe2 {
11 
12 // 1-d array
13 template <typename T>
14 using EArrXt = Eigen::Array<T, Eigen::Dynamic, 1>;
15 using EArrXf = Eigen::ArrayXf;
16 using EArrXd = Eigen::ArrayXd;
17 using EArrXi = Eigen::ArrayXi;
18 using EArrXb = EArrXt<bool>;
19 
20 // 2-d array, column major
21 template <typename T>
22 using EArrXXt = Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>;
23 using EArrXXf = Eigen::ArrayXXf;
24 
25 // 2-d array, row major
26 template <typename T>
27 using ERArrXXt =
28  Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
29 using ERArrXXf = ERArrXXt<float>;
30 
31 // 1-d vector
32 template <typename T>
33 using EVecXt = Eigen::Matrix<T, Eigen::Dynamic, 1>;
34 using EVecXd = Eigen::VectorXd;
35 using EVecXf = Eigen::VectorXf;
36 
37 // 1-d row vector
38 using ERVecXd = Eigen::RowVectorXd;
39 using ERVecXf = Eigen::RowVectorXf;
40 
41 // 2-d matrix, column major
42 template <typename T>
43 using EMatXt = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>;
44 using EMatXd = Eigen::MatrixXd;
45 using EMatXf = Eigen::MatrixXf;
46 
47 // 2-d matrix, row major
48 template <typename T>
49 using ERMatXt =
50  Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
51 using ERMatXd = ERMatXt<double>;
52 using ERMatXf = ERMatXt<float>;
53 
54 namespace utils {
55 
56 template <typename T>
57 Eigen::Map<const EArrXt<T>> AsEArrXt(const std::vector<T>& arr) {
58  return {arr.data(), static_cast<int>(arr.size())};
59 }
60 template <typename T>
61 Eigen::Map<EArrXt<T>> AsEArrXt(std::vector<T>& arr) {
62  return {arr.data(), static_cast<int>(arr.size())};
63 }
64 
65 // return a sub array of 'array' based on indices 'indices'
66 template <class Derived, class Derived1, class Derived2>
67 void GetSubArray(
68  const Eigen::ArrayBase<Derived>& array,
69  const Eigen::ArrayBase<Derived1>& indices,
70  Eigen::ArrayBase<Derived2>* out_array) {
71  CAFFE_ENFORCE_EQ(array.cols(), 1);
72  // using T = typename Derived::Scalar;
73 
74  out_array->derived().resize(indices.size());
75  for (int i = 0; i < indices.size(); i++) {
76  DCHECK_LT(indices[i], array.size());
77  (*out_array)[i] = array[indices[i]];
78  }
79 }
80 
81 // return a sub array of 'array' based on indices 'indices'
82 template <class Derived, class Derived1>
83 EArrXt<typename Derived::Scalar> GetSubArray(
84  const Eigen::ArrayBase<Derived>& array,
85  const Eigen::ArrayBase<Derived1>& indices) {
86  using T = typename Derived::Scalar;
87  EArrXt<T> ret(indices.size());
88  GetSubArray(array, indices, &ret);
89  return ret;
90 }
91 
92 // return a sub array of 'array' based on indices 'indices'
93 template <class Derived>
94 EArrXt<typename Derived::Scalar> GetSubArray(
95  const Eigen::ArrayBase<Derived>& array,
96  const std::vector<int>& indices) {
97  return GetSubArray(array, AsEArrXt(indices));
98 }
99 
100 // return 2d sub array of 'array' based on row indices 'row_indices'
101 template <class Derived, class Derived1, class Derived2>
102 void GetSubArrayRows(
103  const Eigen::ArrayBase<Derived>& array2d,
104  const Eigen::ArrayBase<Derived1>& row_indices,
105  Eigen::ArrayBase<Derived2>* out_array) {
106  out_array->derived().resize(row_indices.size(), array2d.cols());
107 
108  for (int i = 0; i < row_indices.size(); i++) {
109  DCHECK_LT(row_indices[i], array2d.size());
110  out_array->row(i) =
111  array2d.row(row_indices[i]).template cast<typename Derived2::Scalar>();
112  }
113 }
114 
115 // return indices of 1d array for elements evaluated to true
116 template <class Derived>
117 std::vector<int> GetArrayIndices(const Eigen::ArrayBase<Derived>& array) {
118  std::vector<int> ret;
119  for (int i = 0; i < array.size(); i++) {
120  if (array[i]) {
121  ret.push_back(i);
122  }
123  }
124  return ret;
125 }
126 
127 } // namespace utils
128 } // namespace caffe2
129 
130 #endif
Copyright (c) 2016-present, Facebook, Inc.