Caffe2 - C++ API
A deep learning, cross platform ML framework
activation_distribution_observer.h
1 #pragma once
2 
3 #include "caffe2/core/observer.h"
4 #include "caffe2/core/operator.h"
5 #include "caffe2/quantization/server/dnnlowp.h"
6 #include "caffe2/quantization/server/dynamic_histogram.h"
7 
8 #include <memory>
9 #include <set>
10 #include <vector>
11 
12 namespace caffe2 {
13 
14 class OutputMinMaxObserver final : public ObserverBase<OperatorBase> {
15  public:
16  explicit OutputMinMaxObserver(OperatorBase* op);
18 
19  struct TensorInfo {
20  explicit TensorInfo(const std::string& name)
21  : min(std::numeric_limits<float>::max()),
22  max(std::numeric_limits<float>::lowest()),
23  total_min(std::numeric_limits<float>::max()),
24  total_max(std::numeric_limits<float>::lowest()),
25  name(name) {}
26 
27  void Update(float cur_min, float cur_max) {
28  min = std::min(min, cur_min);
29  max = std::max(max, cur_max);
30  total_min = std::min(total_min, cur_min);
31  total_max = std::max(total_max, cur_max);
32  }
33 
34  float min, max;
35  float total_min, total_max;
36  std::string name;
37  };
38 
39  struct OperatorInfo {
40  std::vector<TensorInfo> tensor_infos;
41  std::string type;
42  };
43 
44  // OutputMinMaxObserver is assumed to be used together with
45  // OutputMinMaxNetObserver and the information shared via shared_ptr to be
46  // prepared for the case when OutputMinMaxObserver is destroyed before
47  // OutputMinMaxNetObserver
48  std::shared_ptr<OperatorInfo> GetInfo() {
49  return info_;
50  }
51 
52  private:
53  void Stop() override;
54 
55  std::shared_ptr<OperatorInfo> info_;
56  bool warning_printed_ = false;
57 }; // class OutputMinMaxObserver
58 
59 class OutputMinMaxNetObserver final : public NetObserver {
60  public:
62  // Otherwise, print out every dum_freq invocations
63  explicit OutputMinMaxNetObserver(
64  NetBase* subject,
65  const std::string& out_file_name,
66  int dump_freq = -1);
68 
69  private:
70  void Stop() override;
71  void DumpAndReset_(
72  const std::string& out_file_name,
73  bool print_total_min_max = false);
74 
75  int dump_freq_, cnt_;
76  const std::string out_file_name_;
77  std::vector<std::shared_ptr<OutputMinMaxObserver::OperatorInfo>>
78  min_max_infos_;
79 };
80 
84 class HistogramObserver final : public ObserverBase<OperatorBase> {
85  public:
86  struct Info {
87  std::vector<dnnlowp::DynamicHistogram> histograms;
88  std::vector<dnnlowp::DynamicHistogram> total_histograms;
90  };
91 
92  explicit HistogramObserver(OperatorBase* op, std::shared_ptr<Info> info);
93 
94  private:
95  void Stop() override;
96 
97  std::shared_ptr<Info> info_;
98  bool warning_printed_ = false;
99 }; // class HistogramObserver
100 
101 class HistogramNetObserver final : public NetObserver {
102  public:
103  explicit HistogramNetObserver(
104  NetBase* subject,
105  const std::string& out_file_name,
106  int nbins,
107  int dump_freq = -1,
108  bool mul_nets = false);
110 
111  private:
112  void Stop() override;
113  void DumpAndReset_(
114  const std::string& out_file_name,
115  bool print_total_min_max = false);
116 
117  int dump_freq_, cnt_;
118 
122  bool mul_nets_;
123  const std::string out_file_name_;
124  std::vector<std::shared_ptr<HistogramObserver::Info>> hist_infos_;
125 };
126 
132  public:
134  NetBase* subject,
135  const std::string& min_max_file_name,
136  bool is_weight = false,
137  const std::string& qparams_output_file_name = "");
138 };
139 
145  : public NetObserver {
146  public:
148  NetBase* subject,
149  const std::string& histogram_file_name,
150  bool is_weight = false,
151  const std::string& qparams_output_file_name = "");
152 };
153 
154 } // namespace caffe2
Use this to implement a Observer using the Observer Pattern template.
Definition: observer.h:15
Set quantization parameters of operators based on min/max collected from OutputMinMaxObserver.
Given min/max, collect histogram.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Set quantization parameters of operators based on min/max collected from OutputMinMaxObserver.