3 #include <torch/detail/static.h> 4 #include <torch/nn/module.h> 5 #include <torch/nn/pimpl.h> 6 #include <torch/types.h> 8 #include <torch/csrc/autograd/variable.h> 9 #include <torch/csrc/utils/memory.h> 10 #include <torch/csrc/utils/variadic.h> 12 #include <ATen/Device.h> 15 #include <type_traits> 117 template <
typename ModuleType>
118 explicit AnyModule(std::shared_ptr<ModuleType> module);
123 typename = torch::detail::enable_if_module_t<ModuleType>>
127 template <
typename ModuleType>
133 AnyModule& operator=(AnyModule&&) =
default;
137 AnyModule& operator=(
const AnyModule& other);
145 template <
typename ModuleType>
146 AnyModule& operator=(std::shared_ptr<ModuleType> module);
151 template <
typename... ArgumentTypes>
157 template <
typename ReturnType =
torch::Tensor,
typename... ArgumentTypes>
158 ReturnType
forward(ArgumentTypes&&... arguments);
162 template <
typename T,
typename = torch::detail::enable_if_module_t<T>>
167 template <
typename T,
typename = torch::detail::enable_if_module_t<T>>
168 const T&
get()
const;
172 template <
typename T,
typename ContainedType =
typename T::ContainedType>
177 std::shared_ptr<Module>
ptr()
const;
180 template <
typename T,
typename = torch::detail::enable_if_module_t<T>>
181 std::shared_ptr<T>
ptr()
const;
201 template <
typename ModuleType,
typename... ArgumentTypes>
211 typename... ArgumentTypes>
212 std::unique_ptr<Placeholder> make_holder(
213 std::shared_ptr<ModuleType>&& module,
214 ReturnType (Class::*)(ArgumentTypes...));
217 template <
typename ModuleType,
typename ReturnType,
typename... ArgumentTypes>
218 ModuleType& get_(ReturnType (ModuleType::*)(ArgumentTypes...))
const;
221 template <
typename ModuleType>
222 ModuleType& get_()
const;
225 std::unique_ptr<Placeholder> content_;
244 Value& operator=(
const Value& other) =
delete;
249 template <
typename T>
252 !std::is_reference<T>::value,
253 "Value stores decayed types, you cannot cast it to a reference type");
255 !std::is_array<T>::value,
256 "Value stores decayed types, you must cast it to T* instead of T[]");
257 if (
typeid(
T).hash_code() ==
type_info().hash_code()) {
258 return &
static_cast<Holder<T>&
>(*content_).value;
266 template <
typename T>
268 if (
auto* maybe_value = try_get<T>()) {
272 "Attempted to cast Value to ",
274 ", but its actual type is ",
280 return content_->type_info;
291 torch::disable_if_t<std::is_same<autograd::Variable, T>::value>>
294 torch::make_unique<Holder<decay_t<T>>>(std::forward<T>(value))) {}
299 : Value(
Tensor(std::move(variable))) {}
306 explicit Placeholder(
const std::type_info& type_info_) noexcept
308 virtual ~Placeholder() =
default;
315 template <
typename T>
316 struct Holder :
public Placeholder {
318 template <
typename U>
319 explicit Holder(U&& value_) noexcept
320 : Placeholder(
typeid(
T)), value(std::forward<U>(value_)) {}
325 std::unique_ptr<Placeholder> content_;
331 using AnyModule::Value::Placeholder::Placeholder;
334 virtual Value forward(std::vector<Value>&& arguments) = 0;
337 virtual std::shared_ptr<Module>
ptr() = 0;
340 virtual std::unique_ptr<Placeholder> copy()
const = 0;
348 template <
typename ModuleType,
typename... ArgumentTypes>
352 template <
typename T>
353 decay_t<T>&& operator()(
size_t index) {
354 AT_ASSERT(index < arguments_.size());
355 auto& value = arguments_[index];
356 if (
auto* maybe_value = value.template try_get<decay_t<T>>()) {
357 return std::move(*maybe_value);
360 "Expected argument #",
364 ", but received value of type ",
367 std::vector<Value>& arguments_;
372 template <
typename... Ts>
373 Value operator()(Ts&&... ts) {
374 return Value(module_->forward(std::forward<Ts>(ts)...));
376 std::shared_ptr<ModuleType>& module_;
380 explicit Holder(std::shared_ptr<ModuleType>&& module_)
381 :
Placeholder(
typeid(ModuleType)), module(std::move(module_)) {}
385 Value forward(std::vector<Value>&& arguments)
override {
387 arguments.size() ==
sizeof...(ArgumentTypes),
389 "'s forward() method expects ",
390 sizeof...(ArgumentTypes),
391 " arguments, but received ",
395 return torch::unpack<
Value, ArgumentTypes...>(
399 std::shared_ptr<Module>
ptr()
override {
403 std::unique_ptr<Placeholder> copy()
const override {
404 return torch::make_unique<Holder>(*this);
408 return torch::make_unique<Holder>(
409 std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
413 std::shared_ptr<ModuleType> module;
418 template <
typename ModuleType>
420 : content_(make_holder(
422 &
std::remove_reference<ModuleType>::type::
forward)) {}
424 template <
typename ModuleType,
typename>
427 std::make_shared<ModuleType>(
std::
forward<ModuleType>(module))) {}
429 template <
typename ModuleType>
434 : content_(other.content_ ? other.content_->copy() : nullptr) {}
437 if (
this != &other) {
438 content_ = other.content_ ? other.content_->copy() :
nullptr;
445 clone.content_ = content_ ? content_->clone(device) :
nullptr;
449 template <
typename ModuleType>
450 AnyModule& AnyModule::operator=(std::shared_ptr<ModuleType> module) {
451 return (*
this =
AnyModule(std::move(module)));
454 template <
typename... ArgumentTypes>
456 AT_CHECK(!
is_empty(),
"Cannot call forward() on an empty AnyModule");
457 std::vector<Value> values;
458 values.reserve(
sizeof...(ArgumentTypes));
460 [&values](
Value&& value) { values.push_back(std::move(value)); },
461 Value(std::forward<ArgumentTypes>(arguments))...);
462 return content_->forward(std::move(values));
465 template <
typename ReturnType,
typename... ArgumentTypes>
467 return any_forward(std::forward<ArgumentTypes>(arguments)...)
468 .template get<ReturnType>();
471 template <
typename T,
typename>
473 AT_CHECK(!
is_empty(),
"Cannot call get() on an empty AnyModule");
477 template <
typename T,
typename>
479 AT_CHECK(!
is_empty(),
"Cannot call get() on an empty AnyModule");
483 template <
typename T,
typename ContainedType>
485 return T(ptr<ContainedType>());
489 AT_CHECK(!
is_empty(),
"Cannot call ptr() on an empty AnyModule");
490 return content_->ptr();
493 template <
typename T,
typename>
495 AT_CHECK(!
is_empty(),
"Cannot call ptr() on an empty AnyModule");
498 return std::dynamic_pointer_cast<
T>(
ptr());
502 AT_CHECK(!
is_empty(),
"Cannot call type_info() on an empty AnyModule");
503 return content_->type_info;
507 return content_ ==
nullptr;
516 typename... ArgumentTypes>
517 std::unique_ptr<AnyModule::Placeholder> AnyModule::make_holder(
518 std::shared_ptr<ModuleType>&& module,
519 ReturnType (Class::*)(ArgumentTypes...)) {
521 torch::detail::check_not_lvalue_references<ArgumentTypes...>(),
522 "Modules stored inside AnyModule must not take references. " 523 "Use pointers instead.");
525 !std::is_void<ReturnType>::value,
526 "AnyModule cannot store modules that return void " 527 "(you can return a dummy value).");
528 return torch::make_unique<Holder<decay_t<ModuleType>, ArgumentTypes...>>(
532 template <
typename ModuleType>
533 ModuleType& AnyModule::get_()
const {
534 using M =
typename std::remove_reference<ModuleType>::type;
537 "Can only call AnyModule::get<T> with a type T that has a forward method");
538 return get_(&M::forward);
541 template <
typename ModuleType,
typename ReturnType,
typename... ArgumentTypes>
542 ModuleType& AnyModule::get_(
543 ReturnType (ModuleType::*)(ArgumentTypes...))
const {
544 if (
typeid(ModuleType).hash_code() ==
type_info().hash_code()) {
545 return *
static_cast<Holder<ModuleType, ArgumentTypes...
>&>(*content_)
549 "Attempted to cast module of type ",
const std::type_info & type_info() const
Returns the type_info object of the contained value.
std::shared_ptr< Module > ptr() const
Returns a std::shared_ptr whose dynamic type is that of the underlying module.
T & get()
Attempts to cast the underlying module to the given module type.
A simplified implementation of std::any which stores a type erased object, whose concrete value can b...
const std::type_info & type_info() const noexcept
Returns the type_info object of the contained value.
Value any_forward(ArgumentTypes &&...arguments)
Invokes forward() on the contained module with the given arguments, and returns the return value as a...
AnyModule clone(optional< Device > device=nullopt) const
Creates a deep copy of an AnyModule if it contains a module, else an empty AnyModule if it is empty...
Detects if a type T has a forward() method.
A ModuleHolder is essentially a wrapper around std::shared_ptr<M> where M is an nn::Module subclass...
Variable A Variable augments a Tensor with the ability to interact in our autograd machinery...
ReturnType forward(ArgumentTypes &&...arguments)
Invokes forward() on the contained module with the given arguments, and casts the returned Value to t...
Stores a type erased Module.
std::string demangle(const char *name)
Utility to demangle a C++ symbol name.
AnyModule()=default
A default-constructed AnyModule is in an empty state.
T * try_get()
Returns a pointer to the value contained in the Value if the type passed as template parameter matche...
bool is_empty() const noexcept
Returns true if the AnyModule does not contain a module.