1 #include <torch/csrc/jit/constants.h> 2 #include <ATen/core/functional.h> 3 #include <torch/csrc/autograd/variable.h> 4 #include <torch/csrc/jit/custom_operator.h> 5 #include <torch/csrc/jit/operator.h> 11 Value* insertConstant(
14 const c10::TypePtr& result_type,
17 Node* n = g.create(prim::Constant);
22 return g.insertNode(g.createNone(TensorType::get()))->output();
27 ref = autograd::make_variable(ref,
false);
29 AT_ASSERT(!ref.requires_grad());
31 n->output()->inferTypeFrom(
33 n->t_(attr::value, std::move(ref));
34 }
else if (val.isInt()) {
35 n->i_(attr::value, val.toInt());
36 n->output()->setType(IntType::get());
37 }
else if (val.isDouble()) {
38 n->f_(attr::value, val.toDouble());
39 n->output()->setType(FloatType::get());
40 }
else if (val.isBool()) {
41 n->i_(attr::value, val.toBool());
42 n->output()->setType(BoolType::get());
43 }
else if (val.isBoolList()) {
44 auto bool_list = val.toBoolList()->elements();
46 attr::value, std::vector<int64_t>(bool_list.begin(), bool_list.end()));
47 n->output()->setType(ListType::ofBools());
48 }
else if (val.isIntList()) {
49 n->is_(attr::value, val.toIntList()->elements());
50 n->output()->setType(ListType::ofInts());
51 }
else if (val.isTensorList()) {
54 fmap(val.toTensorList()->elements(), [](
const at::Tensor& t) {
58 n->output()->setType(ListType::ofTensors());
59 }
else if (val.isString()) {
60 n->s_(attr::value, val.toString()->string());
61 n->output()->setType(StringType::get());
62 }
else if (val.isDevice()) {
65 n->s_(attr::value, ss.str());
66 n->output()->setType(DeviceObjType::get());
67 }
else if (val.isNone()) {
68 n->output()->setType(NoneType::get());
71 throw constant_not_supported_error(
72 "Unsupported value kind: " + val.tagKind());
75 n->setSourceLocation(std::make_shared<SourceRange>(*loc));
79 auto inferred_type = n->output()->type();
81 if (!(inferred_type->isSubtypeOf(TensorType::get()) &&
82 result_type->isSubtypeOf(inferred_type))) {
83 n->output()->setType(result_type);
86 return g.insertNode(n)->output();
89 RegisterOperators reg({
98 [](
const Node* node) -> Operation {
99 TypePtr type = node->output()->type();
100 if (type->isSubtypeOf(TensorType::get())) {
101 auto t = node->t(attr::value);
102 return [t](Stack& stack) {
106 }
else if (type->isSubtypeOf(BoolType::get())) {
107 bool b = node->i(attr::value);
108 return [b](Stack& stack) {
113 type->isSubtypeOf(NumberType::get()) &&
114 node->kindOf(attr::value) == AttributeKind::i) {
115 auto i = node->i(attr::value);
116 return [i](Stack& stack) {
121 type->isSubtypeOf(NumberType::get()) &&
122 node->kindOf(attr::value) == AttributeKind::f) {
123 auto f = node->f(attr::value);
124 return [f](Stack& stack) {
128 }
else if (type->isSubtypeOf(ListType::ofInts())) {
129 const auto& is = node->is(attr::value);
130 return [is](Stack& stack) {
134 }
else if (type->isSubtypeOf(ListType::ofBools())) {
135 const auto bs = fmap<bool>(node->is(attr::value));
136 return [bs](Stack& stack) {
140 }
else if (type->isSubtypeOf(ListType::ofTensors())) {
141 const auto& ts = node->ts(attr::value);
142 return [ts](Stack& stack) {
146 }
else if (type == StringType::get()) {
147 const auto& s = node->s(attr::value);
148 return [s](Stack& stack) {
152 }
else if (type == DeviceObjType::get()) {
154 return [d](Stack& stack) {
158 }
else if (node->mustBeNone()) {
159 return [](Stack& stack) {
160 push(stack, IValue());
164 std::stringstream ss;
165 ss <<
"constant literal not supported for: " << type->str();
166 throw std::runtime_error(ss.str());
172 if (v->node()->kind() != prim::Constant) {
176 auto op = getOperation(v->node());
Represents a a compute device on which a tensor is located.
bool is_variable() const noexcept
Returns true if the Tensor is actually a torch::autograd::Variable.