Caffe2 - C++ API
A deep learning, cross platform ML framework
constants.cpp
1 #include <torch/csrc/jit/constants.h>
2 #include <ATen/core/functional.h>
3 #include <torch/csrc/autograd/variable.h>
4 #include <torch/csrc/jit/custom_operator.h>
5 #include <torch/csrc/jit/operator.h>
6 
7 namespace torch {
8 namespace jit {
9 
10 // IValue -> Constant node
11 Value* insertConstant(
12  Graph& g,
13  const IValue& val,
14  const c10::TypePtr& result_type,
17  Node* n = g.create(prim::Constant);
18  if (val.isTensor()) {
19  at::Tensor ref = val.toTensor();
20  if (!ref.defined()) {
21  n->destroy();
22  return g.insertNode(g.createNone(TensorType::get()))->output();
23  }
24  // TODO: fix all cases where we are not passing in a variable,
25  // and then change this to an AT_ASSERT
26  if (!ref.is_variable()) {
27  ref = autograd::make_variable(ref, /*requires_grad=*/false);
28  } else {
29  AT_ASSERT(!ref.requires_grad());
30  }
31  n->output()->inferTypeFrom(
32  ref); // note: before t_ because of std::move(ref)
33  n->t_(attr::value, std::move(ref));
34  } else if (val.isInt()) {
35  n->i_(attr::value, val.toInt());
36  n->output()->setType(IntType::get());
37  } else if (val.isDouble()) {
38  n->f_(attr::value, val.toDouble());
39  n->output()->setType(FloatType::get());
40  } else if (val.isBool()) {
41  n->i_(attr::value, val.toBool());
42  n->output()->setType(BoolType::get());
43  } else if (val.isBoolList()) {
44  auto bool_list = val.toBoolList()->elements();
45  n->is_(
46  attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
47  n->output()->setType(ListType::ofBools());
48  } else if (val.isIntList()) {
49  n->is_(attr::value, val.toIntList()->elements());
50  n->output()->setType(ListType::ofInts());
51  } else if (val.isTensorList()) {
52  n->ts_(
53  attr::value,
54  fmap(val.toTensorList()->elements(), [](const at::Tensor& t) {
55  AT_ASSERT(t.is_variable() && !t.requires_grad());
56  return t;
57  }));
58  n->output()->setType(ListType::ofTensors());
59  } else if (val.isString()) {
60  n->s_(attr::value, val.toString()->string());
61  n->output()->setType(StringType::get());
62  } else if (val.isDevice()) {
63  std::stringstream ss;
64  ss << val.toDevice();
65  n->s_(attr::value, ss.str());
66  n->output()->setType(DeviceObjType::get());
67  } else if (val.isNone()) {
68  n->output()->setType(NoneType::get());
69  } else {
70  n->destroy();
71  throw constant_not_supported_error(
72  "Unsupported value kind: " + val.tagKind());
73  }
74  if (loc)
75  n->setSourceLocation(std::make_shared<SourceRange>(*loc));
76  if (scope)
77  n->setScope(*scope);
78  if (result_type) {
79  auto inferred_type = n->output()->type();
80  // Retain more type information in case of tensor constant
81  if (!(inferred_type->isSubtypeOf(TensorType::get()) &&
82  result_type->isSubtypeOf(inferred_type))) {
83  n->output()->setType(result_type);
84  }
85  }
86  return g.insertNode(n)->output();
87 }
88 
89 RegisterOperators reg({
90  Operator(
91  FunctionSchema(
92  prim::Constant,
93  "",
94  {},
95  {},
96  /*is_vararg=*/false,
97  /*is_varret=*/true),
98  [](const Node* node) -> Operation {
99  TypePtr type = node->output()->type();
100  if (type->isSubtypeOf(TensorType::get())) {
101  auto t = node->t(attr::value);
102  return [t](Stack& stack) {
103  push(stack, t);
104  return 0;
105  };
106  } else if (type->isSubtypeOf(BoolType::get())) {
107  bool b = node->i(attr::value);
108  return [b](Stack& stack) {
109  push(stack, b);
110  return 0;
111  };
112  } else if (
113  type->isSubtypeOf(NumberType::get()) &&
114  node->kindOf(attr::value) == AttributeKind::i) {
115  auto i = node->i(attr::value);
116  return [i](Stack& stack) {
117  push(stack, i);
118  return 0;
119  };
120  } else if (
121  type->isSubtypeOf(NumberType::get()) &&
122  node->kindOf(attr::value) == AttributeKind::f) {
123  auto f = node->f(attr::value);
124  return [f](Stack& stack) {
125  push(stack, f);
126  return 0;
127  };
128  } else if (type->isSubtypeOf(ListType::ofInts())) {
129  const auto& is = node->is(attr::value);
130  return [is](Stack& stack) {
131  push(stack, is);
132  return 0;
133  };
134  } else if (type->isSubtypeOf(ListType::ofBools())) {
135  const auto bs = fmap<bool>(node->is(attr::value));
136  return [bs](Stack& stack) {
137  push(stack, bs);
138  return 0;
139  };
140  } else if (type->isSubtypeOf(ListType::ofTensors())) {
141  const auto& ts = node->ts(attr::value);
142  return [ts](Stack& stack) {
143  push(stack, ts);
144  return 0;
145  };
146  } else if (type == StringType::get()) {
147  const auto& s = node->s(attr::value);
148  return [s](Stack& stack) {
149  push(stack, s);
150  return 0;
151  };
152  } else if (type == DeviceObjType::get()) {
153  auto d = c10::Device(node->s(attr::value));
154  return [d](Stack& stack) {
155  push(stack, d);
156  return 0;
157  };
158  } else if (node->mustBeNone()) {
159  return [](Stack& stack) {
160  push(stack, IValue());
161  return 0;
162  };
163  } else {
164  std::stringstream ss;
165  ss << "constant literal not supported for: " << type->str();
166  throw std::runtime_error(ss.str());
167  }
168  }),
169 });
170 
171 c10::optional<IValue> toIValue(const Value* v) {
172  if (v->node()->kind() != prim::Constant) {
173  return c10::nullopt;
174  }
175  // use implemenation of prim::Constant to compute the output IValue
176  auto op = getOperation(v->node());
177  Stack stack;
178  op(stack);
179  return stack.back();
180 }
181 } // namespace jit
182 } // namespace torch
Represents a a compute device on which a tensor is located.
Definition: Device.h:30
bool is_variable() const noexcept
Returns true if the Tensor is actually a torch::autograd::Variable.
Definition: jit_type.h:17