Caffe2 - C++ API
A deep learning, cross platform ML framework
cast_op.h
1 
17 #pragma once
18 
19 #include "caffe2/core/context.h"
20 #include "caffe2/core/operator.h"
21 #include "caffe2/utils/cast.h"
22 #include "caffe2/utils/conversions.h"
23 #include "caffe2/utils/math.h"
24 #include "caffe2/core/logging.h"
25 #include "caffe2/core/types.h"
26 
27 namespace caffe2 {
28 
29 template <class Context>
30 class CastOp : public Operator<Context> {
31  public:
32  USE_OPERATOR_CONTEXT_FUNCTIONS;
33 
34  CastOp(const OperatorDef& operator_def, Workspace* ws)
35  : Operator<Context>(operator_def, ws) {
36  const ArgumentHelper helper(operator_def);
37  TensorProto_DataType to = cast::GetCastDataType(helper, "to");
38  TensorProto_DataType from = cast::GetCastDataType(helper, "from_type");
39 
40  SetBody(to);
41  }
42 
43  bool RunOnDevice() override {
44  return (this->*body_)();
45  }
46 
47  // Allow for Context-specific implementations
48  void SetBody(TensorProto_DataType to);
49 
50  template <typename DstType>
51  bool DoRunWithDstType();
52 
53  template <typename DstType, typename SrcType>
54  bool DoRunWithType() {
55  auto& input = Input(0);
56  auto* output = Output(0);
57  output->ResizeLike(input);
58  const auto* data = input.template data<SrcType>();
59  auto* out = output->template mutable_data<DstType>();
60  auto N = input.size();
61  for (TIndex i = 0; i < N; ++i) {
62  out[i] = static_cast<DstType>(data[i]);
63  }
64  return true;
65  }
66 
67  private:
68  bool (CastOp::*body_)();
69 };
70 
71 } // namespace caffe2
A helper class to index into arguments.
Definition: proto_utils.h:198
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.