Caffe2 - C++ API
A deep learning, cross platform ML framework
common_subexpression_elimination.h
1 
2 #pragma once
3 
4 #include "caffe2/core/common.h"
5 #include "caffe2/core/transform.h"
6 #include "caffe2/proto/caffe2_pb.h"
7 #include "caffe2/utils/proto_utils.h"
8 
9 namespace caffe2 {
10 
29  public:
31  SetPatternMatchType(SORTED_WRT_EXECUTION_ORDER);
32  }
33 
34  protected:
35  bool PatternRule(
36  const transform::Graph& g,
37  const std::vector<int>& subgraph,
38  int idx) override;
39  bool ValidatorRule(
40  const transform::Graph& g,
41  const std::vector<int>& subgraph) override;
42  bool ReplaceRule(const std::vector<int>& subgraph, transform::Graph* g_ptr)
43  override;
44 
45  private:
46  bool IsWhitelisted(string op_type) {
47  return whitelisted_ops_.count(op_type);
48  }
49  std::set<string> whitelisted_ops_ = {"LearningRate", "FC"};
50 };
51 
52 } // namespace caffe2
Graph representation of a Netdef.
Definition: graph.h:48
The Transform Base Object.
Definition: transform.h:34
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13