1 #ifndef CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_ 2 #define CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/operators/elementwise_ops.h" 10 #include <unordered_set> 14 template <
class Context>
17 USE_OPERATOR_FUNCTIONS(Context);
20 template <
class... Args>
21 explicit WhereOp(Args&&... args)
23 OP_SINGLE_ARG(
bool,
"broadcast_on_rows", enable_broadcast_, 0) {}
25 bool RunOnDevice()
override {
32 bool DoRunWithType() {
33 auto& select =
Input(0);
34 auto& left =
Input(1);
35 auto& right =
Input(2);
37 if (enable_broadcast_) {
38 CAFFE_ENFORCE_EQ(select.dim(), 1);
39 CAFFE_ENFORCE_EQ(select.size(0), right.size(0));
40 CAFFE_ENFORCE_EQ(left.sizes(), right.sizes());
42 CAFFE_ENFORCE_EQ(select.sizes(), left.sizes());
43 CAFFE_ENFORCE_EQ(select.sizes(), right.sizes());
45 auto* output = Output(0, left.sizes(), at::dtype<T>());
47 const bool* select_data = select.template data<bool>();
48 const T* left_data = left.template data<T>();
49 const T* right_data = right.template data<T>();
50 T* output_data = output->template mutable_data<T>();
52 if (enable_broadcast_) {
53 size_t block_size = left.size_from_dim(1);
54 for (
int i = 0; i < select.numel(); i++) {
55 size_t offset = i * block_size;
57 context_.CopyItemsSameDevice(
61 output_data + offset);
63 context_.CopyItemsSameDevice(
67 output_data + offset);
71 for (
int i = 0; i < select.numel(); ++i) {
72 output_data[i] = select_data[i] ? left_data[i] : right_data[i];
79 bool enable_broadcast_;
83 std::unordered_set<int32_t> int32_values_;
84 std::unordered_set<int64_t> int64_values_;
85 std::unordered_set<bool> bool_values_;
86 std::unordered_set<std::string> string_values_;
87 bool has_values_ =
false;
91 std::unordered_set<T>&
get();
94 void set(
const std::vector<T>& args) {
96 auto& values = get<T>();
97 values.insert(args.begin(), args.end());
105 template <
class Context>
107 USE_OPERATOR_CONTEXT_FUNCTIONS;
110 static constexpr
const char* VALUE_TAG =
"value";
115 template <
class... Args>
119 static_cast<TensorProto_DataType
>(this->
template GetSingleArgument<int>(
120 "dtype", TensorProto_DataType_UNDEFINED));
122 case TensorProto_DataType_INT32:
123 values_.set(this->
template GetRepeatedArgument<int32_t>(VALUE_TAG));
125 case TensorProto_DataType_INT64:
126 values_.set(this->
template GetRepeatedArgument<int64_t>(VALUE_TAG));
128 case TensorProto_DataType_BOOL:
129 values_.set(this->
template GetRepeatedArgument<bool>(VALUE_TAG));
131 case TensorProto_DataType_STRING:
132 values_.set(this->
template GetRepeatedArgument<std::string>(VALUE_TAG));
134 case TensorProto_DataType_UNDEFINED:
139 CAFFE_THROW(
"Unexpected 'dtype' argument value: ", dtype);
142 virtual ~IsMemberOfOp() noexcept {}
144 bool RunOnDevice()
override {
149 template <
typename T>
150 bool DoRunWithType() {
151 auto& input =
Input(0);
153 auto* output = Output(0, input.sizes(), at::dtype<bool>());
155 if (!values_.has_values()) {
156 values_.set(this->
template GetRepeatedArgument<T>(VALUE_TAG));
158 const auto& values = values_.get<
T>();
160 const T* input_data = input.template data<T>();
161 bool* output_data = output->template mutable_data<bool>();
162 for (
int i = 0; i < input.numel(); ++i) {
163 output_data[i] = values.find(input_data[i]) != values.end();
174 #endif // CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...