1 #include "caffe2/utils/math/utils.h" 8 #include "caffe2/core/logging.h" 14 #define CAFFE2_SPECIALIZED_INCREASE_INDEX_IN_DIMS(TIndex) \ 16 C10_EXPORT void IncreaseIndexInDims<TIndex>( \ 17 const int ndim, const TIndex* dims, TIndex* index) { \ 18 for (int i = ndim - 1; i >= 0; --i) { \ 20 if (index[i] >= dims[i]) { \ 21 index[i] -= dims[i]; \ 27 CAFFE2_SPECIALIZED_INCREASE_INDEX_IN_DIMS(std::int32_t)
28 CAFFE2_SPECIALIZED_INCREASE_INDEX_IN_DIMS(
std::int64_t)
29 #undef CAFFE2_SPECIALIZED_INCREASE_INDEX_IN_DIMS 31 int GetIndexFromDims(
const int n,
const int* dims,
const int* index) {
33 for (
int i = 0; i < n; ++i) {
35 sum = sum * dims[i] + index[i];
41 bool IsIdentityPermutation(
const int n,
const int* perm) {
42 for (
int i = 0; i < n; ++i) {
50 bool CheckReduceDims(
const int ndim,
const int* X_dims,
const int* Y_dims) {
51 for (
int i = 0; i < ndim; ++i) {
52 if (X_dims[i] != Y_dims[i] && Y_dims[i] != 1) {
67 for (; pivot >= 0 && B_dims[pivot] == 1; --pivot) {
68 *cols *= A_dims[pivot];
71 for (
int i = pivot; i >= 0; --i) {
72 if (A_dims[i] != B_dims[i]) {
88 for (; pivot < ndim && B_dims[pivot] == 1; ++pivot) {
89 *rows *= A_dims[pivot];
92 for (
int i = pivot; i < ndim; ++i) {
93 if (A_dims[i] != B_dims[i]) {
101 bool IsBothEndsReduce(
110 for (; r >= 0 && B_dims[r] == 1; --r) {
115 for (; l <= r && B_dims[l] == 1; ++l) {
119 for (
int i = l; i <= r; ++i) {
120 if (A_dims[i] != B_dims[i]) {
128 void ComputeBroadcastBinaryOpDims(
133 int* A_broadcast_dims,
134 int* B_broadcast_dims,
135 int* C_broadcast_dims) {
136 const int ndim = std::max(A_ndim, B_ndim);
137 std::fill(A_broadcast_dims, A_broadcast_dims + ndim - A_ndim, 1);
138 std::fill(B_broadcast_dims, B_broadcast_dims + ndim - B_ndim, 1);
139 std::copy(A_dims, A_dims + A_ndim, A_broadcast_dims + ndim - A_ndim);
140 std::copy(B_dims, B_dims + B_ndim, B_broadcast_dims + ndim - B_ndim);
141 for (
int i = 0; i < ndim; ++i) {
143 A_broadcast_dims[i] == B_broadcast_dims[i] ||
144 A_broadcast_dims[i] <= 1 || B_broadcast_dims[i] <= 1);
145 if (A_broadcast_dims[i] == 0 || B_broadcast_dims[i] == 0) {
146 C_broadcast_dims[i] = 0;
148 C_broadcast_dims[i] = std::max(A_broadcast_dims[i], B_broadcast_dims[i]);
153 bool IsRowwiseBroadcastBinaryOp(
159 bool* broadcast_1st) {
164 for (; A_pivot < ndim && A_dims[A_pivot] == 1; ++A_pivot)
167 for (; B_pivot < ndim && B_dims[B_pivot] == 1; ++B_pivot)
169 if (A_pivot == B_pivot) {
172 const int pivot = std::max(A_pivot, B_pivot);
173 if (A_pivot > B_pivot) {
174 *rows = std::accumulate(
175 B_dims + B_pivot, B_dims + pivot, 1, std::multiplies<int>());
176 *broadcast_1st =
true;
178 *rows = std::accumulate(
179 A_dims + A_pivot, A_dims + pivot, 1, std::multiplies<int>());
180 *broadcast_1st =
false;
183 for (
int i = pivot; i < ndim; ++i) {
184 if (A_dims[i] != B_dims[i]) {
192 bool IsColwiseBroadcastBinaryOp(
198 bool* broadcast_1st) {
202 int A_pivot = ndim - 1;
203 for (; A_pivot >= 0 && A_dims[A_pivot] == 1; --A_pivot)
205 int B_pivot = ndim - 1;
206 for (; B_pivot >= 0 && B_dims[B_pivot] == 1; --B_pivot)
208 if (A_pivot == B_pivot) {
213 const int pivot = std::min(A_pivot, B_pivot);
214 if (A_pivot < B_pivot) {
215 *cols = std::accumulate(
216 B_dims + pivot, B_dims + B_pivot, 1, std::multiplies<int>());
217 *broadcast_1st =
true;
219 *cols = std::accumulate(
220 A_dims + pivot, A_dims + A_pivot, 1, std::multiplies<int>());
221 *broadcast_1st =
false;
224 for (
int i = 0; i < pivot; ++i) {
225 if (A_dims[i] != B_dims[i]) {
233 bool IsBothEndsBroadcastBinaryOp(
240 bool* broadcast_1st) {
245 for (; A_pre < ndim && A_dims[A_pre] == 1; ++A_pre)
248 for (; B_pre < ndim && B_dims[B_pre] == 1; ++B_pre)
250 int A_nxt = ndim - 1;
251 for (; A_nxt >= 0 && A_dims[A_nxt] == 1; --A_nxt)
253 int B_nxt = ndim - 1;
254 for (; B_nxt >= 0 && B_dims[B_nxt] == 1; --B_nxt)
258 if (A_pre == B_pre || A_nxt == B_nxt) {
261 if (A_pre > B_pre && A_nxt < B_nxt) {
262 *pre = std::accumulate(
263 B_dims + B_pre, B_dims + A_pre, 1, std::multiplies<int>());
264 *nxt = std::accumulate(
265 B_dims + A_nxt, B_dims + B_nxt, 1, std::multiplies<int>());
266 *broadcast_1st =
true;
267 }
else if (A_pre < B_pre && A_nxt > B_nxt) {
268 *pre = std::accumulate(
269 A_dims + A_pre, A_dims + B_pre, 1, std::multiplies<int>());
270 *nxt = std::accumulate(
271 A_dims + B_nxt, A_dims + A_nxt, 1, std::multiplies<int>());
272 *broadcast_1st =
false;
276 const int l = std::max(A_pre, B_pre);
277 const int r = std::min(A_nxt, B_nxt);
279 for (
int i = l; i < r; ++i) {
280 if (A_dims[i] != B_dims[i]) {
288 bool IsBatchTranspose2D(
const int ndim,
const int* axes) {
292 for (
int i = 0; i < ndim - 2; ++i) {
297 return axes[ndim - 2] == ndim - 1 && axes[ndim - 1] == ndim - 2;
300 void ComputeTransposeAxesForReduceOp(
302 const int num_reduce_axes,
303 const int* reduce_axes,
304 int* transpose_axes) {
305 const int d = num_dims - num_reduce_axes;
306 std::copy_n(reduce_axes, num_reduce_axes, transpose_axes + d);
307 std::sort(transpose_axes + d, transpose_axes + num_dims);
310 for (
int i = 0; i < num_dims; ++i) {
311 if (q < num_dims && i == transpose_axes[q]) {
314 transpose_axes[p++] = i;
319 void ComputeTransposeAxesForReduceOp(
323 const int d = ndim - std::count(dims, dims + ndim, 1);
326 for (
int i = 0; i < ndim; ++i) {
335 #define CAFFE2_SPECIALIZED_COMPUTE_TRANSPOSED_STRIDES(TIndex) \ 337 C10_EXPORT void ComputeTransposedStrides<TIndex>( \ 338 const int ndim, const TIndex* dims, const int* axes, TIndex* strides) { \ 339 std::vector<TIndex> buff(ndim); \ 340 TIndex cur_stride = 1; \ 341 for (int i = ndim - 1; i >= 0; --i) { \ 342 buff[i] = cur_stride; \ 343 cur_stride *= dims[i]; \ 345 for (int i = 0; i < ndim; ++i) { \ 346 strides[i] = buff[axes[i]]; \ 349 CAFFE2_SPECIALIZED_COMPUTE_TRANSPOSED_STRIDES(std::int32_t)
350 CAFFE2_SPECIALIZED_COMPUTE_TRANSPOSED_STRIDES(
std::int64_t)
351 #undef CAFFE2_SPECIALIZED_COMPUTE_TRANSPOSED_STRIDES
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...