Caffe2 - C++ API
A deep learning, cross platform ML framework
iter_op.cc
1 #include "caffe2/sgd/iter_op.h"
2 
3 #ifdef CAFFE2_USE_MKLDNN
4 #include <caffe2/ideep/operators/operator_fallback_ideep.h>
5 #include <caffe2/ideep/utils/ideep_operator.h>
6 #endif
7 
8 namespace caffe2 {
9 
11  const void* pointer,
12  TypeMeta typeMeta,
13  const string& name,
14  BlobSerializerBase::SerializationAcceptor acceptor) {
15  CAFFE_ENFORCE(typeMeta.Match<std::unique_ptr<std::mutex>>());
16  BlobProto blob_proto;
17  blob_proto.set_name(name);
18  blob_proto.set_type("std::unique_ptr<std::mutex>");
19  blob_proto.set_content("");
20  acceptor(name, SerializeBlobProtoAsString_EnforceCheck(blob_proto));
21 }
22 
23 void MutexDeserializer::Deserialize(const BlobProto& /* unused */, Blob* blob) {
24  *blob->GetMutable<std::unique_ptr<std::mutex>>() =
25  caffe2::make_unique<std::mutex>();
26 }
27 
28 REGISTER_CPU_OPERATOR(Iter, IterOp<CPUContext>);
29 REGISTER_CPU_OPERATOR(AtomicIter, AtomicIterOp<CPUContext>);
30 
31 #ifdef CAFFE2_USE_MKLDNN
32 REGISTER_IDEEP_OPERATOR(AtomicIter, IDEEPFallbackOp<AtomicIterOp<CPUContext>>);
33 #endif
34 
35 REGISTER_BLOB_SERIALIZER(
36  (TypeMeta::Id<std::unique_ptr<std::mutex>>()),
38 REGISTER_BLOB_DESERIALIZER(std::unique_ptr<std::mutex>, MutexDeserializer);
39 
40 OPERATOR_SCHEMA(Iter)
41  .NumInputs(0, 1)
42  .NumOutputs(1)
43  .EnforceInplace({{0, 0}})
44  .SetDoc(R"DOC(
45 Stores a singe integer, that gets incremented on each call to Run().
46 Useful for tracking the iteration count during SGD, for example.
47 )DOC");
48 
49 OPERATOR_SCHEMA(AtomicIter)
50  .NumInputs(2)
51  .NumOutputs(1)
52  .EnforceInplace({{1, 0}})
53  .SetDoc(R"DOC(
54 Similar to Iter, but takes a mutex as the first input to make sure that
55 updates are carried out atomically. This can be used in e.g. Hogwild sgd
56 algorithms.
57 )DOC")
58  .Input(0, "mutex", "The mutex used to do atomic increment.")
59  .Input(1, "iter", "The iter counter as an int64_t TensorCPU.");
60 
61 NO_GRADIENT(Iter);
62 NO_GRADIENT(AtomicIter);
63 } // namespace caffe2
Blob is a general container that hosts a typed pointer.
Definition: blob.h:24
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
T * GetMutable()
Gets a mutable pointer to the stored object.
Definition: blob.h:100
void Serialize(const void *pointer, TypeMeta typeMeta, const string &name, BlobSerializerBase::SerializationAcceptor acceptor) override
Serializes a std::unique_ptr<std::mutex>.
Definition: iter_op.cc:10
A templated class to allow one to wrap a CPU operator as an IDEEP operator.
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:324