Caffe2 - C++ API
A deep learning, cross platform ML framework
atomic_ops.cc
1 
17 #include <mutex>
18 #include "caffe2/core/context.h"
19 #include "caffe2/core/operator.h"
20 
21 namespace caffe2 {
22 namespace fb {
23 namespace {
24 
25 class CreateMutexOp final : public Operator<CPUContext> {
26  public:
27  CreateMutexOp(const OperatorDef& operator_def, Workspace* ws)
28  : Operator<CPUContext>(operator_def, ws) {}
29 
30  bool RunOnDevice() override {
31  *OperatorBase::Output<std::unique_ptr<std::mutex>>(0) =
32  std::unique_ptr<std::mutex>(new std::mutex);
33  return true;
34  }
35 };
36 
37 class AtomicFetchAddOp final : public Operator<CPUContext> {
38  public:
39  AtomicFetchAddOp(const OperatorDef& operator_def, Workspace* ws)
40  : Operator<CPUContext>(operator_def, ws) {}
41 
42  bool RunOnDevice() override {
43  auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(0);
44  auto& a = Input(1);
45  auto& b = Input(2);
46  auto* c = Output(0);
47  auto* d = Output(1);
48  c->Resize(std::vector<TIndex>());
49  d->Resize(std::vector<TIndex>());
50  auto* aPtr = a.data<int32_t>();
51  auto* bPtr = b.data<int32_t>();
52  auto* cPtr = c->mutable_data<int32_t>();
53  auto* dPtr = d->mutable_data<int32_t>();
54  std::lock_guard<std::mutex> lg(*mutex);
55  *dPtr = *aPtr;
56  *cPtr = *aPtr + *bPtr;
57  return true;
58  }
59 };
60 
61 class CreateAtomicBoolOp final : public Operator<CPUContext> {
62  public:
63  using Operator::Operator;
64 
65  bool RunOnDevice() override {
66  *OperatorBase::Output<std::unique_ptr<std::atomic<bool>>>(0) =
67  std::unique_ptr<std::atomic<bool>>(new std::atomic<bool>(false));
68  return true;
69  }
70 };
71 
72 class ConditionalSetAtomicBoolOp final : public Operator<CPUContext> {
73  public:
74  using Operator::Operator;
75 
76  bool RunOnDevice() override {
77  auto& ptr =
78  OperatorBase::Input<std::unique_ptr<std::atomic<bool>>>(ATOMIC_BOOL);
79  if (Input(CONDITION).data<bool>()[0]) {
80  ptr->store(true);
81  }
82  return true;
83  }
84 
85  private:
86  INPUT_TAGS(ATOMIC_BOOL, CONDITION);
87 };
88 
89 class CheckAtomicBoolOp final : public Operator<CPUContext> {
90  public:
91  using Operator::Operator;
92 
93  bool RunOnDevice() override {
94  auto& ptr = OperatorBase::Input<std::unique_ptr<std::atomic<bool>>>(0);
95  Output(0)->Resize(1);
96  *Output(0)->mutable_data<bool>() = ptr->load();
97  return true;
98  }
99 };
100 
101 REGISTER_CPU_OPERATOR(CreateMutex, CreateMutexOp);
102 REGISTER_CPU_OPERATOR(AtomicFetchAdd, AtomicFetchAddOp);
103 
104 REGISTER_CPU_OPERATOR(CreateAtomicBool, CreateAtomicBoolOp);
105 REGISTER_CPU_OPERATOR(ConditionalSetAtomicBool, ConditionalSetAtomicBoolOp);
106 REGISTER_CPU_OPERATOR(CheckAtomicBool, CheckAtomicBoolOp);
107 
108 OPERATOR_SCHEMA(CreateMutex)
109  .NumInputs(0)
110  .NumOutputs(1)
111  .SetDoc("Creates an unlocked mutex and returns it in a unique_ptr blob.")
112  .Output(0, "mutex_ptr", "Blob containing a std::unique_ptr<mutex>.");
113 
114 OPERATOR_SCHEMA(AtomicFetchAdd)
115  .NumInputs(3)
116  .NumOutputs(2)
117  .SetDoc(R"DOC(
118 Given a mutex and two int32 scalar tensors, performs an atomic fetch add
119 by mutating the first argument and adding it to the second input
120 argument. Returns the updated integer and the value prior to the update.
121 )DOC")
122  .Input(0, "mutex_ptr", "Blob containing to a unique_ptr<mutex>")
123  .Input(1, "mut_value", "Value to be mutated after the sum.")
124  .Input(2, "increment", "Value to add to the first operand.")
125  .Output(0, "mut_value", "Mutated value after sum. Usually same as input 1.")
126  .Output(1, "fetched_value", "Value of the first operand before sum.")
127  .AllowInplace({{1, 0}});
128 
129 OPERATOR_SCHEMA(CreateAtomicBool)
130  .NumInputs(0)
131  .NumOutputs(1)
132  .SetDoc("Create an unique_ptr blob to hold an atomic<bool>")
133  .Output(0, "atomic_bool", "Blob containing a unique_ptr<atomic<bool>>");
134 
135 OPERATOR_SCHEMA(ConditionalSetAtomicBool)
136  .NumInputs(2)
137  .NumOutputs(0)
138  .SetDoc(R"DOC(
139 Set an atomic<bool> to true if the given condition bool variable is true
140  )DOC")
141  .Input(0, "atomic_bool", "Blob containing a unique_ptr<atomic<bool>>")
142  .Input(1, "condition", "Blob containing a bool");
143 
144 OPERATOR_SCHEMA(CheckAtomicBool)
145  .NumInputs(1)
146  .NumOutputs(1)
147  .SetDoc("Copy the value of an atomic<bool> to a bool")
148  .Input(0, "atomic_bool", "Blob containing a unique_ptr<atomic<bool>>")
149  .Output(0, "value", "Copy of the value for the atomic<bool>");
150 
151 SHOULD_NOT_DO_GRADIENT(CreateMutex);
152 SHOULD_NOT_DO_GRADIENT(AtomicFetchAdd);
153 SHOULD_NOT_DO_GRADIENT(CreateAtomicBool);
154 SHOULD_NOT_DO_GRADIENT(ConditionalSetAtomicBool);
155 SHOULD_NOT_DO_GRADIENT(CheckAtomicBool);
156 }
157 }
158 }
Copyright (c) 2016-present, Facebook, Inc.