1 #ifndef CAFFE_OPERATORS_REPLACE_NAN_OP_H_ 2 #define CAFFE_OPERATORS_REPLACE_NAN_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/utils/math.h" 11 template <
class Context>
14 USE_OPERATOR_CONTEXT_FUNCTIONS;
15 template <
class... Args>
19 bool RunOnDevice()
override {
24 void ReplaceNaN(
const T& value,
const int64_t size,
const T* X,
T* Y);
27 bool DoRunWithType() {
28 T value = this->
template GetSingleArgument<T>(
"value", 0);
30 auto& input =
Input(0);
32 auto* output = Output(0, input.sizes(), at::dtype<T>());
34 const T* input_data = input.template data<T>();
35 T* output_data = output->template mutable_data<T>();
37 ReplaceNaN<T>(value, input.numel(), input_data, output_data);
45 #endif // CAFFE_OPERATORS_REPLACE_NAN_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...