Caffe2 - C++ API
A deep learning, cross platform ML framework
boolean_mask_ops.h
1 
17 #ifndef BOOLEAN_MASK_OPS_H
18 #define BOOLEAN_MASK_OPS_H
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/core/tensor.h"
23 #include "caffe2/utils/conversions.h"
24 
25 namespace caffe2 {
26 
27 template <class Context>
28 class BooleanMaskOp final : public Operator<Context> {
29  public:
30  USE_OPERATOR_CONTEXT_FUNCTIONS;
31  BooleanMaskOp(const OperatorDef& operator_def, Workspace* ws)
32  : Operator<Context>(operator_def, ws) {}
33 
34  bool RunOnDevice() override;
35 };
36 
37 template <class Context>
38 class SequenceMaskOp final : public Operator<Context> {
39  public:
40  USE_OPERATOR_CONTEXT_FUNCTIONS;
41  explicit SequenceMaskOp(const OperatorDef& operator_def, Workspace* ws)
42  : Operator<Context>(operator_def, ws),
43  axis_(OperatorBase::GetSingleArgument<int>("axis", 1)),
44  radius_(OperatorBase::GetSingleArgument<int>("radius", 10)),
45  grad_(OperatorBase::GetSingleArgument<bool>("grad", false)),
46  fill_val_(OperatorBase::GetSingleArgument<float>(
47  "fill_val",
48  -1.0f * std::numeric_limits<float>::infinity())) {
49  // Mode argument is required
50  mode_ = GetArgument(operator_def, "mode").s();
51  // batch argument is optional, but if not given, we don't want a default val
52  if (HasArgument("batch")) {
53  batch_ = GetArgument(operator_def, "batch").i();
54  }
55 
56  if (HasArgument("repeat_from_axis")) {
57  CAFFE_ENFORCE(
58  mode_ == "sequence",
59  "repeat_from_axis currently only supported in sequence mode.");
60  CAFFE_ENFORCE(
61  !HasArgument("batch"),
62  "repeat_from_axis and batch not currently supported together.");
63  repeat_from_ =
64  OperatorBase::GetSingleArgument<int>("repeat_from_axis", -1);
65  }
66  }
67 
68  bool RunOnDevice() override;
69 
70  template <typename T>
71  bool DoRunWithType();
72 
73  private:
74  int axis_;
75  int radius_;
76  std::string mode_;
77  bool grad_;
78  float fill_val_;
79  int batch_;
80  int repeat_from_;
81 };
82 
83 } // namespace caffe2
84 
85 #endif
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:52