Caffe2 - C++ API
A deep learning, cross platform ML framework
iter_op.h
1 
17 #ifndef CAFFE2_SGD_ITER_OP_H_
18 #define CAFFE2_SGD_ITER_OP_H_
19 
20 #include <limits>
21 #include <mutex>
22 
23 #include "caffe2/core/blob_serialization.h"
24 #include "caffe2/core/context.h"
25 #include "caffe2/core/operator.h"
26 
27 namespace caffe2 {
28 
29 inline void IncrementIter(TensorCPU* output) {
30  CAFFE_ENFORCE_EQ(
31  output->size(),
32  1,
33  "The output of IterOp exists, but not of the right size.");
34  int64_t* iter = output->template mutable_data<int64_t>();
35  CAFFE_ENFORCE(*iter >= 0, "Previous iteration number is negative.");
36  CAFFE_ENFORCE(
37  *iter < std::numeric_limits<int64_t>::max(), "Overflow will happen!");
38  (*iter)++;
39 }
40 
41 // IterOp runs an iteration counter. I cannot think of a case where we would
42 // need to access the iter variable on device, so this will always produce a
43 // tensor on the CPU side. If the blob already exists and is a tensor<int64_t>
44 // object, we will simply increment it (this emulates the case when we want to
45 // resume training). Otherwise we will have the iter starting with 0.
46 template <class Context>
47 class IterOp final : public Operator<Context> {
48  public:
49  USE_OPERATOR_CONTEXT_FUNCTIONS;
50 
51  IterOp(const OperatorDef& operator_def, Workspace* ws)
52  : Operator<Context>(operator_def, ws) {}
53 
54  bool RunOnDevice() override {
55  if (InputSize() == 0) {
56  if (!OperatorBase::OutputIsType<TensorCPU>(0)) {
57  // This is the first run; set the iter to start with 0.
58  LOG(ERROR) << "You are using an old definition of IterOp that will "
59  "be deprecated soon. More specifically, IterOp now "
60  "requires an explicit in-place input and output.";
61 
62  auto* output = OperatorBase::Output<TensorCPU>(0);
63  VLOG(1) << "Initializing iter counter.";
64  output->Resize(1);
65  output->template mutable_data<int64_t>()[0] = 0;
66  }
67  }
68  IncrementIter(OperatorBase::Output<TensorCPU>(0));
69  return true;
70  }
71 };
72 
73 template <class Context>
74 class AtomicIterOp final : public Operator<Context> {
75  public:
76  USE_OPERATOR_CONTEXT_FUNCTIONS;
77 
78  AtomicIterOp(const OperatorDef& operator_def, Workspace* ws)
79  : Operator<Context>(operator_def, ws) {}
80 
81  bool RunOnDevice() override {
82  auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(0);
83  std::lock_guard<std::mutex> lg(*mutex);
84  IncrementIter(OperatorBase::Output<TensorCPU>(0));
85  return true;
86  }
87 };
88 
90  public:
96  void Serialize(
97  const Blob& blob,
98  const string& name,
99  BlobSerializerBase::SerializationAcceptor acceptor) override;
100 };
101 
103  public:
104  void Deserialize(const BlobProto& proto, Blob* blob) override;
105 };
106 
107 } // namespace caffe2
108 
109 #endif // CAFFE2_SGD_ITER_OP_H_
Blob is a general container that hosts a typed pointer.
Definition: blob.h:41
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
BlobSerializerBase is an abstract class that serializes a blob to a string.