Caffe2 - C++ API
A deep learning, cross platform ML framework
elementwise_logical_ops.h
1 #ifndef CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_
2 #define CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/operators/elementwise_ops.h"
9 
10 #include <unordered_set>
11 
12 namespace caffe2 {
13 
14 template <class Context>
15 class WhereOp final : public Operator<Context> {
16  public:
17  USE_OPERATOR_FUNCTIONS(Context);
18  USE_DISPATCH_HELPER;
19 
20  template <class... Args>
21  explicit WhereOp(Args&&... args)
22  : Operator<Context>(std::forward<Args>(args)...),
23  OP_SINGLE_ARG(bool, "broadcast_on_rows", enable_broadcast_, 0) {}
24 
25  bool RunOnDevice() override {
26  return DispatchHelper<
28  call(this, Input(1));
29  }
30 
31  template <typename T>
32  bool DoRunWithType() {
33  auto& select = Input(0);
34  auto& left = Input(1);
35  auto& right = Input(2);
36 
37  if (enable_broadcast_) {
38  CAFFE_ENFORCE_EQ(select.dim(), 1);
39  CAFFE_ENFORCE_EQ(select.size(0), right.size(0));
40  CAFFE_ENFORCE_EQ(left.sizes(), right.sizes());
41  } else {
42  CAFFE_ENFORCE_EQ(select.sizes(), left.sizes());
43  CAFFE_ENFORCE_EQ(select.sizes(), right.sizes());
44  }
45  auto* output = Output(0, left.sizes(), at::dtype<T>());
46 
47  const bool* select_data = select.template data<bool>();
48  const T* left_data = left.template data<T>();
49  const T* right_data = right.template data<T>();
50  T* output_data = output->template mutable_data<T>();
51 
52  if (enable_broadcast_) {
53  size_t block_size = left.size_from_dim(1);
54  for (int i = 0; i < select.numel(); i++) {
55  size_t offset = i * block_size;
56  if (select_data[i]) {
57  context_.CopyItemsSameDevice(
58  output->dtype(),
59  block_size,
60  left_data + offset,
61  output_data + offset);
62  } else {
63  context_.CopyItemsSameDevice(
64  output->dtype(),
65  block_size,
66  right_data + offset,
67  output_data + offset);
68  }
69  }
70  } else {
71  for (int i = 0; i < select.numel(); ++i) {
72  output_data[i] = select_data[i] ? left_data[i] : right_data[i];
73  }
74  }
75  return true;
76  }
77 
78  private:
79  bool enable_broadcast_;
80 };
81 
83  std::unordered_set<int32_t> int32_values_;
84  std::unordered_set<int64_t> int64_values_;
85  std::unordered_set<bool> bool_values_;
86  std::unordered_set<std::string> string_values_;
87  bool has_values_ = false;
88 
89  public:
90  template <typename T>
91  std::unordered_set<T>& get();
92 
93  template <typename T>
94  void set(const std::vector<T>& args) {
95  has_values_ = true;
96  auto& values = get<T>();
97  values.insert(args.begin(), args.end());
98  }
99 
100  bool has_values() {
101  return has_values_;
102  }
103 };
104 
105 template <class Context>
106 class IsMemberOfOp final : public Operator<Context> {
107  USE_OPERATOR_CONTEXT_FUNCTIONS;
108  USE_DISPATCH_HELPER;
109 
110  static constexpr const char* VALUE_TAG = "value";
111 
112  public:
114 
115  template <class... Args>
116  explicit IsMemberOfOp(Args&&... args)
117  : Operator<Context>(std::forward<Args>(args)...) {
118  auto dtype =
119  static_cast<TensorProto_DataType>(this->template GetSingleArgument<int>(
120  "dtype", TensorProto_DataType_UNDEFINED));
121  switch (dtype) {
122  case TensorProto_DataType_INT32:
123  values_.set(this->template GetRepeatedArgument<int32_t>(VALUE_TAG));
124  break;
125  case TensorProto_DataType_INT64:
126  values_.set(this->template GetRepeatedArgument<int64_t>(VALUE_TAG));
127  break;
128  case TensorProto_DataType_BOOL:
129  values_.set(this->template GetRepeatedArgument<bool>(VALUE_TAG));
130  break;
131  case TensorProto_DataType_STRING:
132  values_.set(this->template GetRepeatedArgument<std::string>(VALUE_TAG));
133  break;
134  case TensorProto_DataType_UNDEFINED:
135  // If dtype is not provided, values_ will be filled the first time that
136  // DoRunWithType is called.
137  break;
138  default:
139  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
140  }
141  }
142  virtual ~IsMemberOfOp() noexcept {}
143 
144  bool RunOnDevice() override {
145  return DispatchHelper<
147  }
148 
149  template <typename T>
150  bool DoRunWithType() {
151  auto& input = Input(0);
152 
153  auto* output = Output(0, input.sizes(), at::dtype<bool>());
154 
155  if (!values_.has_values()) {
156  values_.set(this->template GetRepeatedArgument<T>(VALUE_TAG));
157  }
158  const auto& values = values_.get<T>();
159 
160  const T* input_data = input.template data<T>();
161  bool* output_data = output->template mutable_data<bool>();
162  for (int i = 0; i < input.numel(); ++i) {
163  output_data[i] = values.find(input_data[i]) != values.end();
164  }
165  return true;
166  }
167 
168  protected:
169  IsMemberOfValueHolder values_;
170 };
171 
172 } // namespace caffe2
173 
174 #endif // CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13