Caffe2 - C++ API
A deep learning, cross platform ML framework
annotations.h
1 #pragma once
2 
3 #include "caffe2/core/common.h"
4 #include "caffe2/core/logging.h"
5 #include "caffe2/proto/caffe2_pb.h"
6 #include "nomnigraph/Representations/NeuralNet.h"
7 
8 namespace caffe2 {
9 
10 class CAFFE2_API Caffe2Annotation : public nom::repr::Annotation {
11  public:
12  Caffe2Annotation() : Annotation(AnnotationKind::Caffe2) {}
13  Caffe2Annotation(std::string device)
14  : Annotation(AnnotationKind::Caffe2), Device(device) {}
15  virtual ~Caffe2Annotation() {}
16 
17  void setOperatorDef(const caffe2::OperatorDef& opDef);
18  bool hasOperatorDef() const;
19  const caffe2::OperatorDef& getOperatorDef() const;
20  caffe2::OperatorDef* getMutableOperatorDef();
21 
22  void setDeviceOption(const caffe2::DeviceOption& opDef);
23  bool hasDeviceOption() const;
24  const caffe2::DeviceOption& getDeviceOption() const;
25  caffe2::DeviceOption* getMutableDeviceOption();
26 
27  // Distributed annotations
28  void setDevice(std::string device);
29  const std::string getDevice() const;
30  void setDeviceType(int device);
31  int getDeviceType() const;
32 
33  enum class ParallelizationScheme {
34  none,
35  split_by_batch,
36  split_by_length,
37  shard,
38  shard_by_number
39  };
40  void setParallelization(ParallelizationScheme, int num = -1);
41  ParallelizationScheme getParallelizationScheme() const;
42  int getParallelization() const;
43 
44  void setKeyNode(nom::repr::NNGraph::NodeRef);
45  const nom::repr::NNGraph::NodeRef& getKeyNode() const;
46  void setLengthNode(nom::repr::NNGraph::NodeRef);
47  const nom::repr::NNGraph::NodeRef& getLengthNode() const;
48 
49  void setComponentLevels(std::vector<std::string> components);
50  std::vector<std::string> getComponentLevels() const;
51 
52  static bool classof(const Annotation* A);
53 
54  private:
55  std::string Device = "";
56  caffe2::OperatorDef OpDef;
57  bool OpDefExists = false;
58 
59  // Distributed annotations
60  int DeviceType = caffe2::DeviceTypeProto::PROTO_CPU;
61  ParallelizationScheme parallelization_scheme_ = ParallelizationScheme::none;
62  int parallelization_ = -1;
63  nom::repr::NNGraph::NodeRef key_node_ = nullptr;
64  nom::repr::NNGraph::NodeRef length_node_ = nullptr;
65  std::vector<std::string> component_levels_;
66 };
67 
68 } // namespace caffe2
Annotations allow for generic manipulation of neural network operations.
Definition: NeuralNet.h:44
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
Definition: static.cpp:52
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13