Caffe2 - C++ API
A deep learning, cross platform ML framework
arg_spec.h
1 #pragma once
2 #include <ATen/ATen.h>
3 #include <ATen/core/functional.h> // fmap
4 #include <torch/csrc/WindowsTorchApiMacro.h>
5 #include <torch/csrc/jit/fuser/tensor_desc.h>
6 #include <torch/csrc/utils/hash.h>
7 
8 #include <cstdint>
9 #include <vector>
10 
11 namespace torch {
12 namespace jit {
13 namespace fuser {
14 
15 // Describes the (runtime) arguments to a kernel.
16 // ArgSpecs are also used as keys to lookup instantiated kernels, so
17 // they are hashable.
18 // Note: the device to run on is included in the arg spec because kernels
19 // are compiled per-device.
20 struct TORCH_API ArgSpec {
21  ArgSpec(at::TensorList inputs, const int _device)
22  : descs_{c10::fmap<TensorDesc>(inputs)},
23  hash_code_{torch::get_hash(_device, inputs.size(), descs_)},
24  device_{_device} {}
25 
26  // (Common) hash function
27  static size_t hash(const ArgSpec& spec) {
28  return spec.hash_code_;
29  }
30 
31  // Comparators
32  bool operator==(const ArgSpec& other) const {
33  return (descs_ == other.descs_ && device_ == other.device_);
34  }
35 
36  bool operator!=(const ArgSpec& spec) const {
37  return !(*this == spec);
38  }
39 
40  // Getters
41  size_t hashCode() const {
42  return hash_code_;
43  }
44  const std::vector<TensorDesc>& descs() const {
45  return descs_;
46  }
47  int device() const {
48  return device_;
49  }
50 
51  private:
52  std::vector<TensorDesc> descs_;
53  size_t hash_code_;
54  int device_;
55 };
56 
57 } // namespace fuser
58 } // namespace jit
59 } // namespace torch
constexpr size_t size() const
size - Get the array size.
Definition: ArrayRef.h:138
Definition: jit_type.h:17
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory)...
Definition: ArrayRef.h:41