8 def batch_tanh(data, mask, dims):
     9     data = torch.tanh(data)
    10     return data, mask, dims
    14 def batch_sigmoid(data, mask, dims):
    15     data = torch.sigmoid(data)
    16     return data, mask, dims
    20 def batch_relu(data, mask, dims):
    21     data = torch.relu(data)
    22     return data, mask, dims
    26 def batch_neg(data, mask, dims):
    27     data = torch.neg(data)
    28     return data, mask, dims
    32 def batch_neg_scalar(data):
    33     return torch.neg(data)
    37 def batch_add(data1, mask1, dims1, data2, mask2, dims2, alpha_):
    39     data = torch.add(data1, data2, alpha=alpha)
    41     dims = dims1.__or__(dims2)
    42     return data, mask, dims
    46 def batch_add_scalar(data, mask, dims, other, alpha_):
    48     data = torch.add(data, other.type_as(data), alpha=alpha)
    49     return data, mask, dims
    53 def batch_sub(data1, mask1, dims1, data2, mask2, dims2, alpha_):
    55     data = torch.sub(data1, data2, alpha=alpha)
    57     dims = dims1.__or__(dims2)
    58     return data, mask, dims
    62 def batch_sub_scalar(data1, data2):
    67 def batch_mul(data1, mask1, dims1, data2, mask2, dims2):
    68     data = torch.mul(data1, data2)
    70     dims = dims1.__or__(dims2)
    71     return data, mask, dims
    75 def batch_mul_scalar(data1, data2):
    80 def batch_div(data, mask, dims, other):  
    81     data = torch.div(data, other)
    82     return data, mask, dims
    86 def batch_mm(data1, mask1, dims1, data2, mask2, dims2):
    87     data1 = data1 * mask1.type_as(data1)
    88     data2 = data2 * mask2.type_as(data2)
    89     data = torch.bmm(data1, data2)
    90     mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1))
    91     dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)]))
    92     return data, mask, dims
    96 def batch_matmul(data1, mask1, dims1, data2, mask2, dims2):
    99     data1 = data1 * mask1.type_as(data1)
   100     data2 = data2 * mask2.type_as(data2)
   102         data1 = data1.unsqueeze(-2)
   104         data2 = data2.unsqueeze(-1)
   105     data = torch.bmm(data1, data2)
   108     if d1 == 1 
and d2 == 1:
   111         data = data.squeeze(-1).squeeze(-1)
   112         mask = mask1.narrow(1, 0, 1).squeeze(-1)
   114     if d1 == 2 
and d2 == 1:
   117         data = data.squeeze(-1)
   118         mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1).unsqueeze(-1)).squeeze(-1)
   120     elif d1 == 1 
and d2 == 2:
   123         data = data.squeeze(-2)
   124         mask = torch.bmm(mask1.narrow(1, 0, 1).unsqueeze(-2), mask2.narrow(1, 0, 1)).squeeze(-2)
   125         dims = dims2[1:dims2.size(0)]
   126     elif d1 == 2 
and d2 == 2:
   129         mask = torch.bmm(mask1.narrow(2, 0, 1), mask2.narrow(1, 0, 1))
   130         dims = torch.cat((dims1[:1], dims2[1:dims2.size(0)]))
   133     return data, mask, dims
   137 def batch_select(data, mask, dims, dim_, index_):
   142     data = data.select(dim, index)
   143     if bool(dims[dim - 1]):
   144         mask = mask.select(dim, index)
   146         mask = mask.select(dim, 0)
   147     dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
   148     return data, mask, dims
   152 def batch_fmod(data, mask, dims, other_):
   154     data = torch.fmod(data, other)
   155     return data, mask, dims
   159 def batch_zeros_like(data, mask, dims):
   160     res_data = torch.zeros_like(data)
   161     return res_data, mask, dims
   165 def batch_index_select(data, mask, dims, dim_, index_data, index_mask, index_dims):
   169     batch_size = data.size(0)  
   170     res_data = torch.zeros([0])
   171     res_mask = torch.zeros([0])
   172     for i 
