Caffe2 - C++ API
A deep learning, cross platform ML framework
flatten_op.h
1 
17 #ifndef CAFFE2_OPERATORS_FLATTEN_OP_H_
18 #define CAFFE2_OPERATORS_FLATTEN_OP_H_
19 
20 #include "caffe2/core/operator.h"
21 
22 namespace caffe2 {
23 
24 template <class Context>
25 class FlattenOp : public Operator<Context> {
26  public:
27  USE_OPERATOR_CONTEXT_FUNCTIONS;
28 
29  FlattenOp(const OperatorDef& operator_def, Workspace* ws)
30  : Operator<Context>(operator_def, ws),
31  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)) {}
32 
33  bool RunOnDevice() override {
34  auto& input = Input(0);
35  auto* output = Output(0);
36  CAFFE_ENFORCE_GE(
37  input.dims().size(), axis_, "The rank of the tensor must be >= axis.");
38  output->Resize(input.size_to_dim(axis_), input.size_from_dim(axis_));
39  context_.template CopyItems<Context, Context>(
40  input.meta(),
41  input.size(),
42  input.raw_data(),
43  output->raw_mutable_data(input.meta()));
44  return true;
45  }
46 
47  private:
48  int axis_;
49 };
50 
51 } // namespace caffe2
52 
53 #endif // CAFFE2_OPERATORS_FLATTEN_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.