Caffe2 - C++ API
A deep learning, cross platform ML framework
boolean_mask_ops.h
1 #ifndef BOOLEAN_MASK_OPS_H
2 #define BOOLEAN_MASK_OPS_H
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/core/tensor.h"
7 #include "caffe2/utils/conversions.h"
8 
9 namespace caffe2 {
10 
11 template <class Context>
12 class BooleanMaskOp final : public Operator<Context> {
13  public:
14  USE_OPERATOR_CONTEXT_FUNCTIONS;
15  template <class... Args>
16  explicit BooleanMaskOp(Args&&... args)
17  : Operator<Context>(std::forward<Args>(args)...) {}
18 
19  bool RunOnDevice() override;
20 };
21 
22 template <class Context>
23 class SequenceMaskOp final : public Operator<Context> {
24  public:
25  USE_OPERATOR_CONTEXT_FUNCTIONS;
26  explicit SequenceMaskOp(const OperatorDef& operator_def, Workspace* ws)
27  : Operator<Context>(operator_def, ws),
28  axis_(this->template GetSingleArgument<int>("axis", 1)),
29  radius_(this->template GetSingleArgument<int>("radius", 10)),
30  grad_(this->template GetSingleArgument<bool>("grad", false)),
31  fill_val_(this->template GetSingleArgument<float>(
32  "fill_val",
33  -1.0f * std::numeric_limits<float>::infinity())) {
34  // Mode argument is required
35  mode_ = GetArgument(operator_def, "mode").s();
36  // batch argument is optional, but if not given, we don't want a default val
37  if (HasArgument("batch")) {
38  batch_ = GetArgument(operator_def, "batch").i();
39  }
40 
41  if (HasArgument("repeat_from_axis")) {
42  CAFFE_ENFORCE(
43  mode_ == "sequence",
44  "repeat_from_axis currently only supported in sequence mode.");
45  CAFFE_ENFORCE(
46  !HasArgument("batch"),
47  "repeat_from_axis and batch not currently supported together.");
48  repeat_from_ =
49  this->template GetSingleArgument<int>("repeat_from_axis", -1);
50  }
51  }
52 
53  bool RunOnDevice() override;
54 
55  template <typename T>
56  bool DoRunWithType();
57 
58  private:
59  int axis_;
60  int radius_;
61  std::string mode_;
62  bool grad_;
63  float fill_val_;
64  int batch_;
65  int repeat_from_;
66 };
67 
68 } // namespace caffe2
69 
70 #endif
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70