1 #ifndef THCUNN_ROW2COL_H 2 #define THCUNN_ROW2COL_H 4 #include <THC/THCNumerics.cuh> 5 #include <THCUNN/common.h> 8 template <
typename Dtype>
10 row2col_kernel(
const int n,
const Dtype *data_row,
const int width,
11 const int ksize_w,
const int pad_w,
const int stride_w,
12 const int dilation_w,
const int width_col, Dtype *data_col) {
13 CUDA_KERNEL_LOOP(index, n) {
14 int w_out = index % width_col;
16 int channel_in = index;
17 int channel_out = channel_in * ksize_w;
18 int w_in = w_out * stride_w - pad_w;
19 data_col += (channel_out)*width_col + w_out;
20 data_row += (channel_in)*width + w_in;
21 for (
int j = 0; j < ksize_w; ++j) {
22 int w = w_in + j * dilation_w;
23 *data_col = (w >= 0 && w < width) ? data_row[j * dilation_w]
24 : ScalarConvert<int, Dtype>::to(0);
25 data_col += width_col;
30 template <
typename Dtype>
31 void row2col(cudaStream_t stream,
const Dtype *data_row,
const int channels,
32 const int width,
const int ksize_w,
const int pad_w,
33 const int stride_w,
const int dilation_w, Dtype *data_col) {
37 (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
38 int num_kernels = channels * width_col;
40 row2col_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>(
41 num_kernels, data_row, width, ksize_w, pad_w, stride_w, 1, width_col,
43 THCudaCheck(cudaGetLastError());
46 template <
typename Dtype,
typename Acctype>
47 __global__
void col2row_kernel(
const int n,
const Dtype *data_col,
48 const int width,
const int channels,
49 const int kernel_w,
const int pad_w,
50 const int stride_w,
const int dilation_w,
51 const int width_col, Dtype *data_row) {
52 CUDA_KERNEL_LOOP(index, n) {
53 Acctype val = Acctype(0);
54 const int w_row = index % width + pad_w;
55 const int c_row = index / width;
56 int kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
58 const int w_col_start = (w_row < kernel_extent_w)
60 : (w_row - kernel_extent_w) / stride_w + 1;
61 const int w_col_end = min(w_row / stride_w + 1, width_col);
62 for (
int w_col = w_col_start; w_col < w_col_end; w_col += 1) {
63 int w_k = (w_row - w_col * stride_w);
64 if (w_k % dilation_w == 0) {
66 int data_col_index = (c_row * kernel_w + w_k) * width_col + w_col;
67 val += data_col[data_col_index];
70 data_row[index] = ScalarConvert<Acctype, Dtype>::to(val);
74 template <
typename Dtype,
typename Acctype>
75 void col2row(cudaStream_t stream,
const Dtype *data_col,
const int channels,
76 const int width,
const int patch_w,
const int pad_w,
77 const int stride_w,
const int dilation_w, Dtype *data_row) {
79 (width + 2 * pad_w - (dilation_w * (patch_w - 1) + 1)) / stride_w + 1;
80 int num_kernels = channels * width;
84 Dtype, Acctype><<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>(
85 num_kernels, data_col, width, channels, patch_w, pad_w, stride_w,
86 dilation_w, width_col, data_row);
88 THCudaCheck(cudaGetLastError());