4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/utils/math.h" 12 template <
class SIndex,
class Context>
20 const Tensor* go =
nullptr) {
21 bool backward = output ==
nullptr;
23 auto* starts_data = starts.template data<SIndex>();
24 auto* ends_data = ends.template data<SIndex>();
26 CAFFE_ENFORCE_EQ(starts.dim(), 1);
27 CAFFE_ENFORCE_EQ(ends.dim(), 1);
28 CAFFE_ENFORCE_GE(data.dim(), starts.numel());
29 CAFFE_ENFORCE_EQ(starts.numel(), ends.numel());
31 std::vector<SIndex> starts_idx(data.dim());
32 std::vector<SIndex> ends_idx(data.dim());
33 std::vector<SIndex> dst_sizes(data.dim());
35 for (
int i = 0; i < data.dim(); ++i) {
36 if (i >= starts.numel()) {
38 ends_idx[i] = data.sizes()[i];
41 if (data.sizes()[i] > 0) {
42 auto start = starts_data[i];
43 auto end = ends_data[i];
45 start = data.sizes()[i] + 1 + start;
48 end = data.sizes()[i] + 1 + end;
50 if (start > data.sizes()[i]) {
51 start = data.sizes()[i];
53 if (end > data.sizes()[i]) {
54 end = data.sizes()[i];
56 CAFFE_ENFORCE_GE(start, 0);
57 CAFFE_ENFORCE_GE(end, 0);
58 CAFFE_ENFORCE_GE(end, start);
59 starts_idx[i] = start;
61 dst_sizes[i] = end - start;
69 if (data.numel() <= 0) {
72 output->Resize(dst_sizes);
73 output->raw_mutable_data(data.dtype());
79 for (
int i = 0; i < data.dim(); ++i) {
80 if (starts_idx[i] > 0 || ends_idx[i] < data.sizes()[i]) {
82 dim, -1,
"Currently only possible to slice in 1 dimension.");
88 output->CopyFrom(data,
true );
90 gdata->CopyFrom(*go,
true );
94 size_t unit = std::accumulate(
95 data.sizes().begin() + dim + 1,
98 std::multiplies<SIndex>());
99 size_t num_blocks = std::accumulate(
100 data.sizes().begin(),
101 data.sizes().begin() + dim,
103 std::multiplies<SIndex>());
105 output->Resize(dst_sizes);
107 gdata->ResizeLike(data);
110 size_t itemsize = data.dtype().itemsize();
113 char* src_bytes = (
char*)data.raw_data();
114 char* dst_bytes = (
char*)output->raw_mutable_data(data.dtype());
116 size_t src_nbytes = data.nbytes();
117 size_t dst_nbytes = output->nbytes();
119 size_t src_block_size = unit * data.sizes()[dim];
120 size_t dst_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
121 size_t src_offset = unit * starts_idx[dim];
123 if (num_blocks == 0 || dst_block_size == 0) {
127 size_t src_block_size_bytes = itemsize * src_block_size;
128 size_t dst_block_size_bytes = itemsize * dst_block_size;
130 char* src_offset_bytes = src_bytes + itemsize * src_offset;
131 char* dst_offset_bytes = dst_bytes;
132 for (
size_t i = 0; i < num_blocks; ++i) {
133 char* local_src_offset_bytes =
134 src_offset_bytes + i * src_block_size_bytes;
135 char* local_dst_offset_bytes =
136 dst_offset_bytes + i * dst_block_size_bytes;
138 static_cast<void*>(local_src_offset_bytes + dst_block_size_bytes),
139 static_cast<void*>(src_bytes + src_nbytes));
141 static_cast<void*>(local_dst_offset_bytes + dst_block_size_bytes),
142 static_cast<void*>(dst_bytes + dst_nbytes));
143 context->CopyItemsSameDevice(
146 (
void*)local_src_offset_bytes,
147 (
void*)local_dst_offset_bytes);
150 char* src_bytes = (
char*)go->raw_data();
151 char* dst_bytes = (
char*)gdata->raw_mutable_data(go->dtype());
153 size_t src_nbytes = go->nbytes();
154 size_t dst_nbytes = gdata->nbytes();
156 size_t src_block_size = unit * (ends_idx[dim] - starts_idx[dim]);
157 size_t dst_block_size = unit * data.sizes()[dim];
158 size_t dst_offset = unit * starts_idx[dim];
160 if (num_blocks == 0 || dst_block_size == 0) {
164 size_t src_block_size_bytes = itemsize * src_block_size;
165 size_t dst_block_size_bytes = itemsize * dst_block_size;
167 char* src_offset_bytes = src_bytes;
168 char* dst_offset_bytes = dst_bytes + itemsize * dst_offset;
171 math::Set<char, Context>(dst_nbytes, 0, dst_bytes, context);
178 for (
size_t i = 0; i < num_blocks; ++i) {
179 char* local_src_offset_bytes =
180 src_offset_bytes + i * src_block_size_bytes;
181 char* local_dst_offset_bytes =
182 dst_offset_bytes + i * dst_block_size_bytes;
184 local_src_offset_bytes + src_block_size_bytes,
185 src_bytes + src_nbytes);
187 local_dst_offset_bytes + src_block_size_bytes,
188 dst_bytes + dst_nbytes);
189 context->CopyItemsSameDevice(
192 (
void*)local_src_offset_bytes,
193 (
void*)local_dst_offset_bytes);
201 template <class Context>
204 USE_OPERATOR_CONTEXT_FUNCTIONS;
205 template <
class... Args>
206 explicit SliceOp(Args&&... args)
208 starts_(this->
template GetRepeatedArgument<int64_t>(
"starts")),
209 ends_(this->
template GetRepeatedArgument<int64_t>(
"ends")),
210 statically_inited_(
false) {}
212 bool RunOnDevice()
override {
213 if (InputSize() > 1) {
216 return DoRunWithType<int64_t>();
220 template <
typename SIndex>
221 bool DoRunWithType() {
222 if (InputSize() > 1) {
223 ReinitializeAndCopyFrom(&starts_host_, at::dtype<SIndex>().device(CPU), Input(1));
224 ReinitializeAndCopyFrom(&ends_host_, at::dtype<SIndex>().device(CPU), Input(2));
226 if (!statically_inited_) {
227 CAFFE_ENFORCE(HasArgument(
"starts"));
228 CAFFE_ENFORCE(HasArgument(
"ends"));
229 CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
231 ReinitializeTensor(&starts_host_, {
static_cast<int64_t
>(starts_.size())}, at::dtype<SIndex>().device(CPU));
232 ReinitializeTensor(&ends_host_, {
static_cast<int64_t
>(ends_.size())}, at::dtype<SIndex>().device(CPU));
235 starts_host_.template mutable_data<SIndex>(),
237 sizeof(SIndex) * starts_.size());
239 ends_host_.template mutable_data<SIndex>(),
241 sizeof(SIndex) * ends_.size());
242 statically_inited_ =
true;
246 const auto& data = Input(0);
247 auto output = Output(0);
249 return SliceImpl<SIndex, Context>(
250 output, data, starts_host_, ends_host_, &context_);
253 C10_DISABLE_COPY_AND_ASSIGN(SliceOp);
256 std::vector<int64_t> starts_;
257 std::vector<int64_t> ends_;
258 bool statically_inited_;
263 template <
class Context>
266 USE_OPERATOR_CONTEXT_FUNCTIONS;
267 template <
class... Args>
270 starts_(this->
template GetRepeatedArgument<int64_t>(
"starts")),
271 ends_(this->
template GetRepeatedArgument<int64_t>(
"ends")),
272 statically_inited_(
false) {}
274 C10_DISABLE_COPY_AND_ASSIGN(SliceGradientOp);
276 bool RunOnDevice()
override {
277 if (InputSize() == 4) {
280 return DoRunWithType<int64_t>();
284 template <
typename SIndex>
285 bool DoRunWithType() {
286 auto* gdata = Output(0);
287 auto& data =
Input(0);
289 if (InputSize() == 4) {
290 ReinitializeAndCopyFrom(&starts_host_, at::dtype<SIndex>().device(CPU),
Input(1));
291 ReinitializeAndCopyFrom(&ends_host_, at::dtype<SIndex>().device(CPU),
Input(2));
295 return SliceImpl<SIndex, Context>(
296 nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
298 if (!statically_inited_) {
301 CAFFE_ENFORCE_EQ(starts_.size(), ends_.size());
304 &starts_host_, {
static_cast<int64_t
>(starts_.size())}, at::dtype<SIndex>().device(CPU));
306 &ends_host_, {
static_cast<int64_t
>(ends_.size())}, at::dtype<SIndex>().device(CPU));
309 starts_host_.template mutable_data<SIndex>(),
311 sizeof(SIndex) * starts_.size());
313 ends_host_.template mutable_data<SIndex>(),
315 sizeof(SIndex) * ends_.size());
317 statically_inited_ =
true;
321 return SliceImpl<SIndex, Context>(
322 nullptr, data, starts_host_, ends_host_, &context_, gdata, &go);
328 std::vector<int64_t> starts_;
329 std::vector<int64_t> ends_;
330 bool statically_inited_;
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.