Caffe2 - C++ API
A deep learning, cross platform ML framework
passes.h
1 #ifndef CAFFE2_OPT_OPT_PASSS_H
2 #define CAFFE2_OPT_OPT_PASSS_H
3 
4 #include "caffe2/core/common.h"
5 #include "caffe2/core/workspace.h"
6 #include "caffe2/proto/caffe2_pb.h"
7 
8 #include "nomnigraph/Representations/NeuralNet.h"
9 
10 using namespace nom::repr;
11 
12 namespace caffe2 {
13 
14 /* This file sets up the optimization pass registry.
15  *
16  * You'll want to either create a class that inherits from OptimizationPass
17  * and implements run or use the REGISTER_OPT_PASS_FROM_FUNC(name, func)
18  * to register a function that takes in an NNModule*.
19  *
20  * If you need access to the workspace in the optimization you'll need to
21  * use a different registry and inherit from WorkspaceOptimizationPass.
22  */
23 
24 class CAFFE2_API OptimizationPass {
25  public:
26  OptimizationPass(NNModule* nn) : nn_(nn) {}
27  virtual void run() = 0;
28  virtual ~OptimizationPass() {}
29 
30  protected:
31  NNModule* nn_;
32 };
33 
34 class CAFFE2_API WorkspaceOptimizationPass : public OptimizationPass {
35  public:
37  virtual ~WorkspaceOptimizationPass() {}
38 
39  protected:
40  Workspace* ws_;
41 };
42 
43 C10_DECLARE_REGISTRY(
44  WorkspaceOptimizationPassRegistry,
46  NNModule*,
47  Workspace*);
48 #define REGISTER_WS_OPT_PASS(clsname) \
49  C10_REGISTER_CLASS(WorkspaceOptimizationPassRegistry, clsname, clsname)
50 #define REGISTER_WS_OPT_PASS_FROM_FUNC(passname, funcname) \
51  class passname : public WorkspaceOptimizationPass { \
52  public: \
53  using WorkspaceOptimizationPass::WorkspaceOptimizationPass; \
54  void run() override { \
55  funcname(nn_, ws_); \
56  } \
57  }; \
58  REGISTER_WS_OPT_PASS(passname);
59 
60 C10_DECLARE_REGISTRY(OptimizationPassRegistry, OptimizationPass, NNModule*);
61 #define REGISTER_OPT_PASS(clsname) \
62  C10_REGISTER_CLASS(OptimizationPassRegistry, clsname, clsname)
63 #define REGISTER_OPT_PASS_FROM_FUNC(passname, funcname) \
64  class passname : public OptimizationPass { \
65  public: \
66  using OptimizationPass::OptimizationPass; \
67  void run() override { \
68  funcname(nn_); \
69  } \
70  }; \
71  REGISTER_OPT_PASS(passname);
72 
73 } // namespace caffe2
74 
75 #endif // CAFFE2_OPT_OPT_PASSS_H
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13