1 #ifndef CAFFE2_OPERATORS_PACK_SEGMENTS_H_ 2 #define CAFFE2_OPERATORS_PACK_SEGMENTS_H_ 7 #include <unordered_map> 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/tensor.h" 11 #include "caffe2/utils/math.h" 15 template <
class Context>
18 USE_OPERATOR_CONTEXT_FUNCTIONS;
21 template <
class... Args>
24 max_length_(this->
template GetSingleArgument<int>(
"max_length", -1)),
25 pad_minf_(this->
template GetSingleArgument<bool>(
"pad_minf",
false)),
26 return_presence_mask_(this->
template GetSingleArgument<bool>(
27 "return_presence_mask",
30 padding_ = -1.0 * std::numeric_limits<float>::infinity();
43 template <
typename T,
typename Data_T>
44 bool DoRunWithType2();
46 INPUT_TAGS(LENGTHS, DATA);
52 bool return_presence_mask_;
55 Tensor dev_buffer_{Context::GetDeviceType()};
56 Tensor dev_lengths_prefix_sum_{Context::GetDeviceType()};
57 Tensor dev_max_length_{Context::GetDeviceType()};
58 Tensor host_max_length_{CPU};
61 template <
class Context>
64 USE_OPERATOR_CONTEXT_FUNCTIONS;
67 template <
class... Args>
70 max_length_(this->
template GetSingleArgument<int>(
"max_length", -1)) {}
72 bool RunOnDevice()
override {
79 template <
typename T,
typename Data_T>
80 bool DoRunWithType2();
82 INPUT_TAGS(LENGTHS, DATA);
86 Tensor dev_buffer_{Context::GetDeviceType()};
87 Tensor dev_lengths_prefix_sum_{Context::GetDeviceType()};
88 Tensor dev_max_length_{Context::GetDeviceType()};
89 Tensor dev_num_cell_{Context::GetDeviceType()};
90 Tensor host_max_length_{CPU};
91 Tensor host_num_cell_{CPU};
95 #endif // CAFFE2_OPERATORS_PACK_SEGMENTS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...