Caffe2 - C++ API
A deep learning, cross platform ML framework
named_value.h
1 #pragma once
2 #include <ATen/ATen.h>
3 #include <ATen/core/ivalue.h>
4 #include <torch/csrc/jit/constants.h>
5 #include <torch/csrc/jit/source_range.h>
6 #include <torch/csrc/utils/variadic.h>
7 
8 namespace torch {
9 namespace jit {
10 
11 struct Value;
12 
13 struct NamedValue {
14  NamedValue(const SourceRange& loc, const std::string& name, Value* value)
15  : loc_(loc), name_(name), value_(value) {}
16  NamedValue(const SourceRange& loc, Value* value) : loc_(loc), value_(value) {}
17 
18  /* implicit */ NamedValue(Value* value) : value_(value) {}
19  NamedValue(const std::string& name, Value* value)
20  : name_(name), value_(value) {}
21 
22  /* implicit */ NamedValue(IValue value)
23  : value_(nullptr), ivalue_(std::move(value)) {}
24 
25  NamedValue(const std::string& name, IValue value)
26  : name_(name), ivalue_(std::move(value)) {}
27 
28  template <
29  typename T,
30  typename = enable_if_t<
31  (!std::is_same<decay_t<T>, NamedValue>::value &&
32  !std::is_same<decay_t<T>, Value*>::value &&
33  !std::is_same<decay_t<T>, IValue>::value)>>
34  NamedValue(T&& t) : NamedValue(IValue(std::forward<T>(t))) {}
35 
36  template <
37  typename T,
38  typename = enable_if_t<
39  (!std::is_same<decay_t<T>, Value*>::value &&
40  !std::is_same<decay_t<T>, IValue>::value)>>
41  NamedValue(const std::string& name, T&& t)
42  : NamedValue(name, IValue(std::forward<T>(t))) {}
43 
44  SourceRange locOr(const SourceRange& backup_location) const {
45  if (!loc_)
46  return backup_location;
47  return loc();
48  }
49 
50  // note: this will insert a constant node into the graph at the current
51  // insert point if this NamedValue is actually a constant
52  Value* value(Graph& g) const {
53  if (!value_)
54  return insertConstant(
55  g, ivalue_); // use insertConstant to remove need to include ir.h here
56  return value_;
57  }
58 
59  const std::string& name() const {
60  AT_ASSERT(name_);
61  return *name_;
62  }
63 
64  const SourceRange& loc() const {
65  AT_ASSERT(loc_);
66  return *loc_;
67  }
68 
69  private:
72  Value* value_{nullptr};
73  // only valid if value_ == nullptr;
74  IValue ivalue_;
75 };
76 
77 } // namespace jit
78 } // namespace torch
Definition: jit_type.h:17