Caffe2 - C++ API
A deep learning, cross platform ML framework
iter_op.cc
1 
17 #include "caffe2/sgd/iter_op.h"
18 
19 namespace caffe2 {
20 
22  const Blob& blob,
23  const string& name,
24  BlobSerializerBase::SerializationAcceptor acceptor) {
25  CAFFE_ENFORCE(blob.IsType<std::unique_ptr<std::mutex>>());
26  BlobProto blob_proto;
27  blob_proto.set_name(name);
28  blob_proto.set_type("std::unique_ptr<std::mutex>");
29  blob_proto.set_content("");
30  acceptor(name, blob_proto.SerializeAsString());
31 }
32 
33 void MutexDeserializer::Deserialize(const BlobProto& /* unused */, Blob* blob) {
34  *blob->GetMutable<std::unique_ptr<std::mutex>>() =
35  caffe2::make_unique<std::mutex>();
36 }
37 
38 REGISTER_CPU_OPERATOR(Iter, IterOp<CPUContext>);
39 REGISTER_CPU_OPERATOR(AtomicIter, AtomicIterOp<CPUContext>);
40 
41 REGISTER_BLOB_SERIALIZER(
42  (TypeMeta::Id<std::unique_ptr<std::mutex>>()),
44 REGISTER_BLOB_DESERIALIZER(std::unique_ptr<std::mutex>, MutexDeserializer);
45 
46 OPERATOR_SCHEMA(Iter)
47  .NumInputs(0, 1)
48  .NumOutputs(1)
49  .EnforceInplace({{0, 0}})
50  .SetDoc(R"DOC(
51 Stores a singe integer, that gets incremented on each call to Run().
52 Useful for tracking the iteration count during SGD, for example.
53 )DOC");
54 
55 OPERATOR_SCHEMA(AtomicIter)
56  .NumInputs(2)
57  .NumOutputs(1)
58  .EnforceInplace({{1, 0}})
59  .SetDoc(R"DOC(
60 Similar to Iter, but takes a mutex as the first input to make sure that
61 updates are carried out atomically. This can be used in e.g. Hogwild sgd
62 algorithms.
63 )DOC")
64  .Input(0, "mutex", "The mutex used to do atomic increment.")
65  .Input(1, "iter", "The iter counter as an int64_t TensorCPU.");
66 
67 NO_GRADIENT(Iter);
68 NO_GRADIENT(AtomicIter);
69 } // namespace caffe2
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
static CAFFE2_API CaffeTypeId Id()
Returns the unique id for the given type T.
Copyright (c) 2016-present, Facebook, Inc.
T * GetMutable(bool *is_new_object=nullptr)
Gets a mutable pointer to the stored object.
Definition: blob.h:117
bool IsType() const
Checks if the content stored in the blob is of type T.
Definition: blob.h:74
void Serialize(const Blob &blob, const string &name, BlobSerializerBase::SerializationAcceptor acceptor) override
Serializes a std::unique_ptr<std::mutex>.
Definition: iter_op.cc:21