Caffe2 - C++ API
A deep learning, cross platform ML framework
elementwise_logical_ops.h
1 
17 #ifndef CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_
18 #define CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_
19 
20 #include "caffe2/core/common_omp.h"
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/logging.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/operators/elementwise_op.h"
25 
26 #include <unordered_set>
27 
28 namespace caffe2 {
29 
30 template <class Context>
31 class WhereOp final : public Operator<Context> {
32  public:
33  USE_OPERATOR_FUNCTIONS(Context);
34  USE_DISPATCH_HELPER;
35 
36  WhereOp(const OperatorDef& operator_def, Workspace* ws)
37  : Operator<Context>(operator_def, ws),
38  OP_SINGLE_ARG(bool, "broadcast_on_rows", enable_broadcast_, 0) {}
39 
40  bool RunOnDevice() override {
41  return DispatchHelper<
43  call(this, Input(1));
44  }
45 
46  template <typename T>
47  bool DoRunWithType() {
48  auto& select = Input(0);
49  auto& left = Input(1);
50  auto& right = Input(2);
51  auto* output = Output(0);
52  if (enable_broadcast_) {
53  CAFFE_ENFORCE_EQ(select.ndim(), 1);
54  CAFFE_ENFORCE_EQ(select.dim(0), right.dim(0));
55  CAFFE_ENFORCE_EQ(left.dims(), right.dims());
56  } else {
57  CAFFE_ENFORCE_EQ(select.dims(), left.dims());
58  CAFFE_ENFORCE_EQ(select.dims(), right.dims());
59  }
60  output->ResizeLike(left);
61 
62  const bool* select_data = select.template data<bool>();
63  const T* left_data = left.template data<T>();
64  const T* right_data = right.template data<T>();
65  T* output_data = output->template mutable_data<T>();
66 
67  if (enable_broadcast_) {
68  size_t block_size = left.size_from_dim(1);
69  for (int i = 0; i < select.size(); i++) {
70  size_t offset = i * block_size;
71  if (select_data[i]) {
72  context_.template CopyItems<Context, Context>(
73  output->meta(),
74  block_size,
75  left_data + offset,
76  output_data + offset);
77  } else {
78  context_.template CopyItems<Context, Context>(
79  output->meta(),
80  block_size,
81  right_data + offset,
82  output_data + offset);
83  }
84  }
85  } else {
86  for (int i = 0; i < select.size(); ++i) {
87  output_data[i] = select_data[i] ? left_data[i] : right_data[i];
88  }
89  }
90  return true;
91  }
92 
93  private:
94  bool enable_broadcast_;
95 };
96 
98  std::unordered_set<int32_t> int32_values_;
99  std::unordered_set<int64_t> int64_values_;
100  std::unordered_set<bool> bool_values_;
101  std::unordered_set<std::string> string_values_;
102  bool has_values_ = false;
103 
104  public:
105  template <typename T>
106  std::unordered_set<T>& get();
107 
108  template <typename T>
109  void set(const std::vector<T>& args) {
110  has_values_ = true;
111  auto& values = get<T>();
112  values.insert(args.begin(), args.end());
113  }
114 
115  bool has_values() {
116  return has_values_;
117  }
118 };
119 
120 template <class Context>
121 class IsMemberOfOp final : public Operator<Context> {
122  USE_OPERATOR_CONTEXT_FUNCTIONS;
123  USE_DISPATCH_HELPER;
124 
125  static constexpr const char* VALUE_TAG = "value";
126 
127  public:
129 
130  IsMemberOfOp(const OperatorDef& op, Workspace* ws)
131  : Operator<Context>(op, ws) {
132  auto dtype =
133  static_cast<TensorProto_DataType>(OperatorBase::GetSingleArgument<int>(
134  "dtype", TensorProto_DataType_UNDEFINED));
135  switch (dtype) {
136  case TensorProto_DataType_INT32:
137  values_.set(OperatorBase::GetRepeatedArgument<int32_t>(VALUE_TAG));
138  break;
139  case TensorProto_DataType_INT64:
140  values_.set(OperatorBase::GetRepeatedArgument<int64_t>(VALUE_TAG));
141  break;
142  case TensorProto_DataType_BOOL:
143  values_.set(OperatorBase::GetRepeatedArgument<bool>(VALUE_TAG));
144  break;
145  case TensorProto_DataType_STRING:
146  values_.set(OperatorBase::GetRepeatedArgument<std::string>(VALUE_TAG));
147  break;
148  case TensorProto_DataType_UNDEFINED:
149  // If dtype is not provided, values_ will be filled the first time that
150  // DoRunWithType is called.
151  break;
152  default:
153  CAFFE_THROW("Unexpected 'dtype' argument value: ", dtype);
154  }
155  }
156  virtual ~IsMemberOfOp() noexcept {}
157 
158  bool RunOnDevice() override {
159  return DispatchHelper<
161  }
162 
163  template <typename T>
164  bool DoRunWithType() {
165  auto& input = Input(0);
166  auto* output = Output(0);
167  output->ResizeLike(input);
168 
169  if (!values_.has_values()) {
170  values_.set(OperatorBase::GetRepeatedArgument<T>(VALUE_TAG));
171  }
172  const auto& values = values_.get<T>();
173 
174  const T* input_data = input.template data<T>();
175  bool* output_data = output->template mutable_data<bool>();
176  for (int i = 0; i < input.size(); ++i) {
177  output_data[i] = values.find(input_data[i]) != values.end();
178  }
179  return true;
180  }
181 
182  protected:
183  IsMemberOfValueHolder values_;
184 };
185 
186 } // namespace caffe2
187 
188 #endif // CAFFE2_OPERATORS_ELEMENTWISE_LOGICAL_OPS_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.