Caffe2 - C++ API
A deep learning, cross platform ML framework
replace_nan_op.h
1 
17 #ifndef CAFFE_OPERATORS_REPLACE_NAN_OP_H_
18 #define CAFFE_OPERATORS_REPLACE_NAN_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <class Context>
28 class ReplaceNaNOp final : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  ReplaceNaNOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws) {}
33 
34  bool RunOnDevice() override {
35  return DispatchHelper<TensorTypes<float, double>>::call(this, Input(0));
36  }
37 
38  template <typename T>
39  void ReplaceNaN(const T& value, const TIndex size, const T* X, T* Y);
40 
41  template <typename T>
42  bool DoRunWithType() {
43  T value = OperatorBase::GetSingleArgument<T>("value", 0);
44 
45  auto& input = Input(0);
46  auto* output = Output(0);
47  output->ResizeLike(input);
48 
49  const T* input_data = input.template data<T>();
50  T* output_data = output->template mutable_data<T>();
51 
52  ReplaceNaN<T>(value, input.size(), input_data, output_data);
53 
54  return true;
55  }
56 };
57 
58 } // namespace caffe2
59 
60 #endif // CAFFE_OPERATORS_REPLACE_NAN_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.