3 #include <torch/csrc/jit/constants.h> 4 #include <torch/csrc/jit/ir.h> 14 operator Value*()
const {
18 return g.addInput(std::move(name));
21 return g.addInput()->setType(std::move(type));
23 const std::vector<int64_t>& sizes()
const {
26 void addAsOutput()
const {
27 v->owningGraph()->registerOutput(v);
29 static std::vector<SymbolicVariable> create(
33 Node** created_node =
nullptr,
36 g = inputs.
at(0).value()->owningGraph();
38 Node* n = g->insertNode(g->create(kind, num_outputs));
41 for (
auto n : inputs) {
42 size_t d = n.value()->node()->scope()->getDepth();
45 s = n.value()->node()->scope();
50 for (
auto i : inputs) {
51 n->addInput(i.value());
56 std::vector<SymbolicVariable> out;
57 for (
auto v : n->outputs()) {
62 static bool isConstInt(
at::Scalar s, int32_t i) {
64 if (s.isFloatingPoint()) {
65 return (
double)i == s.toDouble();
67 return (int64_t)i == s.toLong();
71 return create(aten::mul, {*
this, rhs})[0].typeLike(*
this);
74 return create(aten::div, {*
this, rhs})[0].typeLike(*
this);
77 if (isConstInt(rhs, 1))
79 return (*
this) * insertConstant(rhs);
82 return create(aten::gt, {*
this, insertConstant(rhs)})[0]
83 .typeLikeWithScalarType(*
this, at::kByte);
86 return create(aten::gt, {*
this, rhs})[0].typeLikeWithScalarType(
90 return create(aten::lt, {*
this, insertConstant(rhs)})[0]
91 .typeLikeWithScalarType(*
this, at::kByte);
94 return create(aten::lt, {*
this, rhs})[0].typeLikeWithScalarType(
98 return create(aten::ge, {*
this, insertConstant(rhs)})[0]
99 .typeLikeWithScalarType(*
this, at::kByte);
102 return create(aten::ge, {*
this, rhs})[0].typeLikeWithScalarType(
106 return create(aten::le, {*
this, insertConstant(rhs)})[0]
107 .typeLikeWithScalarType(*
this, at::kByte);
110 return create(aten::le, {*
this, rhs})[0].typeLikeWithScalarType(
114 return create(aten::eq, {*
this, insertConstant(rhs)})[0]
115 .typeLikeWithScalarType(*
this, at::kByte);
118 return create(aten::ne, {*
this, insertConstant(rhs)})[0]
119 .typeLikeWithScalarType(*
this, at::kByte);
122 return create(aten::add, {*
this, rhs, insertConstant(1)})[0].typeLike(
126 return (*
this) + insertConstant(rhs);
129 return create(aten::neg, {*
this})[0].typeLike(*
this);
132 return create(aten::sub, {*
this, rhs, insertConstant(1)})[0].typeLike(
136 return create(aten::div, {*
this, insertConstant(rhs)})[0].typeLike(*
this);
139 return create(aten::remainder, {*
this, insertConstant(rhs)})[0].typeLike(
142 Value* size()
const {
143 return v->owningGraph()->insert(aten::size, {v});
146 return create(aten::_grad_sum_to_size, {*
this, size})[0];
149 return v->owningGraph()->insert(aten::expand, {v, size});
152 return create(aten::ne, {*
this, *
this})[0].typeLikeWithScalarType(
156 return create(t(
"mm"), {*
this, rhs})[0];
159 return create(t(
"t"), {*
this})[0];
162 return create(aten::sigmoid, {*
this})[0].typeLike(*
this);
165 return create(aten::tanh, {*
this})[0].typeLike(*
this);
167 std::vector<SymbolicVariable> chunk(int64_t chunks,
int dim)
const {
169 auto outputs = create(prim::ConstantChunk, {value()}, chunks, &chunk);
170 chunk->i_(attr::chunks, chunks)->i_(attr::dim, dim);
174 return create(aten::type_as, {*
this, rhs})[0].typeLikeWithRhsScalarType(
182 insertConstant(start),
183 insertConstant(length)},
187 Graph* g = dim->owningGraph();
189 if (inputs.
size() == 1 &&
190 inputs[0].value()->type()->isSubtypeOf(ListType::ofTensors())) {
191 input_list = inputs[0];
196 g->insertNode(g->createList(TensorType::get(), value_inputs))
199 return create(aten::cat, {input_list, dim})[0];
202 AT_ASSERT(inputs.
size() > 0);
203 return SymbolicVariable::cat(inputs, inputs[0].insertConstant(dim));
206 Graph* g = dim->owningGraph();
210 g->insertNode(g->createList(TensorType::get(), value_inputs))
212 return create(aten::stack, {input_list, dim})[0];
215 AT_ASSERT(inputs.
size() > 0);
216 return SymbolicVariable::stack(inputs, inputs[0].insertConstant(dim));
218 static std::vector<SymbolicVariable> broadcast_tensors(
220 AT_ASSERT(inputs.
size() > 0);
221 Graph* g = inputs[0].value()->owningGraph();
225 g->insertNode(g->createList(TensorType::get(), value_inputs))
227 Value* output_list = g->insert(aten::broadcast_tensors, {input_list});
228 Node* unpack = g->insertNode(
229 g->create(prim::ListUnpack, {output_list}, inputs.
size()));
230 return fmap<SymbolicVariable>(unpack->outputs());
233 return create(t(
"zeros_like"), {input})[0];
236 return create(t(
"cos"), {*
this})[0];
239 return create(t(
"cosh"), {*
this})[0];
242 return create(t(
"exp"), {*
this})[0];
245 return create(t(
"pow"), {*
this, insertConstant(other)})[0];
248 return create(t(
"rsqrt"), {*
this})[0];
251 return create(t(
"sign"), {*
this})[0];
254 return create(t(
"sin"), {*
this})[0];
257 return create(t(
"sinh"), {*
this})[0];
260 return create(t(
"sum"), {*
this})[0];
265 {*
this, insertConstant(
at::IntArrayRef{dim}), insertConstant(keepdim)})[0];
268 return create(t(
"squeeze"), {*
this, dim})[0];
271 return squeeze(insertConstant(dim));
274 return create(t(
"unsqueeze"), {*
this, dim})[0];
277 return unsqueeze(insertConstant(dim));
280 return create(aten::view, {*
this, sizes})[0];
283 return view(insertConstant(std::move(sizes)));
286 return create(aten::reshape, {*
this, sizes})[0];
289 return reshape(insertConstant(std::move(sizes)));
294 {*
this, mat1, mat2, insertConstant(1), insertConstant(1)})[0];
296 Value* value()
const {
302 return v->owningGraph()->insertConstant(std::move(value));
306 v->setType(other_type->contiguous());
311 at::ScalarType type)
const {
313 auto new_type = other_type->toScalarType(type)->contiguous();
314 v->setType(new_type);
323 if (other_type && rhs_type) {
325 other_type->toScalarType(rhs_type->scalarType())->contiguous();
326 v->setType(new_type);
330 static Symbol a(
const char* s_) {
331 return Symbol::attr(s_);
333 static Symbol t(
const char* s_) {
334 return Symbol::aten(s_);
346 typename =
typename std::enable_if<std::is_arithmetic<T>::value>::type>
356 return (lhs + (-rhs));
Scalar represents a 0-dimensional tensor which contains a single element.
constexpr size_t size() const
size - Get the array size.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
AT_CPP14_CONSTEXPR const T & at(size_t Index) const
Vector compatibility.