Caffe2 - C++ API
A deep learning, cross platform ML framework
net_supplier.h
1 #pragma once
2 #include <functional>
3 
4 #include "caffe2/predictor/emulator/data_filler.h"
5 #include "caffe2/predictor/emulator/utils.h"
6 
7 namespace caffe2 {
8 namespace emulator {
9 
10 struct RunnableNet {
11  const caffe2::NetDef& netdef;
12  const Filler* filler;
13  std::string debug_info;
14 
16  const caffe2::NetDef& netdef_,
17  const Filler* filler_,
18  const std::string& info_ = "")
19  : netdef(netdef_), filler(filler_), debug_info(info_) {}
20 };
21 
22 /*
23  * An interface to supplier a pair of net and its filler.
24  * The net should be able to run once the filler fills the workspace.
25  * The supplier should take the ownership of both net and filler.
26  */
27 class NetSupplier {
28  public:
29  // next() should be thread-safe
30  virtual RunnableNet next() = 0;
31 
32  virtual ~NetSupplier() noexcept {}
33 };
34 
35 /*
36  * A simple net supplier that always return the same net and filler pair.
37  */
39  public:
40  SingleNetSupplier(unique_ptr<Filler> filler, caffe2::NetDef netdef)
41  : filler_(std::move(filler)), netdef_(netdef) {}
42 
43  RunnableNet next() override {
44  return RunnableNet(netdef_, filler_.get());
45  }
46 
47  protected:
48  const unique_ptr<Filler> filler_;
49  const caffe2::NetDef netdef_;
50 };
51 
52 /*
53  * A simple net supplier that always return the same net and filler pair.
54  * The SingleLoadedNetSupplier contains a shared ptr to a workspace with
55  * parameters already loaded by net loader.
56  */
58  public:
60  std::unique_ptr<Filler> filler,
61  caffe2::NetDef netdef,
62  std::shared_ptr<Workspace> ws)
63  : SingleNetSupplier(std::move(filler), netdef), ws_(ws) {}
64 
65  std::shared_ptr<Workspace> get_loaded_workspace() {
66  return ws_;
67  }
68 
69  private:
70  const std::shared_ptr<Workspace> ws_;
71 };
72 
74  public:
75  explicit MutatingNetSupplier(
76  std::unique_ptr<NetSupplier>&& core,
77  std::function<void(NetDef*)> m)
78  : core_(std::move(core)), mutator_(m) {}
79 
80  RunnableNet next() override {
81  RunnableNet orig = core_->next();
82  NetDef* new_net = nullptr;
83  {
84  std::lock_guard<std::mutex> guard(lock_);
85  nets_.push_back(orig.netdef);
86  new_net = &nets_.back();
87  }
88  mutator_(new_net);
89  return RunnableNet(*new_net, orig.filler, orig.debug_info);
90  }
91 
92  private:
93  std::mutex lock_;
94  std::unique_ptr<NetSupplier> core_;
95  std::vector<NetDef> nets_;
96  std::function<void(NetDef*)> mutator_;
97 };
98 
99 } // namespace emulator
100 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13