Caffe2 - C++ API
A deep learning, cross platform ML framework
shape_op.h
1 
18 #pragma once
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 
23 namespace caffe2 {
24 
25 // RecordShapeOp records the shape of the input tensor to a vector of int. You
26 // mostly don't need this operator explicitly, and it is mostly used in the
27 // autodiff process.
28 template <class Context>
29 class ShapeOp : public Operator<Context> {
30  public:
31  USE_OPERATOR_CONTEXT_FUNCTIONS;
32  USE_SIMPLE_CTOR_DTOR(ShapeOp);
33 
34  bool RunOnDevice() override {
35  auto& input = Input(0);
36  auto* output = OperatorBase::Output<Tensor<Context>>(0);
37  output->Resize(input.ndim());
38  TIndex* output_data = output->template mutable_data<TIndex>();
39  context_.template CopyBytes<Context, Context>(
40  input.ndim() * sizeof(TIndex), input.dims().data(), output_data);
41  return true;
42  }
43 };
44 
45 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.