Caffe2 - C++ API
A deep learning, cross platform ML framework
atomic_ops.cc
1 #include <mutex>
2 #include <thread>
3 #include "caffe2/core/context.h"
4 #include "caffe2/core/operator.h"
5 
6 #ifdef CAFFE2_USE_MKLDNN
7 #include <caffe2/ideep/operators/operator_fallback_ideep.h>
8 #include <caffe2/ideep/utils/ideep_operator.h>
9 #endif
10 
11 namespace caffe2 {
12 namespace fb {
13 namespace {
14 
15 class CreateMutexOp final : public Operator<CPUContext> {
16  public:
17  template <class... Args>
18  explicit CreateMutexOp(Args&&... args)
19  : Operator<CPUContext>(std::forward<Args>(args)...) {}
20 
21  bool RunOnDevice() override {
22  *OperatorBase::Output<std::unique_ptr<std::mutex>>(0) =
23  std::unique_ptr<std::mutex>(new std::mutex);
24  return true;
25  }
26 };
27 
28 class AtomicFetchAddOp final : public Operator<CPUContext> {
29  public:
30  template <class... Args>
31  explicit AtomicFetchAddOp(Args&&... args)
32  : Operator<CPUContext>(std::forward<Args>(args)...) {}
33 
34  bool RunOnDevice() override {
35  auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(0);
36  std::lock_guard<std::mutex> lg(*mutex);
37  auto& a = Input(1);
38  auto& b = Input(2);
39  auto* c = Output(0);
40  auto* d = Output(1);
41  c->Resize();
42  d->Resize();
43  auto* aPtr = a.data<int32_t>();
44  auto* bPtr = b.data<int32_t>();
45  auto* cPtr = c->template mutable_data<int32_t>();
46  auto* dPtr = d->template mutable_data<int32_t>();
47  *dPtr = *aPtr;
48  *cPtr = *aPtr + *bPtr;
49  return true;
50  }
51 };
52 
53 class CreateAtomicBoolOp final : public Operator<CPUContext> {
54  public:
55  using Operator::Operator;
56 
57  bool RunOnDevice() override {
58  *OperatorBase::Output<std::unique_ptr<std::atomic<bool>>>(0) =
59  std::unique_ptr<std::atomic<bool>>(new std::atomic<bool>(false));
60  return true;
61  }
62 };
63 
64 class ConditionalSetAtomicBoolOp final : public Operator<CPUContext> {
65  public:
66  using Operator::Operator;
67 
68  bool RunOnDevice() override {
69  auto& ptr =
70  OperatorBase::Input<std::unique_ptr<std::atomic<bool>>>(ATOMIC_BOOL);
71  if (Input(CONDITION).data<bool>()[0]) {
72  ptr->store(true);
73  }
74  return true;
75  }
76 
77  private:
78  INPUT_TAGS(ATOMIC_BOOL, CONDITION);
79 };
80 
81 class CheckAtomicBoolOp final : public Operator<CPUContext> {
82  public:
83  using Operator::Operator;
84 
85  bool RunOnDevice() override {
86  auto& ptr = OperatorBase::Input<std::unique_ptr<std::atomic<bool>>>(0);
87  Output(0)->Resize(1);
88  *Output(0)->template mutable_data<bool>() = ptr->load();
89  return true;
90  }
91 };
92 
93 REGISTER_CPU_OPERATOR(CreateMutex, CreateMutexOp);
94 REGISTER_CPU_OPERATOR(AtomicFetchAdd, AtomicFetchAddOp);
95 
96 #ifdef CAFFE2_USE_MKLDNN
97 REGISTER_IDEEP_OPERATOR(CreateMutex, IDEEPFallbackOp<CreateMutexOp, SkipIndices<0>>);
98 #endif
99 
100 REGISTER_CPU_OPERATOR(CreateAtomicBool, CreateAtomicBoolOp);
101 REGISTER_CPU_OPERATOR(ConditionalSetAtomicBool, ConditionalSetAtomicBoolOp);
102 REGISTER_CPU_OPERATOR(CheckAtomicBool, CheckAtomicBoolOp);
103 
104 OPERATOR_SCHEMA(CreateMutex)
105  .NumInputs(0)
106  .NumOutputs(1)
107  .SetDoc("Creates an unlocked mutex and returns it in a unique_ptr blob.")
108  .Output(0, "mutex_ptr", "Blob containing a std::unique_ptr<mutex>.")
109  .ScalarType(TensorProto_DataType_UNDEFINED);
110 
111 OPERATOR_SCHEMA(AtomicFetchAdd)
112  .NumInputs(3)
113  .NumOutputs(2)
114  .SetDoc(R"DOC(
115 Given a mutex and two int32 scalar tensors, performs an atomic fetch add
116 by mutating the first argument and adding it to the second input
117 argument. Returns the updated integer and the value prior to the update.
118 )DOC")
119  .Input(0, "mutex_ptr", "Blob containing to a unique_ptr<mutex>")
120  .Input(1, "mut_value", "Value to be mutated after the sum.")
121  .Input(2, "increment", "Value to add to the first operand.")
122  .Output(0, "mut_value", "Mutated value after sum. Usually same as input 1.")
123  .Output(1, "fetched_value", "Value of the first operand before sum.")
124  .AllowInplace({{1, 0}});
125 
126 OPERATOR_SCHEMA(CreateAtomicBool)
127  .NumInputs(0)
128  .NumOutputs(1)
129  .SetDoc("Create an unique_ptr blob to hold an atomic<bool>")
130  .Output(0, "atomic_bool", "Blob containing a unique_ptr<atomic<bool>>");
131 
132 OPERATOR_SCHEMA(ConditionalSetAtomicBool)
133  .NumInputs(2)
134  .NumOutputs(0)
135  .SetDoc(R"DOC(
136 Set an atomic<bool> to true if the given condition bool variable is true
137  )DOC")
138  .Input(0, "atomic_bool", "Blob containing a unique_ptr<atomic<bool>>")
139  .Input(1, "condition", "Blob containing a bool");
140 
141 OPERATOR_SCHEMA(CheckAtomicBool)
142  .NumInputs(1)
143  .NumOutputs(1)
144  .SetDoc("Copy the value of an atomic<bool> to a bool")
145  .Input(0, "atomic_bool", "Blob containing a unique_ptr<atomic<bool>>")
146  .Output(0, "value", "Copy of the value for the atomic<bool>");
147 
148 SHOULD_NOT_DO_GRADIENT(CreateMutex);
149 SHOULD_NOT_DO_GRADIENT(AtomicFetchAdd);
150 SHOULD_NOT_DO_GRADIENT(CreateAtomicBool);
151 SHOULD_NOT_DO_GRADIENT(ConditionalSetAtomicBool);
152 SHOULD_NOT_DO_GRADIENT(CheckAtomicBool);
153 }
154 }
155 }
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13