1 #ifndef THCUNN_VOL2COL_H 2 #define THCUNN_VOL2COL_H 4 #include <THCUNN/common.h> 5 #include <THC/THCNumerics.cuh> 8 template <
typename Dtype>
9 __global__
void vol2col_kernel(
const int n,
const Dtype* data_vol,
10 const int depth,
const int height,
const int width,
11 const int ksize_t,
const int ksize_h,
const int ksize_w,
12 const int pad_t,
const int pad_h,
const int pad_w,
13 const int stride_t,
const int stride_h,
const int stride_w,
14 const int dilation_t,
const int dilation_h,
const int dilation_w,
15 const int depth_col,
const int height_col,
const int width_col,
17 CUDA_KERNEL_LOOP(index, n) {
18 int w_out = index % width_col;
20 int h_out = index % height_col;
22 int t_out = index % depth_col;
23 int channel_in = index / depth_col;
24 int channel_out = channel_in * ksize_t * ksize_h * ksize_w;
25 int t_in = t_out * stride_t - pad_t;
26 int h_in = h_out * stride_h - pad_h;
27 int w_in = w_out * stride_w - pad_w;
28 data_col += ((channel_out * depth_col + t_out) * height_col + h_out) * width_col + w_out;
29 data_vol += ((channel_in * depth + t_in) * height + h_in) * width + w_in;
30 for (
int i = 0; i < ksize_t; ++i) {
31 for (
int j = 0; j < ksize_h; ++j) {
32 for (
int k = 0; k < ksize_w; ++k) {
33 int t = t_in + i * dilation_t;
34 int h = h_in + j * dilation_h;
35 int w = w_in + k * dilation_w;
36 *data_col = (t >= 0 && h >= 0 && w >= 0 && t < depth && h < height && w < width) ?
37 data_vol[i * dilation_t * height * width + j * dilation_h * width + k * dilation_w] : ScalarConvert<int, Dtype>::to(0);
38 data_col += depth_col * height_col * width_col;
45 template <
typename Dtype>
46 void vol2col(cudaStream_t stream,
const Dtype* data_vol,
const int channels,
47 const int depth,
const int height,
const int width,
48 const int depth_col,
const int height_col,
const int width_col,
49 const int ksize_t,
const int ksize_h,
const int ksize_w,
50 const int pad_t,
const int pad_h,
const int pad_w,
51 const int stride_t,
const int stride_h,
const int stride_w,
52 const int dilation_t,
const int dilation_h,
const int dilation_w,
56 int num_kernels = channels * depth_col * height_col * width_col;
58 vol2col_kernel <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>> (
59 num_kernels, data_vol, depth, height, width, ksize_t, ksize_h, ksize_w,
60 pad_t, pad_h, pad_w, stride_t, stride_h, stride_w,
61 dilation_t, dilation_h, dilation_w,
62 depth_col, height_col, width_col, data_col
64 THCudaCheck(cudaGetLastError());
67 template <
typename Dtype,
typename Acctype>
68 __global__
void vol2im_kernel(
const int n,
const Dtype* data_col,
69 const int depth,
const int height,
const int width,
const int channels,
70 const int kernel_t,
const int kernel_h,
const int kernel_w,
71 const int pad_t,
const int pad_h,
const int pad_w,
72 const int stride_t,
const int stride_h,
const int stride_w,
73 const int dilation_t,
const int dilation_h,
const int dilation_w,
74 const int depth_col,
const int height_col,
const int width_col,
76 CUDA_KERNEL_LOOP(index, n) {
77 Acctype val = Acctype(0);
78 const int w_im = index % width + pad_w;
79 const int h_im = (index / width) % height + pad_h;
80 const int t_im = (index / width / height) % depth + pad_t;
81 const int c_im = index / (width * height * depth);
82 int kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
83 int kernel_extent_h = (kernel_h - 1) * dilation_h + 1;
84 int kernel_extent_t = (kernel_t - 1) * dilation_t + 1;
86 const int w_col_start =
87 (w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1;
88 const int w_col_end = min(w_im / stride_w + 1, width_col);
89 const int h_col_start =
90 (h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1;
91 const int h_col_end = min(h_im / stride_h + 1, height_col);
92 const int t_col_start =
93 (t_im < kernel_extent_t) ? 0 : (t_im - kernel_extent_t) / stride_t + 1;
94 const int t_col_end = min(t_im / stride_t + 1, depth_col);
96 for (
int t_col = t_col_start; t_col < t_col_end; t_col += 1) {
97 for (
int h_col = h_col_start; h_col < h_col_end; h_col += 1) {
98 for (
int w_col = w_col_start; w_col < w_col_end; w_col += 1) {
99 int t_k = (t_im - t_col * stride_t);
100 int h_k = (h_im - h_col * stride_h);
101 int w_k = (w_im - w_col * stride_w);
102 if (t_k % dilation_t == 0 && h_k % dilation_h == 0 && w_k % dilation_w == 0) {
107 (((((c_im * kernel_t + t_k) * kernel_h + h_k) * kernel_w + w_k)
108 * depth_col + t_col) * height_col + h_col) * width_col + w_col;
109 val += data_col[data_col_index];
114 data_vol[index] = ScalarConvert<Acctype, Dtype>::to(val);
118 template <
typename Dtype,
typename Acctype>
119 void col2vol(cudaStream_t stream,
const Dtype* data_col,
const int channels,
120 const int depth,
const int height,
const int width,
121 const int output_depth,
const int output_height,
const int output_width,
122 const int patch_t,
const int patch_h,
const int patch_w,
123 const int pad_t,
const int pad_h,
const int pad_w,
124 const int stride_t,
const int stride_h,
const int stride_w,
125 const int dilation_t,
const int dilation_h,
const int dilation_w,
127 int num_kernels = channels * depth * height * width;
130 vol2im_kernel<Dtype, Acctype> <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>> (
131 num_kernels, data_col, depth, height, width, channels,
132 patch_t, patch_h, patch_w, pad_t, pad_h, pad_w, stride_t, stride_h, stride_w,
133 dilation_t, dilation_h, dilation_w,
134 output_depth, output_height, output_width, data_vol
136 THCudaCheck(cudaGetLastError());