Caffe2 - C++ API
A deep learning, cross platform ML framework
flatten_op.h
1 #ifndef CAFFE2_OPERATORS_FLATTEN_OP_H_
2 #define CAFFE2_OPERATORS_FLATTEN_OP_H_
3 
4 #include "caffe2/core/operator.h"
5 
6 namespace caffe2 {
7 
8 template <class Context>
9 class FlattenOp : public Operator<Context> {
10  public:
11  USE_OPERATOR_CONTEXT_FUNCTIONS;
12 
13  template <class... Args>
14  explicit FlattenOp(Args&&... args)
15  : Operator<Context>(std::forward<Args>(args)...),
16  axis_(this->template GetSingleArgument<int>("axis", 1)) {}
17 
18  bool RunOnDevice() override {
19  auto& input = Input(0);
20  auto* output = Output(0);
21  CAFFE_ENFORCE_GE(
22  input.dim(), axis_, "The rank of the tensor must be >= axis.");
23  output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_));
24  context_.CopyItemsSameDevice(
25  input.dtype(),
26  input.numel(),
27  input.raw_data(),
28  output->raw_mutable_data(input.dtype()));
29  return true;
30  }
31 
32  private:
33  int axis_;
34 };
35 
36 inline std::vector<TensorShape> TensorInferenceForFlatten(
37  const OperatorDef& def,
38  const std::vector<TensorShape>& in) {
39  ArgumentHelper helper(def);
40  const int axis = helper.GetSingleArgument<int>("axis", 1);
41  std::vector<TensorShape> out(1);
42  int64_t outer = 1;
43  int64_t inner = 1;
44  std::size_t index = 0;
45  for (auto d : in[0].dims()) {
46  if (index < axis) {
47  outer *= d;
48  } else {
49  inner *= d;
50  }
51  ++index;
52  }
53  out[0].set_data_type(in[0].data_type());
54  out[0].add_dims(outer);
55  out[0].add_dims(inner);
56  return out;
57 }
58 
59 } // namespace caffe2
60 
61 #endif // CAFFE2_OPERATORS_FLATTEN_OP_H_
A helper class to index into arguments.
Definition: proto_utils.h:200
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13