1 #ifndef CAFFE2_OPERATORS_FLATTEN_OP_H_ 2 #define CAFFE2_OPERATORS_FLATTEN_OP_H_ 4 #include "caffe2/core/operator.h" 8 template <
class Context>
11 USE_OPERATOR_CONTEXT_FUNCTIONS;
13 template <
class... Args>
16 axis_(this->
template GetSingleArgument<int>(
"axis", 1)) {}
18 bool RunOnDevice()
override {
19 auto& input =
Input(0);
20 auto* output = Output(0);
22 input.dim(), axis_,
"The rank of the tensor must be >= axis.");
23 output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_));
24 context_.CopyItemsSameDevice(
28 output->raw_mutable_data(input.dtype()));
36 inline std::vector<TensorShape> TensorInferenceForFlatten(
37 const OperatorDef& def,
38 const std::vector<TensorShape>& in) {
40 const int axis = helper.GetSingleArgument<
int>(
"axis", 1);
41 std::vector<TensorShape> out(1);
44 std::size_t index = 0;
45 for (
auto d : in[0].dims()) {
53 out[0].set_data_type(in[0].data_type());
54 out[0].add_dims(outer);
55 out[0].add_dims(inner);
61 #endif // CAFFE2_OPERATORS_FLATTEN_OP_H_ A helper class to index into arguments.
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...