Caffe2 - Python API
A deep learning, cross platform ML framework
batchop.py
1 import torch
2 from torch.jit import BatchTensor
3 
4 
5 # TODO: there are some commented raise statements
6 # when we support rasie exception in script, we want to check them
7 @torch.jit.script
8 def batch_tanh(data, mask, dims):
9  data = torch.tanh(data)
10  return data, mask, dims
11 
12 
13 @torch.jit.script
14 def batch_sigmoid(data, mask, dims):
15  data = torch.sigmoid(data)
16  return data, mask, dims
17 
18 
19 @torch.jit.script
20 def batch_relu(data, mask, dims):
21  data = torch.relu(data)
22  return data, mask, dims
23 
24 
25 @torch.jit.script
26 def batch_neg(data, mask, dims):
27  data = torch.neg(data)
28  return data, mask, dims
29 
30 
31 @torch.jit.script
32 def batch_neg_scalar(data):
33  return torch.neg(data)
34 
35 
36 @torch.jit.script
37 def batch_add(data1, mask1, dims1, data2, mask2, dims2, alpha_):
38  alpha = float(alpha_)
39  data = torch.add(data1, data2, alpha=alpha)
40  mask = mask1 * mask2
41  dims = dims1.__or__(dims2)
42  return data, mask, dims
43 
44 
45 @torch.jit.script
46 def batch_add_scalar(data, mask, dims, other, alpha_):
47  alpha = float(alpha_)
48  data = torch.add(data, other.type_as(data), alpha=alpha)
49  return data, mask, dims
50 
51 
52 @torch.jit.script
53 def batch_sub(data1, mask1, dims1, data2, mask2, dims2, alpha_):
54  alpha = float(alpha_)
55  data = torch.sub(data1, data2, alpha=alpha)
56  mask = mask1 * mask2
57  dims = dims1.__or__(dims2)
58  return data, mask, dims
59 
60 
61 @torch.jit.script
62 def batch_sub_scalar(data1, data2):
63  return data1 - data2
64 
65 
66 @torch.jit.script
67 def batch_mul(data1, mask1, dims1, data2, mask2, dims2):
68  data = torch.mul(data1, data2)
69  mask = mask1 * mask2
70  dims = dims1.__or__(dims2)
71  return data, mask, dims
72 
73 
74 @torch.jit.script
75 def batch_mul_scalar(data1, data2):
76  return data1 * data2
77 
78 
79 @torch.jit.script
80 def batch_div(data, mask, dims, other): # div(batchtensor, scalar)
81  data = torch.div(data, other)
82  return data, mask, dims
83 
84 
85 @torch.jit.script
86 def batch_mm(data1, mask1, dims1, data2, mask2, dims2):
87  data1 = data1 * mask1.type_as(data1)
88  data2 = data2 * mask2.type_as(data2)
89  data = torch.bmm(data1, data2)
90  mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1))
91  dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)]))
92  return data, mask, dims
93 
94 
95 @torch.jit.script
96 def batch_matmul(data1, mask1, dims1, data2, mask2, dims2):
97  d1 = data1.dim() - 1
98  d2 = data2.dim() - 1
99  data1 = data1 * mask1.type_as(data1)
100  data2 = data2 * mask2.type_as(data2)
101  if d1 == 1:
102  data1 = data1.unsqueeze(-2)
103  if d2 == 1:
104  data2 = data2.unsqueeze(-1)
105  data = torch.bmm(data1, data2)
106  mask = mask1
107  dims = dims1
108  if d1 == 1 and d2 == 1:
109  # if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask).all():
110  # raise ValueError("cannot contract non-matching dimensions")
111  data = data.squeeze(-1).squeeze(-1)
112  mask = mask1.narrow(1, 0, 1).squeeze(-1)
113  dims = dims1[:0] # empty tensor
114  if d1 == 2 and d2 == 1:
115  # if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask).all():
116  # raise ValueError("cannot contract non-matching dimensions")
117  data = data.squeeze(-1)
118  mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1).unsqueeze(-1)).squeeze(-1)
119  dims = dims1[:1]
120  elif d1 == 1 and d2 == 2:
121  # if (batch1.dims[0] or batch2.dims[0]) and not batch1.mask.eq(batch2.mask[:, :, 0]).all():
122  # raise ValueError("cannot contract non-matching dimensions")
123  data = data.squeeze(-2)
124  mask = torch.bmm(mask1.narrow(1, 0, 1).unsqueeze(-2), mask2.narrow(1, 0, 1)).squeeze(-2)
125  dims = dims2[1:dims2.size(0)]
126  elif d1 == 2 and d2 == 2:
127  # if (batch1.dims[1] or batch2.dims[0]) and not batch1.mask[:, 0].eq(batch2.mask[:, :, 0]).all():
128  # raise ValueError("cannot contract non-matching dimensions")
129  mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1))
130  dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)]))
131  # else:
132  # raise NotImplementedError("matmul not implemented with batches of 3+D tensors")
133  return data, mask, dims
134 
135 
136 @torch.jit.script
137 def batch_select(data, mask, dims, dim_, index_):
138  dim = int(dim_)
139  index = int(index_)
140  # if dim == 0:
141  # raise ValueError("Cannot select 0 dim in BatchTensor")
142  data = data.select(dim, index)
143  if bool(dims[dim - 1]):
144  mask = mask.select(dim, index)
145  else:
146  mask = mask.select(dim, 0)
147  dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
148  return data, mask, dims
149 
150 
151 @torch.jit.script
152 def batch_fmod(data, mask, dims, other_):
153  other = int(other_)
154  data = torch.fmod(data, other)
155  return data, mask, dims
156 
157 
158 @torch.jit.script
159 def batch_zeros_like(data, mask, dims):
160  res_data = torch.zeros_like(data)
161  return res_data, mask, dims
162 
163 
164 @torch.jit.script
165 def batch_index_select(data, mask, dims, dim_, index_data, index_mask, index_dims):
166  dim = int(dim_)
167  # if dim == 0:
168  # raise ValueError("Cannot index_select along 0 dim in BatchTensor")
169  batch_size = data.size(0) # TODO maybe index_mask will be used at some point
170  res_data = torch.zeros([0])
171  res_mask = torch.zeros([0])
172  for i in range(batch_size):
173  d = data[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
174  if bool(dims[dim - 1]):
175  m = mask[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
176  else:
177  m = mask[i].unsqueeze(0)
178  if i == 0:
179  res_data = d
180  res_mask = m
181  else:
182  res_data = torch.cat((res_data, d), 0)
183  res_mask = torch.cat((res_mask, m), 0)
184  return res_data, res_mask, dims
185 
186 
187 @torch.jit.script
188 def batch_view_as(data, mask, dims, data1, mask1, dims1):
189  # if data.size(0) != data1.size(0):
190  # raise ValueError("In view_as, tensor and target tensor should have the same batch_size")
191  # if not torch.equal(dims, dims1):
192  # raise ValueError("In batched view_as, dims and target dims should be the same")
193  data = data.view_as(data1)
194  mask = mask.view_as(mask1)
195  dims = dims1
196  return data, mask, dims
197 
198 
199 # assume data, data1, data2 have same size
200 @torch.jit.script
201 def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2):
202  data = data * mask.type_as(data)
203  cond_data = data
204  cond_mask = data
205  if data.dim() == 1:
206  for _ in range(data1.dim() - 1):
207  data = data.unsqueeze(data.dim())
208  cond_data = data.expand_as(data1)
209  cond_mask = data.expand_as(mask1)
210  res_data = torch.where(cond_data, data1, data2)
211  res_mask = torch.where(cond_mask, mask1, mask2)
212  res_dims = dims1.__or__(dims2)
213  return res_data, res_mask, res_dims
214 
215 
216 @torch.jit.script
217 def batch_where_scalar(cond, data1, mask1, dims1, data2, mask2, dims2):
218  cond = torch.zeros([1], dtype=torch.uint8)
219  res_data = torch.where(cond, data1, data2)
220  res_mask = torch.where(cond, mask1, mask2)
221  res_dims = torch.where(cond, dims1, dims2)
222  return res_data, res_mask, res_dims
223 
224 
225 @torch.jit.script
226 def batch_update(batch_data, batch_mask, batch_dims, new_data, new_mask, new_dims):
227  data = torch.where(new_mask, new_data, batch_data)
228  return data, new_mask, new_dims # TODO: consider whether return new_mask and new_dims
229 
230 
231 @torch.jit.script
232 def batch_any(data, mask, dims):
233  return torch.gt(torch.sum(data * mask), 0)
234 
235 
236 @torch.jit.script
237 def batch_type_as(data, mask, dims, data1, mask1, dims1):
238  return data.type_as(data1), mask, dims
239 
240 
241 @torch.jit.script
242 def batch_gt(data, mask, dims, data1, mask1, dims1):
243  return torch.gt(data, data1), mask * mask1, dims.__or__(dims1)
244 
245 
246 @torch.jit.script
247 def batch_gt_scalar(data1, data2):
248  return torch.gt(data1, data2)
249 
250 
251 @torch.jit.script
252 def batch_gt_one_scalar(data, mask, dims, other_):
253  other = float(other_)
254  return torch.gt(data, other), mask, dims
255 
256 
257 @torch.jit.script
258 def batch_lt(data, mask, dims, data1, mask1, dims1):
259  return torch.lt(data, data1), mask * mask1, dims.__or__(dims1)
260 
261 
262 @torch.jit.script
263 def batch_eq(data, mask, dims, data1, mask1, dims1):
264  return torch.eq(data, data1), mask * mask1, dims.__or__(dims1)
265 
266 
267 @torch.jit.script
268 def batch_size(data, mask, dims, dim_):
269  dim = int(dim_)
270  return data.size(dim)
271 
272 
273 @torch.jit.script
274 def batch_dim(data, mask, dims):
275  return data.dim()
276 
277 
278 @torch.jit.script
279 def batch_squeeze(data, mask, dims, dim_):
280  if int(dim_) < 0:
281  dim_ = dim_ + data.dim()
282  dim = int(dim_)
283  # if dim == 0:
284  # raise ValueError("cannot do squeeze along batch_dim")
285  data = data.squeeze(dim)
286  mask = mask.squeeze(dim)
287  dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
288  return data, mask, dims
289 
290 
291 @torch.jit.script
292 def batch_unsqueeze(data, mask, dims, dim_):
293  if int(dim_) < 0:
294  dim_ = dim_ + data.dim() + 1
295  dim = int(dim_)
296  # if dim == 0:
297  # raise ValueError("cannot do unsqueeze along batch_dim")
298  data = data.unsqueeze(dim)
299  mask = mask.unsqueeze(dim)
300  dims = torch.cat((dims[:dim], torch.zeros([1], dtype=torch.uint8), dims[dim:dims.size(0)]))
301  return data, mask, dims
302 
303 
304 @torch.jit.script
305 def batch_argmax(data, mask, dims, dim_, keepdim_):
306  dim = int(dim_)
307  keepdim = bool(keepdim_)
308  # if dim == 0:
309  # raise ValueError("cannot do argmax along batch_dim")
310  batch_size = data.size(0)
311  res_data = torch.zeros([0])
312  for i in range(batch_size):
313  if bool(dims[dim - 1]):
314  if dim - 1 != 0:
315  m = mask[i].transpose(0, dim - 1)
316  else:
317  m = mask[i]
318  valid_num = m.sum(0, keepdim=True)
319  while(valid_num.dim() >= 1):
320  valid_num = valid_num[0]
321  d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
322  else:
323  d = data[i].unsqueeze(0)
324  d = d.argmax(dim, keepdim)
325  if i == 0:
326  res_data = d
327  else:
328  res_data = torch.cat([res_data, d], 0)
329  if keepdim:
330  mask = mask
331  else:
332  mask = mask.select(dim, 0)
333  dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
334  return res_data, mask, dims
335 
336 
337 @torch.jit.script
338 def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_):
339  k = int(k_)
340  dim = int(dim_)
341  largest = bool(largest_)
342  sorted = bool(sorted_)
343  # if dim == 0:
344  # raise ValueError("cannot do topk along batch_dim")
345  batch_size = data.size(0)
346  res_data = torch.zeros([0])
347  res_index = torch.zeros([0])
348  for i in range(batch_size):
349  if bool(dims[dim - 1]):
350  if dim - 1 != 0:
351  m = mask[i].transpose(0, dim - 1)
352  else:
353  m = mask[i]
354  valid_num = m.sum(0, keepdim=True)
355  while(valid_num.dim() >= 1):
356  valid_num = valid_num[0]
357  d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
358  else:
359  d = data[i].unsqueeze(0)
360  d, idx = d.topk(k, dim, largest, sorted)
361  if i == 0:
362  res_data = d
363  res_index = idx
364  else:
365  res_data = torch.cat([res_data, d], 0)
366  res_index = torch.cat([res_index, idx], 0)
367  if bool(dims[dim - 1]):
368  mask = mask.narrow(dim, 0, k)
369  return res_data, mask, dims, res_index, mask, dims
370 
371 
372 @torch.jit.script
373 def batch_softmax(data, mask, dims, dim_):
374  dim = int(dim_)
375  # if dim == 0:
376  # raise ValueError("cannot do softmax along batch_dim")
377  batch_size = data.size(0)
378  max_len = data.size(dim)
379  res_data = torch.zeros([0])
380  for i in range(batch_size):
381  if bool(dims[dim - 1]):
382  if dim - 1 != 0:
383  m = mask[i].transpose(0, dim - 1)
384  else:
385  m = mask[i]
386  valid_num = m.sum(0, keepdim=True)
387  while(valid_num.dim() >= 1):
388  valid_num = valid_num[0]
389  valid_num = int(valid_num)
390  d = data[i].unsqueeze(0).narrow(dim, 0, valid_num).softmax(dim)
391  if valid_num < max_len:
392  d = torch.cat([d, data[i].unsqueeze(0).narrow(dim, valid_num, max_len - valid_num)], dim)
393  else:
394  d = data[i].unsqueeze(0).softmax(dim)
395  if i == 0:
396  res_data = d
397  else:
398  res_data = torch.cat([res_data, d], 0)
399  return res_data, mask, dims
400 
401 
402 # size argument in dynamic dimension has to be -1
403 # in static dimension, size has to be specified, -1 is not supported
404 @torch.jit.script
405 def batch_view(data, mask, dims, sizes):
406  batch_size = data.size(0)
407  # if(sizes[0] != batch_size and sizes[0] != -1 and sizes[0] != 1):
408  # raise "first dim in view must be 1, -1, or batch size"
409  # for i in range(dims.size(0)):
410  # if dims[0] == 1 and sizes[i + 1] != -1:
411  # raise "size argument in dynamic dimension has to be -1"
412  sizes = sizes.type_as(torch.ones([1], dtype=torch.int))
413  data_sizes_ = torch.cat([torch.ones([1], dtype=torch.int) * batch_size, sizes.narrow(0, 1, sizes.size(0) - 1)], 0)
414  data_sizes = data_sizes_._tensor_to_list()
415  res_data = data.view(data_sizes)
416  mask_sizes_ = data_sizes_.narrow(0, 0, 1)
417  res_dims = data_sizes_.narrow(0, 0, 1)
418  for i_ in range(sizes.size(0) - 1):
419  i = i_ + 1
420  if bool(sizes[i] == -1):
421  cur_size_ = mask.size(i)
422  cur_dim = 1
423  else:
424  cur_size_ = 1
425  cur_dim = 0
426  mask_sizes_ = torch.cat([mask_sizes_, torch.ones([1], dtype=torch.int) * cur_size_])
427  res_dims = torch.cat([res_dims, torch.ones([1], dtype=torch.int) * cur_dim])
428  mask_sizes = mask_sizes_._tensor_to_list()
429  res_mask = mask.view(mask_sizes)
430  return res_data, res_mask, res_dims.narrow(0, 1, res_dims.size(0) - 1).type_as(dims)
431 
432 
433 @torch.jit.script
434 def batch_cat2(data1, mask1, dims1, data2, mask2, dims2, dim_):
435  dim = int(dim_)
436  data = torch.cat([data1, data2], dim)
437  if bool(dims1[dim - 1]):
438  mask = torch.cat([mask1, mask2], dim)
439  else:
440  mask = mask1
441  return data, mask, dims1
442 
443 
444 @torch.jit.script
445 def batch_cat3(data1, mask1, dims1, data2, mask2, dims2, data3, mask3, dims3, dim_):
446  dim = int(dim_)
447  data = torch.cat([data1, data2, data3], dim)
448  if bool(dims1[dim - 1]):
449  mask = torch.cat([mask1, mask2, mask3], dim)
450  else:
451  mask = mask1
452  return data, mask, dims1
453 
454 
455 @torch.jit.script
456 def batch_narrow(data, mask, dims, dimension_, start_, length_):
457  dimension = int(dimension_)
458  start = int(start_)
459  length = int(length_)
460  # if dimension == 0:
461  # raise ValueError("cannot do narrow along batch_dim")
462  data = data.narrow(dimension, start, length)
463  if bool(dims[dimension - 1]):
464  mask = mask.narrow(dimension, start, length)
465  else:
466  mask = mask.narrow(dimension, 0, 1)
467  return data, mask, dims
468 
469 
470 @torch.jit.script
471 def batch_sum(data, mask, dims):
472  data = data * mask.type_as(data)
473  for _ in range(dims.size(0)):
474  data = data.sum(1)
475  mask = torch.ones([data.size(0)], dtype=torch.uint8)
476  dims = dims[:0] # empty tensor
477  return data, mask, dims
478 
479 
480 @torch.jit.script
481 def batch_from_scalar_tensor(data):
482  data = data.unsqueeze(0)
483  mask = torch.ones([1], dtype=torch.uint8)
484  dims = torch.zeros([0], dtype=torch.uint8)
485  return data, mask, dims
486 
487 torch.register_batch_operator("tanh", batch_tanh.graph)
488 torch.register_batch_operator("sigmoid", batch_sigmoid.graph)
489 torch.register_batch_operator("relu", batch_relu.graph)
490 torch.register_batch_operator("neg", batch_neg.graph)
491 torch.register_batch_operator("neg", batch_neg_scalar.graph)
492 torch.register_batch_operator("add", batch_add.graph)
493 torch.register_batch_operator("add", batch_add_scalar.graph)
494 torch.register_batch_operator("sub", batch_sub.graph)
495 torch.register_batch_operator("sub", batch_sub_scalar.graph)
496 torch.register_batch_operator("mul", batch_mul.graph)
497 torch.register_batch_operator("mul", batch_mul_scalar.graph)
498 torch.register_batch_operator("div", batch_div.graph)
499 torch.register_batch_operator("matmul", batch_matmul.graph)
500 torch.register_batch_operator("mm", batch_mm.graph)
501 torch.register_batch_operator("fmod", batch_fmod.graph)
502 torch.register_batch_operator("zeros_like", batch_zeros_like.graph)
503 torch.register_batch_operator("select", batch_select.graph)
504 torch.register_batch_operator("index_select", batch_index_select.graph)
505 torch.register_batch_operator("view_as", batch_view_as.graph)
506 torch.register_batch_operator("where", batch_where.graph)
507 torch.register_batch_operator("where", batch_where_scalar.graph)
508 torch.register_batch_operator("update", batch_update.graph)
509 torch.register_batch_operator("any", batch_any.graph)
510 torch.register_batch_operator("type_as", batch_type_as.graph)
511 torch.register_batch_operator("gt", batch_gt.graph)
512 torch.register_batch_operator("gt", batch_gt_scalar.graph)
513 torch.register_batch_operator("gt", batch_gt_one_scalar.graph)
514 torch.register_batch_operator("lt", batch_lt.graph)
515 torch.register_batch_operator("eq", batch_eq.graph)
516 torch.register_batch_operator("size", batch_size.graph)
517 torch.register_batch_operator("dim", batch_dim.graph)
518 torch.register_batch_operator("squeeze", batch_squeeze.graph)
519 torch.register_batch_operator("unsqueeze", batch_unsqueeze.graph)
520 torch.register_batch_operator("argmax", batch_argmax.graph)
521 torch.register_batch_operator("topk", batch_topk.graph)
522 torch.register_batch_operator("softmax", batch_softmax.graph)
523 torch.register_batch_operator("view", batch_view.graph)
524 torch.register_batch_operator("cat", batch_cat2.graph)
525 torch.register_batch_operator("cat", batch_cat3.graph)
526 torch.register_batch_operator("narrow", batch_narrow.graph)
527 torch.register_batch_operator("sum", batch_sum.graph)
528 torch.register_batch_operator("batch_from_scalar_tensor", batch_from_scalar_tensor.graph)