1 #include <torch/csrc/jit/fuser/executor.h> 4 #include <ATen/ExpandUtils.h> 5 #include <ATen/core/functional.h> 6 #include <ATen/core/stack.h> 7 #include <c10/util/Optional.h> 8 #include <torch/csrc/jit/fuser/compiler.h> 9 #include <torch/csrc/jit/fuser/interface.h> 10 #include <torch/csrc/jit/fuser/kernel_cache.h> 11 #include <torch/csrc/jit/fuser/kernel_spec.h> 12 #include <torch/csrc/jit/fuser/tensor_info.h> 28 const KernelSpec& spec,
36 std::vector<int64_t> map_size;
38 for (
const auto arg_idx : arg_subset) {
39 auto& arg = args.
at(arg_idx);
40 auto& chunk_desc = spec.inputChunks().at(arg_idx);
41 if (chunk_desc.nSubTensors() == 1) {
43 map_size = at::infer_size(map_size, arg.sizes());
48 auto tensor_sizes = arg.sizes().vec();
49 const auto num_chunks = chunk_desc.nSubTensors();
51 at::maybe_wrap_dim(chunk_desc.dim(), tensor_sizes.size());
52 if (tensor_sizes[dim] % num_chunks != 0) {
55 tensor_sizes[dim] /= num_chunks;
57 map_size = at::infer_size(map_size, tensor_sizes);
69 const KernelSpec& spec,
73 args.
size() == spec.inputChunks().size(),
75 spec.inputChunks().size(),
76 " arguments, but got ",
80 for (
const auto& broadcast_group : spec.inputBroadcastGroups()) {
82 map_size = getMapSize(spec, args, broadcast_group);
86 const auto group_map_size = getMapSize(spec, args, broadcast_group);
88 if (map_size != group_map_size)
100 static bool expandArgs(
101 const KernelSpec& spec,
102 std::vector<at::Tensor>& args,
103 std::vector<int64_t>& map_size,
bool dry_run) {
104 bool has_broadcast =
false;
105 for (
size_t i = 0; i < args.size(); ++i) {
107 const auto& pdesc = spec.inputChunks()[i];
108 if (pdesc.nSubTensors() == 1) {
109 if (arg.sizes().equals(map_size))
112 arg = arg.expand(map_size);
113 has_broadcast =
true;
118 map_size.at(pdesc.dim()) *= pdesc.nSubTensors();
119 if (!arg.sizes().equals(map_size)) {
121 arg = arg.expand(map_size);
122 has_broadcast =
true;
127 map_size.at(pdesc.dim()) /= pdesc.nSubTensors();
130 return has_broadcast;
133 static bool shouldExpandArgs(
134 const KernelSpec& spec,
135 std::vector<at::Tensor>& args,
136 std::vector<int64_t>& map_size) {
137 return expandArgs(spec, args, map_size,
true);
144 for (
const auto& size : sizes)
151 static std::vector<int64_t> computeMapSize(
153 const PartitionDesc& chunkDesc) {
154 std::vector<int64_t> sizes(tensor.sizes().begin(), tensor.sizes().end());
155 AT_ASSERT(sizes[chunkDesc.dim()] % chunkDesc.nSubTensors() == 0);
156 sizes[chunkDesc.dim()] /= chunkDesc.nSubTensors();
162 static void compressContiguous(
165 const std::vector<bool>& cont,
167 uint32_t* c_strides) {
168 size_t compressed_dims = 0;
170 size_t ndim = sizes.
size();
172 size_t total_size = sizes[cur];
174 while (cont[cur - 1] && cur < ndim) {
175 AT_ASSERT(strides[cur - 1] == sizes[cur] * strides[cur]);
176 total_size *= sizes[cur];
179 c_sizes[compressed_dims] = total_size;
180 c_strides[compressed_dims] = strides[cur - 1];
185 AT_ASSERT(!cont.back() || strides.
back() == 1);
191 const FusedKernel& fusion,
194 std::vector<at::Tensor>& outputs) {
196 AT_ASSERT(inputs.
size() == fusion.inputDesc().size());
199 size_t flat_inputs_size = 0;
200 size_t flat_outputs_size = 0;
201 for (
const auto& c : fusion.chunkDesc())
202 flat_inputs_size += c.nSubTensors();
203 for (
const auto& c : fusion.concatDesc())
204 flat_outputs_size += c.nSubTensors();
210 AT_ASSERT(inputs[0].numel() <= std::numeric_limits<uint32_t>::max());
215 std::vector<int64_t> keep_alive_size;
216 if (fusion.chunkDesc()[0].isNoop()) {
217 map_size = inputs[0].sizes();
218 numel = inputs[0].numel();
220 keep_alive_size = computeMapSize(inputs[0], fusion.chunkDesc()[0]);
221 map_size = keep_alive_size;
222 numel = computeNumel(map_size);
227 size_t uncompressedDim = fusion.inputDesc().at(0).contiguity.size();
228 size_t maxPossibleTensorInfoSize =
229 sizeof(TensorInfo) + 2 *
sizeof(uint32_t) * uncompressedDim;
230 size_t maxPossibleBufferSize =
231 maxPossibleTensorInfoSize * (flat_inputs_size + flat_outputs_size);
232 std::vector<char> buffer(maxPossibleBufferSize);
233 char* buffer_next = buffer.data();
236 std::vector<void*> arguments;
237 arguments.reserve(3 + flat_inputs_size + flat_outputs_size);
238 arguments.push_back(&numel);
240 auto addTensorInfoRaw = [&](
const TensorDesc& desc,
244 const auto nDim = desc.nDim();
245 AT_ASSERT(nDim <= uncompressedDim);
246 auto ti =
reinterpret_cast<TensorInfo*
>(buffer_next);
249 sizes, strides, desc.contiguity, ti->sizes(nDim), ti->strides(nDim));
250 buffer_next += maxPossibleTensorInfoSize;
251 arguments.push_back(ti);
256 auto addTensorInfo = [&](
const TensorDesc& desc,
const at::Tensor& t) {
257 addTensorInfoRaw(desc, t.data_ptr(), t.sizes(), t.strides());
261 for (
size_t i = 0; i < fusion.inputDesc().size(); ++i) {
262 const auto& chunk = fusion.chunkDesc()[i];
264 if (chunk.isNoop()) {
265 addTensorInfo(fusion.inputDesc()[i], tensor);
267 size_t chunk_offset = map_size[chunk.dim()] * tensor.stride(chunk.dim()) *
268 elementSize(tensor.scalar_type());
269 char* data_ptr =
reinterpret_cast<char*
>(tensor.data_ptr());
270 for (
size_t chunks = 0; chunks < chunk.nSubTensors(); ++chunks) {
272 *chunk.subTensorDesc(), data_ptr, map_size, tensor.strides());
273 data_ptr += chunk_offset;
279 outputs.reserve(fusion.outputDesc().size());
280 const auto& ref_options = inputs[0].options();
281 for (
size_t i = 0; i < fusion.outputDesc().size(); ++i) {
282 const auto& c = fusion.concatDesc()[i];
284 outputs.push_back(at::empty(
285 map_size, ref_options.dtype(fusion.outputDesc()[i].scalar_type)));
286 addTensorInfo(fusion.outputDesc()[i], outputs[i]);
288 size_t small_size = map_size[c.dim()];
289 std::vector<int64_t> concat_size(map_size.begin(), map_size.end());
290 concat_size[c.dim()] = small_size * c.nSubTensors();
291 outputs.push_back(at::empty(concat_size, ref_options));
292 const auto& o = outputs[i];
294 for (
size_t j = 0; j < c.nSubTensors(); ++j) {
298 const auto view = o.narrow(c.dim(), offset, small_size);
299 addTensorInfo(*c.subTensorDesc(), view);
300 offset += small_size;
305 fusion.launch_raw(numel, arguments);
308 bool runFusion(
const int64_t key, Stack& stack) {
310 if (!canFuseOnCPU() && !canFuseOnGPU())
314 auto maybe_spec = retrieve(key);
315 AT_ASSERT(maybe_spec);
316 auto& spec = *(*maybe_spec);
319 auto all_inputs = last(stack, spec.nInputs());
320 std::vector<at::Tensor> inputs;
321 inputs.reserve(spec.nTensorInputs());
323 for (int64_t i = 0; i < spec.nTensorInputs(); i++) {
324 inputs.emplace_back(all_inputs[i].toTensor());
331 if (t.
device() != device) {
337 if (device.
is_cuda() && !canFuseOnGPU())
339 if (device.
is_cpu() && !canFuseOnCPU())
343 auto maybe_map_size = canRunKernel(spec, inputs);
348 if (spec.hasRandom()) {
349 bool hasBroadcast = shouldExpandArgs(spec,inputs, *maybe_map_size);
350 if (hasBroadcast)
return false;
352 expandArgs(spec, inputs, *maybe_map_size,
false);
355 ArgSpec arg_spec{inputs, device.
index()};
356 auto maybe_kernel = spec.findKernel(arg_spec);
358 const auto kernel = compileKernel(spec, arg_spec, *maybe_map_size, device);
359 spec.cacheKernel(arg_spec, kernel);
361 maybe_kernel = spec.findKernel(arg_spec);
362 AT_ASSERT(maybe_kernel);
365 std::vector<at::Tensor> raw_outputs;
366 launchFusion(*(*maybe_kernel), device, inputs, raw_outputs);
368 auto outputs = fmap(spec.outputMapAndSizes(), [&](
const OutputMapAndSize& omap) {
369 if (omap.needsSumToSize()) {
371 raw_outputs[omap.offset()],
372 all_inputs[omap.sizeInput()].toIntList()->elements());
374 return raw_outputs[omap.offset()];
379 drop(stack, spec.nInputs());
382 std::make_move_iterator(outputs.begin()),
383 std::make_move_iterator(outputs.end()));
AT_CPP14_CONSTEXPR const T & back() const
back - Get the last element.
bool is_cuda() const noexcept
Return true if the device is of CUDA type.
Represents a a compute device on which a tensor is located.
constexpr size_t size() const
size - Get the array size.
Device device() const
Returns a Tensor's device.
bool is_cpu() const noexcept
Return true if the device is of CPU type.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
DeviceIndex index() const noexcept
Returns the optional index.
AT_CPP14_CONSTEXPR const T & at(size_t Index) const
Vector compatibility.