1 #ifndef CAFFE2_OPERATORS_EXPAND_OP_H_ 2 #define CAFFE2_OPERATORS_EXPAND_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/core/types.h" 9 #include "caffe2/utils/math.h" 13 template <
typename InputTypes,
class Context>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 template <
class... Args>
22 bool RunOnDevice()
override {
26 bool DoRunWithType() {
27 const auto& X =
Input(0);
28 const auto& Y_shape_tensor =
Input(1);
29 std::vector<int64_t> shape_dims(Y_shape_tensor.numel());
30 context_.template CopyToCPU<int64_t>(
31 Y_shape_tensor.numel(),
32 Y_shape_tensor.template data<int64_t>(),
35 const int ndim = shape_dims.size();
36 const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
37 std::vector<int> Y_dims;
38 Y_dims.reserve(std::max(ndim, X.dim()));
40 for (
int i = ndim - 1, j = X.dim() - 1; i >= 0 || j >= 0; --i, --j) {
41 const int shape_x = (j >= 0 ? X_dims[j] : 1);
44 const int shape_y = ((i >= 0 && shape_dims[i] > 0) ? shape_dims[i] : 1);
47 shape_x == 1 || shape_y == 1 || shape_x == shape_y,
48 "Dimensions format invalid.");
49 Y_dims.push_back(std::max(shape_x, shape_y));
51 std::reverse(Y_dims.begin(), Y_dims.end());
53 std::vector<int64_t> Y_dims_int64;
54 std::copy(Y_dims.begin(), Y_dims.end(), std::back_inserter(Y_dims_int64));
55 auto* Y = Output(0, Y_dims_int64, at::dtype<T>());
56 math::Broadcast<T, Context>(
63 Y->template mutable_data<T>(),
70 template <
typename InputTypes,
class Context>
73 USE_OPERATOR_CONTEXT_FUNCTIONS;
75 template <
class... Args>
79 bool RunOnDevice()
override {
84 bool DoRunWithType() {
85 const auto& dY =
Input(0);
86 const auto& X =
Input(1);
88 const int ndim = dY.dim();
89 const std::vector<int> dX_dims(X.sizes().cbegin(), X.sizes().cend());
90 const std::vector<int> dY_dims(dY.sizes().cbegin(), dY.sizes().cend());
91 auto* dX = Output(0, X.sizes(), at::dtype<T>());
92 std::vector<int> axes;
93 const int offset = ndim - X.dim();
94 for (
int i = 0; i < ndim; i++) {
95 if (i < offset || dX_dims[i - offset] == 1) {
99 std::vector<int> X_dims = dY_dims;
100 for (
const int axis : axes) {
103 math::ReduceSum<T, Context>(
108 dY.template data<T>(),
109 dX->template mutable_data<T>(),
117 #endif // CAFFE2_OPERATORS_REDUCE_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...