Caffe2 - C++ API
A deep learning, cross platform ML framework
flatten_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/experimental/c10/schemas/flatten.h"
3 #include "caffe2/utils/math.h"
4 #include "caffe2/core/tensor.h"
5 
7 using caffe2::Tensor;
8 
9 namespace caffe2 {
10 namespace {
11 template <class DataType, class Context>
12 void flatten_op_cpu_impl(
13  const at::Tensor& input_,
14  const at::Tensor& output_,
15  int64_t axis) {
16  Tensor input{C10Tensor(input_)};
17  Tensor output{C10Tensor(output_)};
18  CPUContext context;
19  CAFFE_ENFORCE_GE(
20  input.sizes().size(), axis, "The rank of the tensor must be >= axis.");
21  output.Resize(input.size_to_dim(axis), input.size_from_dim(axis));
22  context.CopyItemsSameDevice(
23  input.dtype(),
24  input.numel(),
25  input.raw_data(),
26  output.raw_mutable_data(input.dtype()));
27 }
28 } // namespace
29 } // namespace caffe2
30 
31 namespace c10 {
32 C10_REGISTER_KERNEL(caffe2::ops::Flatten)
33  .kernel<decltype(caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>), &caffe2::flatten_op_cpu_impl<float, caffe2::CPUContext>>()
34  .dispatchKey(CPUTensorId());
35 } // namespace c10
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
Virtual interface for the Context class in Caffe2.
Definition: context_base.h:32
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7