8 def batch_tanh(data, mask, dims):
9 data = torch.tanh(data)
10 return data, mask, dims
14 def batch_sigmoid(data, mask, dims):
15 data = torch.sigmoid(data)
16 return data, mask, dims
20 def batch_relu(data, mask, dims):
21 data = torch.relu(data)
22 return data, mask, dims
26 def batch_neg(data, mask, dims):
27 data = torch.neg(data)
28 return data, mask, dims
32 def batch_neg_scalar(data):
33 return torch.neg(data)
37 def batch_add(data1, mask1, dims1, data2, mask2, dims2, alpha_):
39 data = torch.add(data1, data2, alpha=alpha)
41 dims = dims1.__or__(dims2)
42 return data, mask, dims
46 def batch_add_scalar(data, mask, dims, other, alpha_):
48 data = torch.add(data, other.type_as(data), alpha=alpha)
49 return data, mask, dims
53 def batch_sub(data1, mask1, dims1, data2, mask2, dims2, alpha_):
55 data = torch.sub(data1, data2, alpha=alpha)
57 dims = dims1.__or__(dims2)
58 return data, mask, dims
62 def batch_sub_scalar(data1, data2):
67 def batch_mul(data1, mask1, dims1, data2, mask2, dims2):
68 data = torch.mul(data1, data2)
70 dims = dims1.__or__(dims2)
71 return data, mask, dims
75 def batch_mul_scalar(data1, data2):
80 def batch_div(data, mask, dims, other):
81 data = torch.div(data, other)
82 return data, mask, dims
86 def batch_mm(data1, mask1, dims1, data2, mask2, dims2):
87 data1 = data1 * mask1.type_as(data1)
88 data2 = data2 * mask2.type_as(data2)
89 data = torch.bmm(data1, data2)
90 mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1))
91 dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)]))
92 return data, mask, dims
96 def batch_matmul(data1, mask1, dims1, data2, mask2, dims2):
99 data1 = data1 * mask1.type_as(data1)
100 data2 = data2 * mask2.type_as(data2)
102 data1 = data1.unsqueeze(-2)
104 data2 = data2.unsqueeze(-1)
105 data = torch.bmm(data1, data2)
108 if d1 == 1
and d2 == 1:
111 data = data.squeeze(-1).squeeze(-1)
112 mask = mask1.narrow(1, 0, 1).squeeze(-1)
114 if d1 == 2
and d2 == 1:
117 data = data.squeeze(-1)
118 mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1).unsqueeze(-1)).squeeze(-1)
120 elif d1 == 1
and d2 == 2:
123 data = data.squeeze(-2)
124 mask = torch.bmm(mask1.narrow(1, 0, 1).unsqueeze(-2), mask2.narrow(1, 0, 1)).squeeze(-2)
125 dims = dims2[1:dims2.size(0)]
126 elif d1 == 2
and d2 == 2:
129 mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1))
130 dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)]))
133 return data, mask, dims
137 def batch_select(data, mask, dims, dim_, index_):
142 data = data.select(dim, index)
143 if bool(dims[dim - 1]):
144 mask = mask.select(dim, index)
146 mask = mask.select(dim, 0)
147 dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
148 return data, mask, dims
152 def batch_fmod(data, mask, dims, other_):
154 data = torch.fmod(data, other)
155 return data, mask, dims
159 def batch_zeros_like(data, mask, dims):
160 res_data = torch.zeros_like(data)
161 return res_data, mask, dims
165 def batch_index_select(data, mask, dims, dim_, index_data, index_mask, index_dims):
169 batch_size = data.size(0)
170 res_data = torch.zeros([0])
171 res_mask = torch.zeros([0])
172 for i
in range(batch_size):
173 d = data[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
174 if bool(dims[dim - 1]):
175 m = mask[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
177 m = mask[i].unsqueeze(0)
182 res_data = torch.cat((res_data, d), 0)
183 res_mask = torch.cat((res_mask, m), 0)
184 return res_data, res_mask, dims
188 def batch_view_as(data, mask, dims, data1, mask1, dims1):
193 data = data.view_as(data1)
194 mask = mask.view_as(mask1)
196 return data, mask, dims
201 def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2):
202 data = data * mask.type_as(data)
206 for _
in range(data1.dim() - 1):
207 data = data.unsqueeze(data.dim())
208 cond_data = data.expand_as(data1)
209 cond_mask = data.expand_as(mask1)
210 res_data = torch.where(cond_data, data1, data2)
211 res_mask = torch.where(cond_mask, mask1, mask2)
212 res_dims = dims1.__or__(dims2)
213 return res_data, res_mask, res_dims
217 def batch_where_scalar(cond, data1, mask1, dims1, data2, mask2, dims2):
218 cond = torch.zeros([1], dtype=torch.uint8)
219 res_data = torch.where(cond, data1, data2)
220 res_mask = torch.where(cond, mask1, mask2)
221 res_dims = torch.where(cond, dims1, dims2)
222 return res_data, res_mask, res_dims
226 def batch_update(batch_data, batch_mask, batch_dims, new_data, new_mask, new_dims):
227 data = torch.where(new_mask, new_data, batch_data)
228 return data, new_mask, new_dims
232 def batch_any(data, mask, dims):
233 return torch.gt(torch.sum(data * mask), 0)
237 def batch_type_as(data, mask, dims, data1, mask1, dims1):
238 return data.type_as(data1), mask, dims
242 def batch_gt(data, mask, dims, data1, mask1, dims1):
243 return torch.gt(data, data1), mask * mask1, dims.__or__(dims1)
247 def batch_gt_scalar(data1, data2):
248 return torch.gt(data1, data2)
252 def batch_gt_one_scalar(data, mask, dims, other_):
253 other = float(other_)
254 return torch.gt(data, other), mask, dims
258 def batch_lt(data, mask, dims, data1, mask1, dims1):
259 return torch.lt(data, data1), mask * mask1, dims.__or__(dims1)
263 def batch_eq(data, mask, dims, data1, mask1, dims1):
264 return torch.eq(data, data1), mask * mask1, dims.__or__(dims1)
268 def batch_size(data, mask, dims, dim_):
270 return data.size(dim)
274 def batch_dim(data, mask, dims):
279 def batch_squeeze(data, mask, dims, dim_):
281 dim_ = dim_ + data.dim()
285 data = data.squeeze(dim)
286 mask = mask.squeeze(dim)
287 dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
288 return data, mask, dims
292 def batch_unsqueeze(data, mask, dims, dim_):
294 dim_ = dim_ + data.dim() + 1
298 data = data.unsqueeze(dim)
299 mask = mask.unsqueeze(dim)
300 dims = torch.cat((dims[:dim], torch.zeros([1], dtype=torch.uint8), dims[dim:dims.size(0)]))
301 return data, mask, dims
305 def batch_argmax(data, mask, dims, dim_, keepdim_):
307 keepdim = bool(keepdim_)
310 batch_size = data.size(0)
311 res_data = torch.zeros([0])
312 for i
in range(batch_size):
313 if bool(dims[dim - 1]):
315 m = mask[i].transpose(0, dim - 1)
318 valid_num = m.sum(0, keepdim=
True)
319 while(valid_num.dim() >= 1):
320 valid_num = valid_num[0]
321 d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
323 d = data[i].unsqueeze(0)
324 d = d.argmax(dim, keepdim)
328 res_data = torch.cat([res_data, d], 0)
332 mask = mask.select(dim, 0)
333 dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
334 return res_data, mask, dims
338 def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_):
341 largest = bool(largest_)
342 sorted = bool(sorted_)
345 batch_size = data.size(0)
346 res_data = torch.zeros([0])
347 res_index = torch.zeros([0])
348 for i
in range(batch_size):
349 if bool(dims[dim - 1]):
351 m = mask[i].transpose(0, dim - 1)
354 valid_num = m.sum(0, keepdim=
True)
355 while(valid_num.dim() >= 1):
356 valid_num = valid_num[0]
357 d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
359 d = data[i].unsqueeze(0)
360 d, idx = d.topk(k, dim, largest, sorted)
365 res_data = torch.cat([res_data, d], 0)
366 res_index = torch.cat([res_index, idx], 0)
367 if bool(dims[dim - 1]):
368 mask = mask.narrow(dim, 0, k)
369 return res_data, mask, dims, res_index, mask, dims
373 def batch_softmax(data, mask, dims, dim_):
377 batch_size = data.size(0)
378 max_len = data.size(dim)
379 res_data = torch.zeros([0])
380 for i
in range(batch_size):
381 if bool(dims[dim - 1]):
383 m = mask[i].transpose(0, dim - 1)
386 valid_num = m.sum(0, keepdim=
True)
387 while(valid_num.dim() >= 1):
388 valid_num = valid_num[0]
389 valid_num = int(valid_num)
390 d = data[i].unsqueeze(0).narrow(dim, 0, valid_num).softmax(dim)
391 if valid_num < max_len:
392 d = torch.cat([d, data[i].unsqueeze(0).narrow(dim, valid_num, max_len - valid_num)], dim)
394 d = data[i].unsqueeze(0).softmax(dim)
398 res_data = torch.cat([res_data, d], 0)
399 return res_data, mask, dims
405 def batch_view(data, mask, dims, sizes):
406 batch_size = data.size(0)
412 sizes = sizes.type_as(torch.ones([1], dtype=torch.int))
413 data_sizes_ = torch.cat([torch.ones([1], dtype=torch.int) * batch_size, sizes.narrow(0, 1, sizes.size(0) - 1)], 0)
414 data_sizes = data_sizes_._tensor_to_list()
415 res_data = data.view(data_sizes)
416 mask_sizes_ = data_sizes_.narrow(0, 0, 1)
417 res_dims = data_sizes_.narrow(0, 0, 1)
418 for i_
in range(sizes.size(0) - 1):
420 if bool(sizes[i] == -1):
421 cur_size_ = mask.size(i)
426 mask_sizes_ = torch.cat([mask_sizes_, torch.ones([1], dtype=torch.int) * cur_size_])
427 res_dims = torch.cat([res_dims, torch.ones([1], dtype=torch.int) * cur_dim])
428 mask_sizes = mask_sizes_._tensor_to_list()
429 res_mask = mask.view(mask_sizes)
430 return res_data, res_mask, res_dims.narrow(0, 1, res_dims.size(0) - 1).type_as(dims)
434 def batch_cat2(data1, mask1, dims1, data2, mask2, dims2, dim_):
436 data = torch.cat([data1, data2], dim)
437 if bool(dims1[dim - 1]):
438 mask = torch.cat([mask1, mask2], dim)
441 return data, mask, dims1
445 def batch_cat3(data1, mask1, dims1, data2, mask2, dims2, data3, mask3, dims3, dim_):
447 data = torch.cat([data1, data2, data3], dim)
448 if bool(dims1[dim - 1]):
449 mask = torch.cat([mask1, mask2, mask3], dim)
452 return data, mask, dims1
456 def batch_narrow(data, mask, dims, dimension_, start_, length_):
457 dimension = int(dimension_)
459 length = int(length_)
462 data = data.narrow(dimension, start, length)
463 if bool(dims[dimension - 1]):
464 mask = mask.narrow(dimension, start, length)
466 mask = mask.narrow(dimension, 0, 1)
467 return data, mask, dims
471 def batch_sum(data, mask, dims):
472 data = data * mask.type_as(data)
473 for _
in range(dims.size(0)):
475 mask = torch.ones([data.size(0)], dtype=torch.uint8)
477 return data, mask, dims
481 def batch_from_scalar_tensor(data):
482 data = data.unsqueeze(0)
483 mask = torch.ones([1], dtype=torch.uint8)
484 dims = torch.zeros([0], dtype=torch.uint8)
485 return data, mask, dims
487 torch.register_batch_operator(
"tanh", batch_tanh.graph)
488 torch.register_batch_operator(
"sigmoid", batch_sigmoid.graph)
489 torch.register_batch_operator(
"relu", batch_relu.graph)
490 torch.register_batch_operator(
"neg", batch_neg.graph)
491 torch.register_batch_operator(
"neg", batch_neg_scalar.graph)
492 torch.register_batch_operator(
"add", batch_add.graph)
493 torch.register_batch_operator(
"add", batch_add_scalar.graph)
494 torch.register_batch_operator(
"sub", batch_sub.graph)
495 torch.register_batch_operator(
"sub", batch_sub_scalar.graph)
496 torch.register_batch_operator(
"mul", batch_mul.graph)
497 torch.register_batch_operator(
"mul", batch_mul_scalar.graph)
498 torch.register_batch_operator(
"div", batch_div.graph)
499 torch.register_batch_operator(
"matmul", batch_matmul.graph)
500 torch.register_batch_operator(
"mm", batch_mm.graph)
501 torch.register_batch_operator(
"fmod", batch_fmod.graph)
502 torch.register_batch_operator(
"zeros_like", batch_zeros_like.graph)
503 torch.register_batch_operator(
"select", batch_select.graph)
504 torch.register_batch_operator(
"index_select", batch_index_select.graph)
505 torch.register_batch_operator(
"view_as", batch_view_as.graph)
506 torch.register_batch_operator(
"where", batch_where.graph)
507 torch.register_batch_operator(
"where", batch_where_scalar.graph)
508 torch.register_batch_operator(
"update", batch_update.graph)
509 torch.register_batch_operator(
"any", batch_any.graph)
510 torch.register_batch_operator(
"type_as", batch_type_as.graph)
511 torch.register_batch_operator(
"gt", batch_gt.graph)
512 torch.register_batch_operator(
"gt", batch_gt_scalar.graph)
513 torch.register_batch_operator(
"gt", batch_gt_one_scalar.graph)
514 torch.register_batch_operator(
"lt", batch_lt.graph)
515 torch.register_batch_operator(
"eq", batch_eq.graph)
516 torch.register_batch_operator(
"size", batch_size.graph)
517 torch.register_batch_operator(
"dim", batch_dim.graph)
518 torch.register_batch_operator(
"squeeze", batch_squeeze.graph)
519 torch.register_batch_operator(
"unsqueeze", batch_unsqueeze.graph)
520 torch.register_batch_operator(
"argmax", batch_argmax.graph)
521 torch.register_batch_operator(
"topk", batch_topk.graph)
522 torch.register_batch_operator(
"softmax", batch_softmax.graph)
523 torch.register_batch_operator(
"view", batch_view.graph)
524 torch.register_batch_operator(
"cat", batch_cat2.graph)
525 torch.register_batch_operator(
"cat", batch_cat3.graph)
526 torch.register_batch_operator(
"narrow", batch_narrow.graph)
527 torch.register_batch_operator(
"sum", batch_sum.graph)
528 torch.register_batch_operator(
"batch_from_scalar_tensor", batch_from_scalar_tensor.graph)