Caffe2 - C++ API
A deep learning, cross platform ML framework
any.h
1 #pragma once
2 
3 #include <torch/detail/static.h>
4 #include <torch/nn/module.h>
5 #include <torch/nn/pimpl.h>
6 #include <torch/types.h>
7 
8 #include <torch/csrc/autograd/variable.h>
9 #include <torch/csrc/utils/memory.h>
10 #include <torch/csrc/utils/variadic.h>
11 
12 #include <ATen/Device.h>
13 
14 #include <memory>
15 #include <type_traits>
16 #include <typeinfo>
17 #include <utility>
18 #include <vector>
19 
20 namespace torch {
21 namespace nn {
22 
108 class AnyModule {
109  public:
111  class Value;
112 
114  AnyModule() = default;
115 
117  template <typename ModuleType>
118  explicit AnyModule(std::shared_ptr<ModuleType> module);
119 
121  template <
122  typename ModuleType,
123  typename = torch::detail::enable_if_module_t<ModuleType>>
124  explicit AnyModule(ModuleType&& module);
125 
127  template <typename ModuleType>
128  explicit AnyModule(const ModuleHolder<ModuleType>& module_holder);
129 
132  AnyModule(AnyModule&&) = default;
133  AnyModule& operator=(AnyModule&&) = default;
134 
136  AnyModule(const AnyModule& other);
137  AnyModule& operator=(const AnyModule& other);
138 
141  AnyModule clone(optional<Device> device = nullopt) const;
142 
145  template <typename ModuleType>
146  AnyModule& operator=(std::shared_ptr<ModuleType> module);
147 
151  template <typename... ArgumentTypes>
152  Value any_forward(ArgumentTypes&&... arguments);
153 
157  template <typename ReturnType = torch::Tensor, typename... ArgumentTypes>
158  ReturnType forward(ArgumentTypes&&... arguments);
159 
162  template <typename T, typename = torch::detail::enable_if_module_t<T>>
163  T& get();
164 
167  template <typename T, typename = torch::detail::enable_if_module_t<T>>
168  const T& get() const;
169 
172  template <typename T, typename ContainedType = typename T::ContainedType>
173  T get() const;
174 
177  std::shared_ptr<Module> ptr() const;
178 
180  template <typename T, typename = torch::detail::enable_if_module_t<T>>
181  std::shared_ptr<T> ptr() const;
182 
184  const std::type_info& type_info() const;
185 
187  bool is_empty() const noexcept;
188 
189  private:
194  struct Placeholder;
195 
201  template <typename ModuleType, typename... ArgumentTypes>
202  struct Holder;
203 
207  template <
208  typename ModuleType,
209  typename Class,
210  typename ReturnType,
211  typename... ArgumentTypes>
212  std::unique_ptr<Placeholder> make_holder(
213  std::shared_ptr<ModuleType>&& module,
214  ReturnType (Class::*)(ArgumentTypes...));
215 
217  template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
218  ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...)) const;
219 
221  template <typename ModuleType>
222  ModuleType& get_() const;
223 
225  std::unique_ptr<Placeholder> content_;
226 };
227 
228 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule::Value ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
229 
236  public:
239  Value(Value&&) = default;
240  Value& operator=(Value&&) = default;
241 
243  Value(const Value& other) = delete;
244  Value& operator=(const Value& other) = delete;
245 
249  template <typename T>
250  T* try_get() {
251  static_assert(
252  !std::is_reference<T>::value,
253  "Value stores decayed types, you cannot cast it to a reference type");
254  static_assert(
255  !std::is_array<T>::value,
256  "Value stores decayed types, you must cast it to T* instead of T[]");
257  if (typeid(T).hash_code() == type_info().hash_code()) {
258  return &static_cast<Holder<T>&>(*content_).value;
259  }
260  return nullptr;
261  }
262 
266  template <typename T>
267  T get() {
268  if (auto* maybe_value = try_get<T>()) {
269  return *maybe_value;
270  }
271  AT_ERROR(
272  "Attempted to cast Value to ",
273  c10::demangle(typeid(T).name()),
274  ", but its actual type is ",
275  c10::demangle(type_info().name()));
276  }
277 
279  const std::type_info& type_info() const noexcept {
280  return content_->type_info;
281  }
282 
283  private:
284  friend class AnyModule;
285  friend struct TestValue;
286 
288  template <
289  typename T,
290  typename =
291  torch::disable_if_t<std::is_same<autograd::Variable, T>::value>>
292  explicit Value(T&& value)
293  : content_(
294  torch::make_unique<Holder<decay_t<T>>>(std::forward<T>(value))) {}
295 
298  explicit Value(autograd::Variable variable)
299  : Value(Tensor(std::move(variable))) {}
300 
305  struct Placeholder {
306  explicit Placeholder(const std::type_info& type_info_) noexcept
307  : type_info(type_info_) {}
308  virtual ~Placeholder() = default;
309  const std::type_info& type_info;
310  };
311 
315  template <typename T>
316  struct Holder : public Placeholder {
318  template <typename U>
319  explicit Holder(U&& value_) noexcept
320  : Placeholder(typeid(T)), value(std::forward<U>(value_)) {}
321  T value;
322  };
323 
325  std::unique_ptr<Placeholder> content_;
326 };
327 
328 // ~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule::Placeholder ~~~~~~~~~~~~~~~~~~~~~~~~~~
329 
330 struct AnyModule::Placeholder : public AnyModule::Value::Placeholder {
331  using AnyModule::Value::Placeholder::Placeholder;
332 
334  virtual Value forward(std::vector<Value>&& arguments) = 0;
335 
337  virtual std::shared_ptr<Module> ptr() = 0;
338 
340  virtual std::unique_ptr<Placeholder> copy() const = 0;
341 
343  virtual std::unique_ptr<Placeholder> clone(optional<Device> device) const = 0;
344 };
345 
346 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule::Holder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
347 
348 template <typename ModuleType, typename... ArgumentTypes>
349 struct AnyModule::Holder : public AnyModule::Placeholder {
351  struct CheckedGetter {
352  template <typename T>
353  decay_t<T>&& operator()(size_t index) {
354  AT_ASSERT(index < arguments_.size());
355  auto& value = arguments_[index];
356  if (auto* maybe_value = value.template try_get<decay_t<T>>()) {
357  return std::move(*maybe_value);
358  }
359  AT_ERROR(
360  "Expected argument #",
361  index,
362  " to be of type ",
363  c10::demangle(typeid(T).name()),
364  ", but received value of type ",
365  c10::demangle(value.type_info().name()));
366  }
367  std::vector<Value>& arguments_;
368  };
369 
371  struct InvokeForward {
372  template <typename... Ts>
373  Value operator()(Ts&&... ts) {
374  return Value(module_->forward(std::forward<Ts>(ts)...));
375  }
376  std::shared_ptr<ModuleType>& module_;
377  };
378 
380  explicit Holder(std::shared_ptr<ModuleType>&& module_)
381  : Placeholder(typeid(ModuleType)), module(std::move(module_)) {}
382 
385  Value forward(std::vector<Value>&& arguments) override {
386  AT_CHECK(
387  arguments.size() == sizeof...(ArgumentTypes),
388  c10::demangle(type_info.name()),
389  "'s forward() method expects ",
390  sizeof...(ArgumentTypes),
391  " arguments, but received ",
392  arguments.size());
393  // FYI: During invocation of a module's `forward()` method, the values live
394  // in the `arguments` vector inside this function.
395  return torch::unpack<Value, ArgumentTypes...>(
396  InvokeForward{module}, CheckedGetter{arguments});
397  }
398 
399  std::shared_ptr<Module> ptr() override {
400  return module;
401  }
402 
403  std::unique_ptr<Placeholder> copy() const override {
404  return torch::make_unique<Holder>(*this);
405  }
406 
407  std::unique_ptr<Placeholder> clone(optional<Device> device) const override {
408  return torch::make_unique<Holder>(
409  std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
410  }
411 
413  std::shared_ptr<ModuleType> module;
414 };
415 
416 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
417 
418 template <typename ModuleType>
419 AnyModule::AnyModule(std::shared_ptr<ModuleType> module)
420  : content_(make_holder(
421  std::move(module),
422  &std::remove_reference<ModuleType>::type::forward)) {}
423 
424 template <typename ModuleType, typename>
425 AnyModule::AnyModule(ModuleType&& module)
426  : AnyModule(
427  std::make_shared<ModuleType>(std::forward<ModuleType>(module))) {}
428 
429 template <typename ModuleType>
431  : AnyModule(module_holder.ptr()) {}
432 
433 inline AnyModule::AnyModule(const AnyModule& other)
434  : content_(other.content_ ? other.content_->copy() : nullptr) {}
435 
436 inline AnyModule& AnyModule::operator=(const AnyModule& other) {
437  if (this != &other) {
438  content_ = other.content_ ? other.content_->copy() : nullptr;
439  }
440  return *this;
441 }
442 
445  clone.content_ = content_ ? content_->clone(device) : nullptr;
446  return clone;
447 }
448 
449 template <typename ModuleType>
450 AnyModule& AnyModule::operator=(std::shared_ptr<ModuleType> module) {
451  return (*this = AnyModule(std::move(module)));
452 }
453 
454 template <typename... ArgumentTypes>
455 AnyModule::Value AnyModule::any_forward(ArgumentTypes&&... arguments) {
456  AT_CHECK(!is_empty(), "Cannot call forward() on an empty AnyModule");
457  std::vector<Value> values;
458  values.reserve(sizeof...(ArgumentTypes));
459  torch::apply(
460  [&values](Value&& value) { values.push_back(std::move(value)); },
461  Value(std::forward<ArgumentTypes>(arguments))...);
462  return content_->forward(std::move(values));
463 }
464 
465 template <typename ReturnType, typename... ArgumentTypes>
466 ReturnType AnyModule::forward(ArgumentTypes&&... arguments) {
467  return any_forward(std::forward<ArgumentTypes>(arguments)...)
468  .template get<ReturnType>();
469 }
470 
471 template <typename T, typename>
473  AT_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
474  return get_<T>();
475 }
476 
477 template <typename T, typename>
478 const T& AnyModule::get() const {
479  AT_CHECK(!is_empty(), "Cannot call get() on an empty AnyModule");
480  return get_<T>();
481 }
482 
483 template <typename T, typename ContainedType>
484 T AnyModule::get() const {
485  return T(ptr<ContainedType>());
486 }
487 
488 inline std::shared_ptr<Module> AnyModule::ptr() const {
489  AT_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
490  return content_->ptr();
491 }
492 
493 template <typename T, typename>
494 std::shared_ptr<T> AnyModule::ptr() const {
495  AT_CHECK(!is_empty(), "Cannot call ptr() on an empty AnyModule");
496  // Call get() but discard the value, just to do the type checking.
497  get_<T>();
498  return std::dynamic_pointer_cast<T>(ptr());
499 }
500 
501 inline const std::type_info& AnyModule::type_info() const {
502  AT_CHECK(!is_empty(), "Cannot call type_info() on an empty AnyModule");
503  return content_->type_info;
504 }
505 
506 inline bool AnyModule::is_empty() const noexcept {
507  return content_ == nullptr;
508 }
509 
510 // Private Methods
511 
512 template <
513  typename ModuleType,
514  typename Class,
515  typename ReturnType,
516  typename... ArgumentTypes>
517 std::unique_ptr<AnyModule::Placeholder> AnyModule::make_holder(
518  std::shared_ptr<ModuleType>&& module,
519  ReturnType (Class::*)(ArgumentTypes...)) {
520  static_assert(
521  torch::detail::check_not_lvalue_references<ArgumentTypes...>(),
522  "Modules stored inside AnyModule must not take references. "
523  "Use pointers instead.");
524  static_assert(
525  !std::is_void<ReturnType>::value,
526  "AnyModule cannot store modules that return void "
527  "(you can return a dummy value).");
528  return torch::make_unique<Holder<decay_t<ModuleType>, ArgumentTypes...>>(
529  std::move(module));
530 }
531 
532 template <typename ModuleType>
533 ModuleType& AnyModule::get_() const {
534  using M = typename std::remove_reference<ModuleType>::type;
535  static_assert(
537  "Can only call AnyModule::get<T> with a type T that has a forward method");
538  return get_(&M::forward);
539 }
540 
541 template <typename ModuleType, typename ReturnType, typename... ArgumentTypes>
542 ModuleType& AnyModule::get_(
543  ReturnType (ModuleType::*)(ArgumentTypes...)) const {
544  if (typeid(ModuleType).hash_code() == type_info().hash_code()) {
545  return *static_cast<Holder<ModuleType, ArgumentTypes...>&>(*content_)
546  .module;
547  }
548  AT_ERROR(
549  "Attempted to cast module of type ",
550  c10::demangle(type_info().name()),
551  " to type ",
552  c10::demangle(typeid(ModuleType).name()));
553 }
554 
555 } // namespace nn
556 } // namespace torch
const std::type_info & type_info() const
Returns the type_info object of the contained value.
Definition: any.h:501
Definition: any.cpp:108
std::shared_ptr< Module > ptr() const
Returns a std::shared_ptr whose dynamic type is that of the underlying module.
Definition: any.h:488
T & get()
Attempts to cast the underlying module to the given module type.
Definition: any.h:472
A simplified implementation of std::any which stores a type erased object, whose concrete value can b...
Definition: any.h:235
const std::type_info & type_info() const noexcept
Returns the type_info object of the contained value.
Definition: any.h:279
Value any_forward(ArgumentTypes &&...arguments)
Invokes forward() on the contained module with the given arguments, and returns the return value as a...
Definition: any.h:455
AnyModule clone(optional< Device > device=nullopt) const
Creates a deep copy of an AnyModule if it contains a module, else an empty AnyModule if it is empty...
Definition: any.h:443
Detects if a type T has a forward() method.
Definition: static.h:19
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
Definition: pimpl.h:26
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
Definition: variable.h:85
Definition: jit_type.h:17
ReturnType forward(ArgumentTypes &&...arguments)
Invokes forward() on the contained module with the given arguments, and casts the returned Value to t...
Definition: any.h:466
Stores a type erased Module.
Definition: any.h:108
std::string demangle(const char *name)
Utility to demangle a C++ symbol name.
Definition: Type.cpp:23
AnyModule()=default
A default-constructed AnyModule is in an empty state.
T * try_get()
Returns a pointer to the value contained in the Value if the type passed as template parameter matche...
Definition: any.h:250
bool is_empty() const noexcept
Returns true if the AnyModule does not contain a module.
Definition: any.h:506