1 #ifndef CAFFE2_SGD_ITER_OP_H_ 2 #define CAFFE2_SGD_ITER_OP_H_ 7 #include "caffe2/core/blob_serialization.h" 8 #include "caffe2/core/context.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/stats.h" 14 inline void IncrementIter(TensorCPU* output) {
18 "The output of IterOp exists, but not of the right size.");
19 int64_t* iter = output->template mutable_data<int64_t>();
20 CAFFE_ENFORCE(*iter >= 0,
"Previous iteration number is negative.");
22 *iter < std::numeric_limits<int64_t>::max(),
"Overflow will happen!");
31 template <
class Context>
34 USE_OPERATOR_CONTEXT_FUNCTIONS;
39 bool RunOnDevice()
override {
40 if (InputSize() == 0) {
41 LOG(INFO) <<
"[Input size is zero]";
42 if (!OperatorBase::OutputIsTensorType(0, CPU)) {
44 LOG(ERROR) <<
"You are using an old definition of IterOp that will " 45 "be deprecated soon. More specifically, IterOp now " 46 "requires an explicit in-place input and output.";
48 VLOG(1) <<
"Initializing iter counter.";
49 auto* output = OperatorBase::OutputTensor(
50 0, {1}, at::dtype<int64_t>().device(CPU));
51 output->template mutable_data<int64_t>()[0] = 0;
54 IncrementIter(OperatorBase::Output<Tensor>(0, CPU));
59 template <
class Context>
62 USE_OPERATOR_CONTEXT_FUNCTIONS;
66 stats_(std::string(
"atomic_iter/stats/") + operator_def.input(1)) {}
68 bool RunOnDevice()
override {
69 auto& mutex = OperatorBase::Input<std::unique_ptr<std::mutex>>(0);
70 std::lock_guard<std::mutex> lg(*mutex);
71 IncrementIter(OperatorBase::Output<Tensor>(0, CPU));
72 CAFFE_EVENT(stats_, num_iter);
77 struct AtomicIterOpStats {
78 CAFFE_STAT_CTOR(AtomicIterOpStats);
79 CAFFE_EXPORTED_STAT(num_iter);
94 BlobSerializerBase::SerializationAcceptor acceptor)
override;
99 void Deserialize(
const BlobProto& proto,
Blob* blob)
override;
104 #endif // CAFFE2_SGD_ITER_OP_H_ Blob is a general container that hosts a typed pointer.
BlobDeserializerBase is an abstract class that deserializes a blob from a BlobProto or a TensorProto...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
BlobSerializerBase is an abstract class that serializes a blob to a string.