in range(batch_size):
   173         d = data[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
   174         if bool(dims[dim - 1]):
   175             m = mask[i].index_select(dim - 1, index_data[i]).unsqueeze(0)
   177             m = mask[i].unsqueeze(0)
   182             res_data = torch.cat((res_data, d), 0)
   183             res_mask = torch.cat((res_mask, m), 0)
   184     return res_data, res_mask, dims
   188 def batch_view_as(data, mask, dims, data1, mask1, dims1):
   193     data = data.view_as(data1)
   194     mask = mask.view_as(mask1)
   196     return data, mask, dims
   201 def batch_where(data, mask, dims, data1, mask1, dims1, data2, mask2, dims2):
   202     data = data * mask.type_as(data)
   206         for _ 
in range(data1.dim() - 1):
   207             data = data.unsqueeze(data.dim())
   208         cond_data = data.expand_as(data1)
   209         cond_mask = data.expand_as(mask1)
   210     res_data = torch.where(cond_data, data1, data2)
   211     res_mask = torch.where(cond_mask, mask1, mask2)
   212     res_dims = dims1.__or__(dims2)
   213     return res_data, res_mask, res_dims
   217 def batch_where_scalar(cond, data1, mask1, dims1, data2, mask2, dims2):
   218     cond = torch.zeros([1], dtype=torch.uint8)
   219     res_data = torch.where(cond, data1, data2)
   220     res_mask = torch.where(cond, mask1, mask2)
   221     res_dims = torch.where(cond, dims1, dims2)
   222     return res_data, res_mask, res_dims
   226 def batch_update(batch_data, batch_mask, batch_dims, new_data, new_mask, new_dims):
   227     data = torch.where(new_mask, new_data, batch_data)
   228     return data, new_mask, new_dims  
   232 def batch_any(data, mask, dims):
   233     return torch.gt(torch.sum(data * mask), 0)
   237 def batch_type_as(data, mask, dims, data1, mask1, dims1):
   238     return data.type_as(data1), mask, dims
   242 def batch_gt(data, mask, dims, data1, mask1, dims1):
   243     return torch.gt(data, data1), mask * mask1, dims.__or__(dims1)
   247 def batch_gt_scalar(data1, data2):
   248     return torch.gt(data1, data2)
   252 def batch_gt_one_scalar(data, mask, dims, other_):
   253     other = float(other_)
   254     return torch.gt(data, other), mask, dims
   258 def batch_lt(data, mask, dims, data1, mask1, dims1):
   259     return torch.lt(data, data1), mask * mask1, dims.__or__(dims1)
   263 def batch_eq(data, mask, dims, data1, mask1, dims1):
   264     return torch.eq(data, data1), mask * mask1, dims.__or__(dims1)
   268 def batch_size(data, mask, dims, dim_):
   270     return data.size(dim)
   274 def batch_dim(data, mask, dims):
   279 def batch_squeeze(data, mask, dims, dim_):
   281         dim_ = dim_ + data.dim()
   285     data = data.squeeze(dim)
   286     mask = mask.squeeze(dim)
   287     dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
   288     return data, mask, dims
   292 def batch_unsqueeze(data, mask, dims, dim_):
   294         dim_ = dim_ + data.dim() + 1
   298     data = data.unsqueeze(dim)
   299     mask = mask.unsqueeze(dim)
   300     dims = torch.cat((dims[:dim], torch.zeros([1], dtype=torch.uint8), dims[dim:dims.size(0)]))
   301     return data, mask, dims
   305 def batch_argmax(data, mask, dims, dim_, keepdim_):
   307     keepdim = bool(keepdim_)
   310     batch_size = data.size(0)
   311     res_data = torch.zeros([0])
   312     for i 
in range(batch_size):
   313         if bool(dims[dim - 1]):
   315                 m = mask[i].transpose(0, dim - 1)
   318             valid_num = m.sum(0, keepdim=
True)
   319             while(valid_num.dim() >= 1):
   320                 valid_num = valid_num[0]
   321             d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
   323             d = data[i].unsqueeze(0)
   324         d = d.argmax(dim, keepdim)
   328             res_data = torch.cat([res_data, d], 0)
   332         mask = mask.select(dim, 0)
   333         dims = torch.cat((dims[:dim - 1], dims[dim:dims.size(0)]))
   334     return res_data, mask, dims
   338 def batch_topk(data, mask, dims, k_, dim_, largest_, sorted_):
   341     largest = bool(largest_)
   342     sorted = bool(sorted_)
   345     batch_size = data.size(0)
   346     res_data = torch.zeros([0])
   347     res_index = torch.zeros([0])
   348     for i 
in range(batch_size):
   349         if bool(dims[dim - 1]):
   351                 m = mask[i].transpose(0, dim - 1)
   354             valid_num = m.sum(0, keepdim=
True)
   355             while(valid_num.dim() >= 1):
   356                 valid_num = valid_num[0]
   357             d = data[i].unsqueeze(0).narrow(dim, 0, int(valid_num))
   359             d = data[i].unsqueeze(0)
   360         d, idx = d.topk(k, dim, largest, sorted)
   365             res_data = torch.cat([res_data, d], 0)
   366             res_index = torch.cat([res_index, idx], 0)
   367     if bool(dims[dim - 1]):
   368         mask = mask.narrow(dim, 0, k)
   369     return res_data, mask, dims, res_index, mask, dims
   373 def batch_softmax(data, mask, dims, dim_):
   377     batch_size = data.size(0)
   378     max_len = data.size(dim)
   379     res_data = torch.zeros([0])
   380     for i 
in range(batch_size):
   381         if bool(dims[dim - 1]):
   383                 m = mask[i].transpose(0, dim - 1)
   386             valid_num = m.sum(0, keepdim=
True)
   387             while(valid_num.dim() >= 1):
   388                 valid_num = valid_num[0]
   389             valid_num = int(valid_num)
   390             d = data[i].unsqueeze(0).narrow(dim, 0, valid_num).softmax(dim)
   391             if valid_num < max_len:
   392                 d = torch.cat([d, data[i].unsqueeze(0).narrow(dim, valid_num, max_len - valid_num)], dim)
   394             d = data[i].unsqueeze(0).softmax(dim)
   398             res_data = torch.cat([res_data, d], 0)
   399     return res_data, mask, dims
   405 def batch_view(data, mask, dims, sizes):
   406     batch_size = data.size(0)
   412     sizes = sizes.type_as(torch.ones([1], dtype=torch.int))
   413     data_sizes_ = torch.cat([torch.ones([1], dtype=torch.int) * batch_size, sizes.narrow(0, 1, sizes.size(0) - 1)], 0)
   414     data_sizes = data_sizes_._tensor_to_list()
   415     res_data = data.view(data_sizes)
   416     mask_sizes_ = data_sizes_.narrow(0, 0, 1)
   417     res_dims = data_sizes_.narrow(0, 0, 1)
   418     for i_ 
in range(sizes.size(0) - 1):
   420         if bool(sizes[i] == -1):
   421             cur_size_ = mask.size(i)
   426         mask_sizes_ = torch.cat([mask_sizes_, torch.ones([1], dtype=torch.int) * cur_size_])
   427         res_dims = torch.cat([res_dims, torch.ones([1], dtype=torch.int) * cur_dim])
   428     mask_sizes = mask_sizes_._tensor_to_list()
   429     res_mask = mask.view(mask_sizes)
   430     return res_data, res_mask, res_dims.narrow(0, 1, res_dims.size(0) - 1).type_as(dims)
   434 def batch_cat2(data1, mask1, dims1, data2, mask2, dims2, dim_):
   436     data = torch.cat([data1, data2], dim)
   437     if bool(dims1[dim - 1]):
   438         mask = torch.cat([mask1, mask2], dim)
   441     return data, mask, dims1
   445 def batch_cat3(data1, mask1, dims1, data2, mask2, dims2, data3, mask3, dims3, dim_):
   447     data = torch.cat([data1, data2, data3], dim)
   448     if bool(dims1[dim - 1]):
   449         mask = torch.cat([mask1, mask2, mask3], dim)
   452     return data, mask, dims1
   456 def batch_narrow(data, mask, dims, dimension_, start_, length_):
   457     dimension = int(dimension_)
   459     length = int(length_)
   462     data = data.narrow(dimension, start, length)
   463     if bool(dims[dimension - 1]):
   464         mask = mask.narrow(dimension, start, length)
   466         mask = mask.narrow(dimension, 0, 1)
   467     return data, mask, dims
   471 def batch_sum(data, mask, dims):
   472     data = data * mask.type_as(data)
   473     for _ 
in range(dims.size(0)):
   475     mask = torch.ones([data.size(0)], dtype=torch.uint8)
   477     return data, mask, dims
   481 def batch_from_scalar_tensor(data):
   482     data = data.unsqueeze(0)
   483     mask = torch.ones([1], dtype=torch.uint8)
   484     dims = torch.zeros([0], dtype=torch.uint8)
   485     return data, mask, dims
   487 torch.register_batch_operator(
"tanh", batch_tanh.graph)
   488 torch.register_batch_operator(
"sigmoid", batch_sigmoid.graph)
   489 torch.register_batch_operator(
"relu", batch_relu.graph)
   490 torch.register_batch_operator(
"neg", batch_neg.graph)
   491 torch.register_batch_operator(
"neg", batch_neg_scalar.graph)
   492 torch.register_batch_operator(
"add", batch_add.graph)
   493 torch.register_batch_operator(
"add", batch_add_scalar.graph)
   494 torch.register_batch_operator(
"sub", batch_sub.graph)
   495 torch.register_batch_operator(
"sub", batch_sub_scalar.graph)
   496 torch.register_batch_operator(
"mul", batch_mul.graph)
   497 torch.register_batch_operator(
"mul", batch_mul_scalar.graph)
   498 torch.register_batch_operator(
"div", batch_div.graph)
   499 torch.register_batch_operator(
"matmul", batch_matmul.graph)
   500 torch.register_batch_operator(
"mm", batch_mm.graph)
   501 torch.register_batch_operator(
"fmod", batch_fmod.graph)
   502 torch.register_batch_operator(
"zeros_like", batch_zeros_like.graph)
   503 torch.register_batch_operator(
"select", batch_select.graph)
   504 torch.register_batch_operator(
"index_select", batch_index_select.graph)
   505 torch.register_batch_operator(
"view_as", batch_view_as.graph)
   506 torch.register_batch_operator(
"where", batch_where.graph)
   507 torch.register_batch_operator(
"where", batch_where_scalar.graph)
   508 torch.register_batch_operator(
"update", batch_update.graph)
   509 torch.register_batch_operator(
"any", batch_any.graph)
   510 torch.register_batch_operator(
"type_as", batch_type_as.graph)
   511 torch.register_batch_operator(
"gt", batch_gt.graph)
   512 torch.register_batch_operator(
"gt", batch_gt_scalar.graph)
   513 torch.register_batch_operator(
"gt", batch_gt_one_scalar.graph)
   514 torch.register_batch_operator(
"lt", batch_lt.graph)
   515 torch.register_batch_operator(
"eq", batch_eq.graph)
   516 torch.register_batch_operator(
"size", batch_size.graph)
   517 torch.register_batch_operator(
"dim", batch_dim.graph)
   518 torch.register_batch_operator(
"squeeze", batch_squeeze.graph)
   519 torch.register_batch_operator(
"unsqueeze", batch_unsqueeze.graph)
   520 torch.register_batch_operator(
"argmax", batch_argmax.graph)
   521 torch.register_batch_operator(
"topk", batch_topk.graph)
   522 torch.register_batch_operator(
"softmax", batch_softmax.graph)
   523 torch.register_batch_operator(
"view", batch_view.graph)
   524 torch.register_batch_operator(
"cat", batch_cat2.graph)
   525 torch.register_batch_operator(
"cat", batch_cat3.graph)
   526 torch.register_batch_operator(
"narrow", batch_narrow.graph)
   527 torch.register_batch_operator(
"sum", batch_sum.graph)
   528 torch.register_batch_operator(
"batch_from_scalar_tensor", batch_from_scalar_tensor.graph)