Caffe2 - C++ API
A deep learning, cross platform ML framework
mpscnn_kernels.h
1 
17 // @generated
18 
19 static const char* MPSCNN_KERNELS = R"V0G0N(
20 
21 
22 #include <metal_stdlib>
23 
24 using namespace metal;
25 
26 constant ushort ushort_arg_0[[function_constant(0)]];
27 constant ushort ushort_arg_1[[function_constant(1)]];
28 constant ushort ushort_arg_2[[function_constant(2)]];
29 constant ushort ushort_arg_3[[function_constant(3)]];
30 constant ushort ushort_arg_4[[function_constant(4)]];
31 constant ushort ushort_arg_5[[function_constant(5)]];
32 constant ushort ushort_arg_6[[function_constant(6)]];
33 constant ushort ushort_arg_7[[function_constant(7)]];
34 constant ushort ushort_arg_8[[function_constant(8)]];
35 constant ushort ushort_arg_9[[function_constant(9)]];
36 
37 inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; }
38 
39 kernel void affine(constant half4* scale[[buffer(0)]],
40  constant half4* shift[[buffer(1)]],
41  texture2d_array<half, access::read> in[[texture(0)]],
42  texture2d_array<half, access::write> out[[texture(1)]],
43  ushort3 gid[[thread_position_in_grid]]) {
44  const ushort C = ushort_arg_0;
45  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
46  return;
47  }
48  const half4 scale_c = scale[gid.z % divRoundUp(C, 4)];
49  const half4 shift_c = shift[gid.z % divRoundUp(C, 4)];
50  ushort2 gid_(gid.x, gid.y);
51  const half4 x = in.read(gid_, gid.z);
52  const half4 y = scale_c * x + shift_c;
53  out.write(y, gid_, gid.z);
54 }
55 
56 kernel void affine_nonarray(constant half4* scale[[buffer(0)]],
57  constant half4* shift[[buffer(1)]],
58  texture2d<half, access::read> in[[texture(0)]],
59  texture2d<half, access::write> out[[texture(1)]],
60  ushort2 gid[[thread_position_in_grid]]) {
61  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
62  return;
63  }
64  const half4 scale_c = scale[0];
65  const half4 shift_c = shift[0];
66  half4 x = in.read(gid);
67  const half4 y = scale_c * x + shift_c;
68  out.write(y, gid);
69 }
70 
71 kernel void prelu_nonshared(constant half4* weights[[buffer(0)]],
72  texture2d_array<half, access::read> in[[texture(0)]],
73  texture2d_array<half, access::write> out[[texture(1)]],
74  ushort3 gid[[thread_position_in_grid]]) {
75  const ushort C = ushort_arg_0;
76  const ushort S = ushort_arg_1;
77  const bool channel_shared = S == 1;
78  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
79  return;
80  }
81  half4 w = channel_shared ? half4(weights[0][0], weights[0][0], weights[0][0], weights[0][0])
82  : weights[gid.z % divRoundUp(C, 4)];
83  ushort2 gid_(gid.x, gid.y);
84  half4 x = in.read(gid_, gid.z);
85  half4 y = select(x * w, x, x > 0.0h);
86  out.write(y, gid_, gid.z);
87 }
88 
89 kernel void prelu_nonshared_nonarray(constant half4* weights[[buffer(0)]],
90  texture2d<half, access::read> in[[texture(0)]],
91  texture2d<half, access::write> out[[texture(1)]],
92  ushort2 gid[[thread_position_in_grid]]) {
93  // const ushort C = ushort_arg_0;
94  const ushort S = ushort_arg_1;
95  const bool channel_shared = S == 1;
96  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
97  return;
98  }
99  half4 w = channel_shared ? half4(weights[0][0], weights[0][0], weights[0][0], weights[0][0])
100  : weights[0];
101  half4 x = in.read(gid);
102  half4 y = select(x * w, x, x > 0.0h);
103  out.write(y, gid);
104 }
105 
106 // One block per texture.
107 // 256 threads per block.
108 using AccT = float4;
109 
110 constant const bool instance_norm_has_prelu = ushort_arg_1 > 0;
111 
112 kernel void instance_norm(
113  constant half4* weights[[buffer(0)]],
114  constant half4* bias[[buffer(1)]],
115  constant half4* preluWeights[[ buffer(2), function_constant(instance_norm_has_prelu) ]],
116  texture2d_array<half, access::read> in[[texture(0)]],
117  texture2d_array<half, access::write> out[[texture(1)]],
118  ushort3 gid[[thread_position_in_grid]],
119  ushort tid[[thread_index_in_threadgroup]],
120  ushort3 tcount[[threads_per_threadgroup]]) {
121  if (gid.z >= out.get_array_size()) {
122  return;
123  }
124  const ushort C = ushort_arg_0;
125  const ushort S = ushort_arg_1;
126  const bool channel_shared = S == 1;
127  const ushort c = gid.z % divRoundUp(C, 4);
128  constexpr ushort THREADGROUP_SIZE = 256;
129 
130  threadgroup AccT per_thread_state[THREADGROUP_SIZE];
131  // Each block handles a single texture.
132  per_thread_state[tid] = 0;
133  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
134  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
135  per_thread_state[tid] += static_cast<AccT>(in.read(ushort2(x, y), gid.z));
136  }
137  }
138 
139  threadgroup_barrier(mem_flags::mem_threadgroup);
140 
141  // 256 -> 32 reduction
142  if (tid < 32) {
143  per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
144  per_thread_state[tid + 96] + per_thread_state[tid + 128] +
145  per_thread_state[tid + 160] + per_thread_state[tid + 192] +
146  per_thread_state[tid + 224];
147  }
148 
149  threadgroup_barrier(mem_flags::mem_threadgroup);
150 
151  if (tid == 0) {
152  AccT sum = 0.0;
153  for (ushort i = 0; i < 32; ++i) {
154  sum += per_thread_state[i];
155  }
156  sum /= (in.get_width() * in.get_height());
157  per_thread_state[0] = sum;
158  }
159  threadgroup_barrier(mem_flags::mem_threadgroup);
160  // Broadcast to all threads.
161  const AccT mean = per_thread_state[0];
162 
163  threadgroup_barrier(mem_flags::mem_threadgroup);
164 
165  per_thread_state[tid] = 0;
166  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
167  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
168  AccT delta = static_cast<AccT>(in.read(ushort2(x, y), gid.z)) - mean;
169  per_thread_state[tid] += delta * delta;
170  }
171  }
172 
173  threadgroup_barrier(mem_flags::mem_threadgroup);
174 
175  // 256 -> 32 reduction
176  if (tid < 32) {
177  per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
178  per_thread_state[tid + 96] + per_thread_state[tid + 128] +
179  per_thread_state[tid + 160] + per_thread_state[tid + 192] +
180  per_thread_state[tid + 224];
181  }
182 
183  threadgroup_barrier(mem_flags::mem_threadgroup);
184 
185  if (tid == 0) {
186  AccT sum = 0.0;
187  for (ushort i = 0; i < 32; ++i) {
188  sum += per_thread_state[i];
189  }
190  sum /= (in.get_width() * in.get_height());
191  per_thread_state[0] = 1.0 / sqrt(max(sum, AccT(1e-5, 1e-5, 1e-5, 1e-5)) + 1.0e-5);
192  }
193 
194  threadgroup_barrier(mem_flags::mem_threadgroup);
195  // Broadcast to all threads.
196  const AccT inv_var = per_thread_state[0];
197 
198  const AccT c_weights = static_cast<AccT>(weights[c]);
199  const AccT c_bias = static_cast<AccT>(bias[c]);
200 
201  const AccT scale = inv_var * c_weights;
202  const AccT shift = c_bias - mean * scale;
203 
204  half4 w;
205  if (instance_norm_has_prelu) {
206  w = channel_shared ? half4(preluWeights[0][0]) : preluWeights[c];
207  }
208  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
209  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
210  half4 scaled =
211  static_cast<half4>(static_cast<AccT>(in.read(ushort2(x, y), gid.z)) * scale + shift);
212  if (instance_norm_has_prelu) {
213  scaled = select(scaled * w, scaled, scaled > 0.0h);
214  }
215  out.write(scaled, ushort2(x, y), gid.z);
216  }
217  }
218 }
219 
220 // One block per texture.
221 // 256 threads per block.
222 kernel void instance_norm_nonarray(
223  constant half4* weights[[buffer(0)]],
224  constant half4* bias[[buffer(1)]],
225  constant half4* preluWeights[[ buffer(2), function_constant(instance_norm_has_prelu) ]],
226  texture2d<half, access::read> in[[texture(0)]],
227  texture2d<half, access::write> out[[texture(1)]],
228  ushort3 gid[[thread_position_in_grid]],
229  ushort tid[[thread_index_in_threadgroup]],
230  ushort3 tcount[[threads_per_threadgroup]]) {
231  // const ushort C = ushort_arg_0;
232  const ushort S = ushort_arg_1;
233  const bool channel_shared = S == 1;
234  constexpr ushort THREADGROUP_SIZE = 256;
235 
236  threadgroup AccT per_thread_state[THREADGROUP_SIZE];
237  // Each block handles a single texture.
238  per_thread_state[tid] = 0;
239  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
240  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
241  per_thread_state[tid] += static_cast<AccT>(in.read(ushort2(x, y)));
242  }
243  }
244 
245  threadgroup_barrier(mem_flags::mem_threadgroup);
246 
247  // 256 -> 32 reduction
248  if (tid < 32) {
249  per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
250  per_thread_state[tid + 96] + per_thread_state[tid + 128] +
251  per_thread_state[tid + 160] + per_thread_state[tid + 192] +
252  per_thread_state[tid + 224];
253  }
254 
255  threadgroup_barrier(mem_flags::mem_threadgroup);
256 
257  if (tid == 0) {
258  AccT sum = 0.0;
259  for (ushort i = 0; i < 32; ++i) {
260  sum += per_thread_state[i];
261  }
262  sum /= (in.get_width() * in.get_height());
263  per_thread_state[0] = sum;
264  }
265  threadgroup_barrier(mem_flags::mem_threadgroup);
266  // Broadcast to all threads.
267  const AccT mean = per_thread_state[0];
268 
269  threadgroup_barrier(mem_flags::mem_threadgroup);
270 
271  per_thread_state[tid] = 0;
272  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
273  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
274  AccT delta = static_cast<AccT>(in.read(ushort2(x, y))) - mean;
275  per_thread_state[tid] += delta * delta;
276  }
277  }
278 
279  threadgroup_barrier(mem_flags::mem_threadgroup);
280 
281  // 256 -> 32 reduction
282  if (tid < 32) {
283  per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
284  per_thread_state[tid + 96] + per_thread_state[tid + 128] +
285  per_thread_state[tid + 160] + per_thread_state[tid + 192] +
286  per_thread_state[tid + 224];
287  }
288 
289  threadgroup_barrier(mem_flags::mem_threadgroup);
290 
291  if (tid == 0) {
292  AccT sum = 0.0;
293  for (ushort i = 0; i < 32; ++i) {
294  sum += per_thread_state[i];
295  }
296  sum /= (in.get_width() * in.get_height());
297  per_thread_state[0] = 1.0 / sqrt(max(sum, AccT(1e-5, 1e-5, 1e-5, 1e-5)) + 1.0e-5);
298  }
299 
300  threadgroup_barrier(mem_flags::mem_threadgroup);
301  // Broadcast to all threads.
302  const AccT inv_var = per_thread_state[0];
303 
304  const AccT c_weights = static_cast<AccT>(weights[0]);
305  const AccT c_bias = static_cast<AccT>(bias[0]);
306 
307  const AccT scale = inv_var * c_weights;
308  const AccT shift = c_bias - mean * scale;
309 
310  half4 w;
311  if (instance_norm_has_prelu) {
312  w = channel_shared ? half4(preluWeights[0][0]) : preluWeights[0];
313  }
314  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
315  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
316  half4 scaled = static_cast<half4>(static_cast<AccT>(in.read(ushort2(x, y))) * scale + shift);
317  if (instance_norm_has_prelu) {
318  scaled = select(scaled * w, scaled, scaled > 0.0h);
319  }
320  out.write(scaled, ushort2(x, y));
321  }
322  }
323 }
324 
325 kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]],
326  texture2d_array<half, access::write> out[[texture(0)]],
327  ushort3 gid[[thread_position_in_grid]]) {
328  const ushort C = ushort_arg_0;
329  const ushort H = ushort_arg_1;
330  const ushort W = ushort_arg_2;
331  if (gid.x >= W || gid.y >= H) {
332  return;
333  }
334 
335  const ushort n = gid.z / divRoundUp(C, 4);
336  const ushort c = gid.z - n * divRoundUp(C, 4);
337 
338  // TODO: are the `else` branches needed?
339  // TODO: trick the optimizer for case where C == 4?
340 #define CHW_TO_CHWP4(idx, n, c_, h, w) \
341 if ((c_) < C) { \
342 trns[idx] = in[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)]; \
343 } else { \
344 trns[idx] = 0.0h; \
345 }
346 
347  half4 trns;
348  CHW_TO_CHWP4(0, n, c * 4 + 0, gid.y, gid.x);
349  CHW_TO_CHWP4(1, n, c * 4 + 1, gid.y, gid.x);
350  CHW_TO_CHWP4(2, n, c * 4 + 2, gid.y, gid.x);
351  CHW_TO_CHWP4(3, n, c * 4 + 3, gid.y, gid.x);
352 #undef CHW_TO_CHWP4
353 
354  out.write(trns, gid.xy, gid.z);
355 }
356 
357 kernel void copy_nchw_to_metal_nonarray(constant float* in[[buffer(0)]],
358  texture2d<half, access::write> out[[texture(0)]],
359  ushort2 gid[[thread_position_in_grid]]) {
360  const ushort C = ushort_arg_0;
361  const ushort H = ushort_arg_1;
362  const ushort W = ushort_arg_2;
363 
364  if (gid.x >= W || gid.y >= H) {
365  return;
366  }
367 
368  half4 trns;
369  // TODO: are the `else` branches needed?
370  // TODO: trick the optimizer for case where C % 4 == 0?
371 
372 #define CHW_TO_CHWP4(idx, c, h, w) \
373 if ((c) < C) { \
374 trns[idx] = in[int(c) * H * W + int(h) * W + int(w)]; \
375 } else { \
376 trns[idx] = 0.0h; \
377 }
378 
379  CHW_TO_CHWP4(0, 0, gid.y, gid.x);
380  CHW_TO_CHWP4(1, 1, gid.y, gid.x);
381  CHW_TO_CHWP4(2, 2, gid.y, gid.x);
382  CHW_TO_CHWP4(3, 3, gid.y, gid.x);
383 #undef CHW_TO_CHWP4
384 
385  out.write(trns, gid.xy);
386 }
387 
388 kernel void copy_metal_to_nchw(texture2d_array<half, access::read> in[[texture(0)]],
389  device float* out[[buffer(0)]],
390  ushort3 gid[[thread_position_in_grid]]) {
391  const ushort C = ushort_arg_0;
392  const ushort H = ushort_arg_1;
393  const ushort W = ushort_arg_2;
394 
395  if (gid.x >= W || gid.y >= H) {
396  return;
397  }
398  const ushort n = gid.z / divRoundUp(C, 4);
399  const ushort c = gid.z - n * divRoundUp(C, 4);
400 
401  half4 cs = in.read(gid.xy, gid.z);
402 
403 #define CHWP4_TO_CHW(idx, n, c_, h, w) \
404 if ((c_) < C) { \
405 out[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)] = cs[idx]; \
406 }
407 
408  CHWP4_TO_CHW(0, n, c * 4 + 0, gid.y, gid.x);
409  CHWP4_TO_CHW(1, n, c * 4 + 1, gid.y, gid.x);
410  CHWP4_TO_CHW(2, n, c * 4 + 2, gid.y, gid.x);
411  CHWP4_TO_CHW(3, n, c * 4 + 3, gid.y, gid.x);
412 #undef CHWP4_TO_CHW
413 }
414 
415 kernel void copy_metal_to_nchw_nonarray(texture2d<half, access::read> in[[texture(0)]],
416  device float* out[[buffer(0)]],
417  ushort2 gid[[thread_position_in_grid]]) {
418  const ushort C = ushort_arg_0;
419  const ushort H = ushort_arg_1;
420  const ushort W = ushort_arg_2;
421 
422  if (gid.x >= W || gid.y >= H) {
423  return;
424  }
425 
426  half4 cs = in.read(gid.xy);
427 
428 #define CHWP4_TO_CHW(idx, c, h, w) \
429 if ((c) < C) { \
430 out[int(c) * H * W + int(h) * W + int(w)] = cs[idx]; \
431 }
432 
433  CHWP4_TO_CHW(0, 0, gid.y, gid.x);
434  CHWP4_TO_CHW(1, 1, gid.y, gid.x);
435  CHWP4_TO_CHW(2, 2, gid.y, gid.x);
436  CHWP4_TO_CHW(3, 3, gid.y, gid.x);
437 #undef CHWP4_TO_CHW
438 }
439 
440 kernel void convtranspose_upscale(texture2d_array<half, access::read> in[[texture(0)]],
441  texture2d_array<half, access::write> out[[texture(1)]],
442  ushort3 gid[[thread_position_in_grid]]) {
443  // All resolved at compile time.
444  // Assume symmetric kernel/stride/pad for now.
445  const ushort kernel_ = ushort_arg_0;
446  const ushort stride = ushort_arg_1;
447  const ushort pad = ushort_arg_2;
448 
449  half4 zero(0.0h, 0.0h, 0.0h, 0.0h);
450 
451  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
452  return;
453  }
454  const ushort2 gid_ = gid.xy;
455  if (gid.x < kernel_ - 1 - pad || gid.y < kernel_ - 1 - pad) {
456  out.write(zero, gid_, gid.z);
457  return;
458  }
459 
460  if (((gid.x - (kernel_ - 1 - pad)) % stride == 0) &&
461  ((gid.y - (kernel_ - 1 - pad)) % stride == 0)) {
462  ushort2 in_pos((gid.x - (kernel_ - 1 - pad)) / stride, (gid.y - (kernel_ - 1 - pad)) / stride);
463 
464  if (in_pos.x < in.get_width() && in_pos.y < in.get_height()) {
465  half4 input = in.read(in_pos, gid.z);
466  out.write(input, gid_, gid.z);
467  } else {
468  out.write(zero, gid_, gid.z);
469  }
470  } else {
471  out.write(zero, gid_, gid.z);
472  }
473 }
474 
475 constant bool has_in_arr = (ushort_arg_7 > 1 || ushort_arg_0 * ushort_arg_1 * ushort_arg_6 > 4);
476 constant bool has_out_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
477 constant bool has_in_tex = (!has_in_arr);
478 constant bool has_out_tex = (!has_out_arr);
479 
480 kernel void col2im(
481  texture2d_array<half, access::read> ina[[ texture(0), function_constant(has_in_arr) ]],
482  texture2d<half, access::read> in[[ texture(0), function_constant(has_in_tex) ]],
483  texture2d_array<half, access::write> outa[[ texture(1), function_constant(has_out_arr) ]],
484  texture2d<half, access::write> out[[ texture(1), function_constant(has_out_tex) ]],
485  constant half4* bias[[buffer(0)]],
486  ushort3 gid[[thread_position_in_grid]]) {
487  if (has_out_tex) {
488  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
489  return;
490  }
491  } else {
492  if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
493  return;
494  }
495  }
496  const ushort kernel_h = ushort_arg_0;
497  const ushort kernel_w = ushort_arg_1;
498  const ushort stride_h = ushort_arg_2;
499  const ushort stride_w = ushort_arg_3;
500  const ushort pad_l = ushort_arg_4;
501  const ushort pad_t = ushort_arg_5;
502  const ushort C = ushort_arg_6;
503  // const int N = ushort_arg_7;
504  const ushort height_col = ushort_arg_8;
505  const ushort width_col = ushort_arg_9;
506 
507  const ushort n = gid.z / divRoundUp(C, 4);
508  const ushort c = gid.z - n * divRoundUp(C, 4);
509 
510  const ushort w = gid.x + pad_l;
511  const ushort h = gid.y + pad_t;
512 
513  // compute the start and end of the output
514  const ushort w_col_start = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
515  const ushort w_col_end = min(ushort(w / stride_w + 1), ushort(width_col));
516  const ushort h_col_start = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
517  const ushort h_col_end = min(ushort(h / stride_h + 1), ushort(height_col));
518 
519  float4 val = static_cast<float4>(bias[c]);
520  for (ushort h_col = h_col_start; h_col < h_col_end; ++h_col) {
521  for (ushort w_col = w_col_start; w_col < w_col_end; ++w_col) {
522  const ushort w_k = w - w_col * stride_w;
523  const ushort h_k = h - h_col * stride_h;
524 
525  // layout is essentially: [N][K][K][C][H][W]
526  // - where the divRoundUp(K * K * C, 4) channels are interleaved as usual.
527  // Thus, it's actually [N][divRoundUp(K * K * C, 4)][H][W].
528 
529  // If C % 4 is not zero, then we have to play some games via partial indexing.
530  // TODO: is it worth optimizing this loop via padding in C?
531  if (C % 4 == 0) {
532  ushort c_col = n * kernel_h * kernel_w * divRoundUp(C, 4) +
533  h_k * kernel_w * divRoundUp(C, 4) + w_k * divRoundUp(C, 4) + c;
534  if (has_in_arr) {
535  val += static_cast<float4>(ina.read(ushort2(w_col, h_col), c_col));
536  }
537  if (has_in_tex) {
538  val += static_cast<float4>(in.read(ushort2(w_col, h_col), c_col));
539  }
540  } else {
541  half4 components(0, 0, 0, 0);
542  for (auto i = 0; i < 4; ++i) {
543  ushort c_col_i = n * divRoundUp(kernel_h * kernel_w * C, 4) * 4 + h_k * kernel_w * C +
544  w_k * C + c * 4 + i;
545  ushort c_col_i_z = c_col_i / 4;
546  ushort c_col_i_off = c_col_i - c_col_i_z * 4;
547  if (has_in_arr) {
548  components[i] = ina.read(ushort2(w_col, h_col), c_col_i_z)[c_col_i_off];
549  }
550  if (has_in_tex) {
551  components[i] = in.read(ushort2(w_col, h_col))[c_col_i_off];
552  }
553  }
554  val += static_cast<float4>(components);
555  }
556  }
557  }
558  if (has_out_arr) {
559  outa.write(static_cast<half4>(val), gid.xy, gid.z);
560  }
561  if (has_out_tex) {
562  out.write(static_cast<half4>(val), gid.xy);
563  }
564 }
565 
566 kernel void preprocess_stylizer(device uchar4* in[[buffer(0)]],
567  constant half* mean[[buffer(1)]],
568  constant half4* noise[[buffer(2)]],
569  texture2d<half, access::write> out[[texture(0)]],
570  ushort2 gid[[thread_position_in_grid]]) {
571 
572  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
573  return;
574  }
575  const ushort noise_size = ushort_arg_0;
576 
577  half4 mean_half(mean[0], mean[1], mean[2], 0.0h);
578  uint input_noise_idx = ((uint)out.get_width() * (uint)gid.y + (uint)gid.x) % (noise_size / 4);
579  const half4 input_noise = noise[input_noise_idx];
580  const uint W = out.get_width();
581 #define in_at(h, w) in[(uint)(h)*W + (uint)(w)]
582  uchar4 input = in_at(gid.y, gid.x);
583 #undef in_at
584  half4 input_half = static_cast<half4>(input);
585  out.write(input_half - mean_half + input_noise, gid);
586 }
587 
588 kernel void deprocess_stylizer(texture2d<half, access::read> in[[texture(0)]],
589  device uchar4* out[[buffer(0)]],
590  constant half* mean[[buffer(1)]],
591  ushort2 gid[[thread_position_in_grid]]) {
592  if (gid.x >= in.get_width() || gid.y >= in.get_height()) {
593  return;
594  }
595 
596  half4 value = in.read(gid);
597 
598  half4 mean_h(mean[0], mean[1], mean[2], 0.0h);
599  half4 min_h(0.0h, 0.0h, 0.0h, 255.0h);
600  half4 max_h(255.0h, 255.0h, 255.0h, 255.0h);
601  half4 clamped = clamp(value + mean_h, min_h, max_h);
602  const uint W = in.get_width();
603 #define out_at(h, w, v) out[(uint)(h)*W + (uint)(w)] = (v)
604  out_at(gid.y, gid.x, static_cast<uchar4>(clamped));
605 #undef out_at
606 }
607 
608 kernel void reflection_padding_nonarray(texture2d<half, access::read> in[[texture(0)]],
609  texture2d<half, access::write> out[[texture(1)]],
610  ushort2 gid[[thread_position_in_grid]]) {
611  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
612  return;
613  }
614  ushort H = in.get_height();
615  ushort PH = out.get_height();
616 
617  // Note: we assume symmetric padding on H/W here, which is verified
618  // in the calling code.
619  ushort pad_h = (PH - H) / 2;
620  ushort W = in.get_width();
621  ushort PW = out.get_width();
622  ushort pad_w = (PW - W) / 2;
623 
624  short h = short(gid.y) - short(pad_h);
625  h = max(h, short(-h));
626  h = min(h, short(2 * H - h - 2));
627 
628  short w = short(gid.x) - short(pad_w);
629  w = max(w, short(-w));
630  w = min(w, short(2 * W - w - 2));
631 
632  ushort2 inid(w, h);
633  out.write(in.read(inid), gid);
634 }
635 
636 kernel void reflection_padding(texture2d_array<half, access::read> in[[texture(0)]],
637  texture2d_array<half, access::write> out[[texture(1)]],
638  ushort3 gid[[thread_position_in_grid]]) {
639  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
640  return;
641  }
642  ushort H = in.get_height();
643  ushort PH = out.get_height();
644 
645  // Note: we assume symmetric padding on H/W here, which is verified
646  // in the calling code.
647  ushort pad_h = (PH - H) / 2;
648  ushort W = in.get_width();
649  ushort PW = out.get_width();
650  ushort pad_w = (PW - W) / 2;
651 
652  short h = short(gid.y) - short(pad_h);
653  h = max(h, short(-h));
654  h = min(h, short(2 * H - h - 2));
655 
656  short w = short(gid.x) - short(pad_w);
657  w = max(w, short(-w));
658  w = min(w, short(2 * W - w - 2));
659 
660  ushort2 inid(w, h);
661 
662  out.write(in.read(inid, gid.z), gid.xy, gid.z);
663 }
664 
665 kernel void bilinear_upsample(texture2d<half, access::sample> in[[texture(0)]],
666  texture2d<half, access::write> out[[texture(1)]],
667  ushort2 gid[[thread_position_in_grid]]) {
668  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
669  return;
670  }
671  ushort2 src = gid / 2;
672  constexpr sampler sampler(address::clamp_to_edge, filter::linear, coord::pixel);
673  half4 value = in.sample(sampler, static_cast<float2>(src));
674  out.write(value, gid);
675 }
676 
677 constant bool in0_is_tex = ushort_arg_0 <= 1 && ushort_arg_1 <= 4;
678 constant bool in0_is_arr = !in0_is_tex;
679 
680 kernel void elementwise_mul(texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
681  texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
682  texture2d<half, access::write> out[[texture(2), function_constant(in0_is_tex)]],
683  texture2d_array<half, access::write> outa[[texture(2), function_constant(in0_is_arr)]],
684  constant float* in1[[buffer(1)]],
685  ushort3 gid[[thread_position_in_grid]]) {
686  ushort last_dim = ushort_arg_2;
687  ushort idx;
688  if (in0_is_tex) {
689  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
690  return;
691  }
692  idx = gid.y * out.get_width() + gid.x;
693  } else {
694  if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
695  return;
696  }
697  idx = gid.y * outa.get_width() + gid.x;
698  }
699  ushort2 gid_ = gid.xy;
700  if (in0_is_tex) {
701  out.write(in0.read(gid_) * in1[idx % last_dim], gid_);
702  } else {
703  outa.write(ina0.read(gid_, gid.z) * in1[idx % last_dim], gid_, gid.z);
704  }
705 }
706 
707 kernel void elementwise_sub(texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
708  texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
709  texture2d<half, access::write> out[[texture(2), function_constant(in0_is_tex)]],
710  texture2d_array<half, access::write> outa[[texture(2), function_constant(in0_is_arr)]],
711  constant float* in1[[buffer(1)]],
712  ushort3 gid[[thread_position_in_grid]]) {
713  ushort last_dim = ushort_arg_2;
714  ushort idx;
715  if (in0_is_tex) {
716  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
717  return;
718  }
719  idx = gid.y * out.get_width() + gid.x;
720  } else {
721  if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
722  return;
723  }
724  idx = gid.y * outa.get_width() + gid.x;
725  }
726  ushort2 gid_ = gid.xy;
727  if (in0_is_tex) {
728  out.write(in0.read(gid_) - in1[idx % last_dim], gid_);
729  } else {
730  outa.write(ina0.read(gid_, gid.z) - in1[idx % last_dim], gid_, gid.z);
731  }
732 }
733 
734 kernel void elementwise_add_nonarray(texture2d<half, access::read> in0[[texture(0)]],
735  texture2d<half, access::read> in1[[texture(1)]],
736  texture2d<half, access::write> out[[texture(2)]],
737  ushort2 gid[[thread_position_in_grid]]) {
738  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
739  return;
740  }
741  out.write(in0.read(gid) + in1.read(gid), gid);
742 }
743 
744 kernel void elementwise_add(texture2d_array<half, access::read> in0[[texture(0)]],
745  texture2d_array<half, access::read> in1[[texture(1)]],
746  texture2d_array<half, access::write> out[[texture(2)]],
747  ushort3 gid[[thread_position_in_grid]]) {
748  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
749  return;
750  }
751  ushort2 gid_ = gid.xy;
752  out.write(in0.read(gid_, gid.z) + in1.read(gid_, gid.z), gid_, gid.z);
753 }
754 
755 constant bool has_in0_arg = (ushort_arg_0 > 0);
756 constant bool has_in1_arg = (ushort_arg_1 > 0);
757 constant bool has_in2_arg = (ushort_arg_2 > 0);
758 constant bool has_in3_arg = (ushort_arg_3 > 0);
759 
760 constant bool has_in0_tex = (has_in0_arg && ushort_arg_0 <= 4 && ushort_arg_5 <= 1);
761 constant bool has_in1_tex = (has_in1_arg && ushort_arg_1 <= 4 && ushort_arg_5 <= 1);
762 constant bool has_in2_tex = (has_in2_arg && ushort_arg_2 <= 4 && ushort_arg_5 <= 1);
763 constant bool has_in3_tex = (has_in3_arg && ushort_arg_3 <= 4 && ushort_arg_5 <= 1);
764 
765 constant bool has_in0_array = (has_in0_arg && !has_in0_tex);
766 constant bool has_in1_array = (has_in1_arg && !has_in1_tex);
767 constant bool has_in2_array = (has_in2_arg && !has_in2_tex);
768 constant bool has_in3_array = (has_in3_arg && !has_in3_tex);
769 
770 constant bool concat_has_out_tex = (ushort_arg_4 <= 4 && ushort_arg_5 <= 1);
771 constant bool concat_has_out_array = !concat_has_out_tex;
772 
773 inline ushort idx_3(ushort z, ushort C0, ushort C1, ushort C2, ushort C3) {
774  if (z < C0) {
775  return 0;
776  }
777  if (z < (C0 + C1)) {
778  return 1;
779  }
780  if (z < (C0 + C1 + C2)) {
781  return 2;
782  }
783  return 3;
784 }
785 
786 inline ushort idx_2(ushort z, ushort C0, ushort C1, ushort C2) {
787  if (z < C0) {
788  return 0;
789  }
790  if (z < (C0 + C1)) {
791  return 1;
792  }
793  return 2;
794 }
795 
796 inline ushort idx_1(ushort z, ushort C0, ushort C1) {
797  if (z < C0) {
798  return 0;
799  } else {
800  return 1;
801  }
802 }
803 
804 inline ushort idx_0(ushort z, ushort C0) { return 0; }
805 
806 // in a texture_array with size C, find the offset for image N at plane c.
807 inline constexpr ushort z_off(ushort n, ushort c, ushort C) { return n * divRoundUp(C, 4) + c / 4; }
808 
809 kernel void concat(
810  texture2d<half, access::read> in0[[ texture(0), function_constant(has_in0_tex) ]],
811  texture2d<half, access::read> in1[[ texture(1), function_constant(has_in1_tex) ]],
812  texture2d<half, access::read> in2[[ texture(2), function_constant(has_in2_tex) ]],
813  texture2d<half, access::read> in3[[ texture(3), function_constant(has_in3_tex) ]],
814  texture2d_array<half, access::read> ina0[[ texture(0), function_constant(has_in0_array) ]],
815  texture2d_array<half, access::read> ina1[[ texture(1), function_constant(has_in1_array) ]],
816  texture2d_array<half, access::read> ina2[[ texture(2), function_constant(has_in2_array) ]],
817  texture2d_array<half, access::read> ina3[[ texture(3), function_constant(has_in3_array) ]],
818  texture2d<half, access::write> out[[texture(5),
819  function_constant(concat_has_out_tex) ]],
820  texture2d_array<half, access::write> outa[[texture(5),
821  function_constant(concat_has_out_array) ]],
822  ushort3 gid[[thread_position_in_grid]]) {
823  if (concat_has_out_tex) {
824  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
825  return;
826  }
827  } else {
828  if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
829  return;
830  }
831  }
832 
833  const ushort C0 = ushort_arg_0;
834  const ushort C1 = ushort_arg_1;
835  const ushort C2 = ushort_arg_2;
836  const ushort C3 = ushort_arg_3;
837  const ushort C = C0 + C1 + C2 + C3;
838  const ushort n = gid.z / divRoundUp(C, 4);
839  const ushort c = gid.z - n * divRoundUp(C, 4);
840  // Fill channel 4*c to 4*(c+1) of nth image of output
841 
842  ushort2 gid_ = ushort2(gid.x, gid.y);
843  half4 value;
844 
845  for (int off = 0; off < 4; ++off) {
846  ushort cur_channel = c * 4 + off;
847  ushort cur_idx = 0;
848  if (cur_channel >= C) {
849  break;
850  }
851  if (has_in3_arg) {
852  cur_idx = idx_3(cur_channel, C0, C1, C2, C3);
853  } else if (has_in2_arg) {
854  cur_idx = idx_2(cur_channel, C0, C1, C2);
855  } else if (has_in1_arg) {
856  cur_idx = idx_1(cur_channel, C0, C1);
857  } else if (has_in0_arg) {
858  cur_idx = idx_0(cur_channel, C0);
859  } else {
860  // never reached.
861  cur_idx = 0;
862  }
863  ushort src_off = 0;
864  switch (cur_idx) {
865  case 0:
866  src_off = cur_channel % 4;
867  break;
868  case 1:
869  src_off = (cur_channel - C0) % 4;
870  break;
871  case 2:
872  src_off = (cur_channel - (C0 + C1)) % 4;
873  break;
874  case 3:
875  src_off = (cur_channel - (C0 + C1 + C2)) % 4;
876  break;
877  }
878  // try to see if we can only issue one read op for the 4 values
879  bool fast_path = false;
880  if (off == 0 && src_off == 0 && (cur_channel + 3) < C) {
881  ushort last_idx = 0;
882  if (has_in3_arg) {
883  last_idx = idx_3(cur_channel + 3, C0, C1, C2, C3);
884  } else if (has_in2_arg) {
885  last_idx = idx_2(cur_channel + 3, C0, C1, C2);
886  } else if (has_in1_arg) {
887  last_idx = idx_1(cur_channel + 3, C0, C1);
888  } else if (has_in0_arg) {
889  last_idx = idx_0(cur_channel + 3, C0);
890  } else {
891  // never reached.
892  last_idx = 0;
893  }
894  if (cur_idx == last_idx) {
895  fast_path = true;
896  }
897  }
898  switch (cur_idx) {
899  case 0: {
900  if (has_in0_tex) {
901  if (fast_path) {
902  value = in0.read(gid_);
903  } else {
904  value[off] = in0.read(gid_)[src_off];
905  }
906  }
907  if (has_in0_array) {
908  if (fast_path) {
909  value = ina0.read(gid_, z_off(n, cur_channel, C0));
910  } else {
911  value[off] = ina0.read(gid_, z_off(n, cur_channel, C0))[src_off];
912  }
913  }
914  break;
915  }
916  case 1: {
917  if (has_in1_tex) {
918  if (fast_path) {
919  value = in1.read(gid_);
920  } else {
921  value[off] = in1.read(gid_)[src_off];
922  }
923  }
924  if (has_in1_array) {
925  if (fast_path) {
926  value = ina1.read(gid_, z_off(n, cur_channel - C0, C1));
927  } else {
928  value[off] = ina1.read(gid_, z_off(n, cur_channel - C0, C1))[src_off];
929  }
930  }
931  break;
932  }
933  case 2: {
934  if (has_in2_tex) {
935  if (fast_path) {
936  value = in2.read(gid_);
937  } else {
938  value[off] = in2.read(gid_)[src_off];
939  }
940  }
941  if (has_in2_array) {
942  if (fast_path) {
943  value = ina2.read(gid_, z_off(n, cur_channel - (C0 + C1), C2));
944  } else {
945  value[off] = ina2.read(gid_, z_off(n, cur_channel - (C0 + C1), C2))[src_off];
946  }
947  }
948  break;
949  }
950  case 3: {
951  if (has_in3_tex) {
952  if (fast_path) {
953  value = in3.read(gid_);
954  } else {
955  value[off] = in3.read(gid_)[src_off];
956  }
957  }
958  if (has_in3_array) {
959  if (fast_path) {
960  value = ina3.read(gid_, z_off(n, cur_channel - (C0 + C1 + C2), C3));
961  } else {
962  value[off] = ina3.read(gid_, z_off(n, cur_channel - (C0 + C1 + C2), C3))[src_off];
963  }
964  }
965  break;
966  }
967  }
968  if (fast_path) {
969  break;
970  }
971  }
972  if (concat_has_out_tex) {
973  out.write(value, gid_, gid.z);
974  } else {
975  outa.write(value, gid_, gid.z);
976  }
977 }
978 
979 using RoIT = half;
980 using RoIT4 = half4;
981 constant bool rw_has_in_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
982 constant bool rw_has_out_arr = (ushort_arg_4 > 1 || ushort_arg_2 > 4);
983 constant bool rw_has_in_tex = (!rw_has_in_arr);
984 constant bool rw_has_out_tex = (!rw_has_out_arr);
985 kernel void roi_warp(texture2d_array<half, access::sample> ina[[texture(0), function_constant(rw_has_in_arr)]],
986  texture2d<half, access::sample> in[[texture(0), function_constant(rw_has_in_tex)]],
987  texture2d_array<half, access::write> outa[[texture(1), function_constant(rw_has_out_arr)]],
988  texture2d<half, access::write> out[[texture(1), function_constant(rw_has_out_tex)]],
989  constant half4* rois[[buffer(0)]],
990  ushort3 gid[[thread_position_in_grid]]) {
991  ushort out_width, out_height;
992  if (rw_has_out_arr) {
993  out_width = outa.get_width();
994  out_height = outa.get_height();
995  } else {
996  out_width = out.get_width();
997  out_height = out.get_height();
998  }
999  if (gid.x >= out_width || gid.y >= out_height) {
1000  return;
1001  }
1002  constexpr sampler s2(coord::pixel, address::clamp_to_edge, filter::linear);
1003 
1004  const half spatial_scale = half(ushort_arg_0) / 10000;
1005  const ushort sampling_ratio = ushort_arg_1;
1006  const ushort C = ushort_arg_2;
1007  const ushort pw = gid.x;
1008  const ushort ph = gid.y;
1009  const ushort n = gid.z / divRoundUp(C, 4);
1010  const ushort c = gid.z % divRoundUp(C, 4);
1011 
1012  const RoIT4 roi_scaled = rois[n] * spatial_scale;
1013  const RoIT roi_start_w = roi_scaled[0];
1014  const RoIT roi_start_h = roi_scaled[1];
1015  const RoIT roi_end_w = roi_scaled[2];
1016  const RoIT roi_end_h = roi_scaled[3];
1017 
1018  // Force malformed ROIs to be 1x1
1019  const RoIT roi_width = max(roi_end_w - roi_start_w, (RoIT)1.);
1020  const RoIT roi_height = max(roi_end_h - roi_start_h, (RoIT)1.);
1021 
1022  const RoIT bin_size_h = static_cast<RoIT>(roi_height) / static_cast<RoIT>(out_height);
1023  const RoIT bin_size_w = static_cast<RoIT>(roi_width) / static_cast<RoIT>(out_width);
1024  const ushort roi_bin_grid_h = sampling_ratio > 0 ? sampling_ratio : ceil(roi_height / static_cast<RoIT>(out_height));
1025  const ushort roi_bin_grid_w = sampling_ratio > 0 ? sampling_ratio : ceil(roi_width / static_cast<RoIT>(out_width));
1026  const ushort iy_upper = (sampling_ratio > 0) ? roi_bin_grid_h : (roi_bin_grid_h + 1);
1027  const ushort ix_upper = (sampling_ratio > 0) ? roi_bin_grid_w : (roi_bin_grid_w + 1);
1028 
1029  const RoIT count = iy_upper * ix_upper;
1030 
1031  RoIT4 output_val = 0.0;
1032  for (int iy = 0; iy < iy_upper; iy++) {
1033  for (int ix = 0; ix < ix_upper; ix++) {
1034  const RoIT y =
1035  roi_start_h + ph * bin_size_h + iy * bin_size_h / static_cast<RoIT>(roi_bin_grid_h);
1036  const RoIT x =
1037  roi_start_w + pw * bin_size_w + ix * bin_size_w / static_cast<RoIT>(roi_bin_grid_w);
1038  if (rw_has_in_arr) {
1039  output_val += ina.sample(s2, float2(x + 0.5, y + 0.5), c);
1040  } else {
1041  output_val += in.sample(s2, float2(x + 0.5, y + 0.5));
1042  }
1043  }
1044  }
1045  output_val /= count;
1046  if (rw_has_out_arr) {
1047  outa.write(static_cast<half4>(output_val), gid.xy, gid.z);
1048  } else {
1049  out.write(static_cast<half4>(output_val), gid.xy);
1050  }
1051 }
1052 
1053 kernel void resize_nearest(texture2d_array<half, access::sample> in[[texture(0)]],
1054  texture2d_array<half, access::write> out[[texture(1)]],
1055  ushort3 gid[[thread_position_in_grid]]) {
1056  const ushort oH = ushort_arg_0;
1057  const ushort oW = ushort_arg_1;
1058  if (gid.x >= oW || gid.y >= oH) {
1059  return;
1060  }
1061  const float height_scale = float(ushort_arg_2) / 10000;
1062  const float width_scale = float(ushort_arg_3) / 10000;
1063  constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
1064  const int in_y = (int)(gid.y / height_scale);
1065  const int in_x = (int)(gid.x / width_scale);
1066  out.write(in.sample(s, float2(in_x, in_y), gid.z), gid.xy, gid.z);
1067 }
1068 
1069 kernel void resize_nearest_nonarray(texture2d<half, access::sample> in[[texture(0)]],
1070  texture2d<half, access::write> out[[texture(1)]],
1071  ushort2 gid[[thread_position_in_grid]]) {
1072  const ushort oH = ushort_arg_0;
1073  const ushort oW = ushort_arg_1;
1074  if (gid.x >= oW || gid.y >= oH) {
1075  return;
1076  }
1077  const float height_scale = float(ushort_arg_2) / 10000;
1078  const float width_scale = float(ushort_arg_3) / 10000;
1079  constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
1080  const int in_y = (int)(gid.y / height_scale);
1081  const int in_x = (int)(gid.x / width_scale);
1082  out.write(in.sample(s, float2(in_x, in_y)), gid.xy);
1083 }
1084 
1085 kernel void nms(device uint* mask[[buffer(0)]],
1086  constant float* proposals[[buffer(1)]],
1087  constant int* indices[[buffer(2)]],
1088  ushort2 tgid[[threadgroup_position_in_grid]],
1089  ushort2 tid[[thread_position_in_threadgroup]]) {
1090  const ushort num_proposals = ushort_arg_0;
1091  const ushort threads_per_group = ushort_arg_1;
1092  float nms_thresh = float(ushort_arg_2) / 10000.0;
1093  const ushort global_offset = ushort_arg_3;
1094  const ushort row_start = tgid.y;
1095  const ushort col_start = tgid.x;
1096  const ushort trd_id = tid.x;
1097 
1098  const short row_size = min(short(32), short(num_proposals - row_start * threads_per_group));
1099  const short col_size = min(short(32), short(num_proposals - col_start * threads_per_group));
1100 
1101  // mask the bit if the IoU between two proposals exceeds the threshold
1102  if (trd_id < row_size) {
1103  const ushort cur_idx = global_offset + row_start * threads_per_group + trd_id;
1104  const ushort offset = indices[cur_idx] * 4;
1105  const float4 cur_proposal = float4(
1106  proposals[offset], proposals[offset + 1], proposals[offset + 2], proposals[offset + 3]);
1107  uint cur_mask = 0;
1108  ushort group_start = 0; // start index within group
1109  if (row_start == col_start) {
1110  // if in the same group, start from the next
1111  group_start = trd_id + 1;
1112  }
1113  for (ushort i = group_start; i < col_size; i++) {
1114  float4 a = cur_proposal;
1115  ushort idx = indices[global_offset + col_start * threads_per_group + i] * 4;
1116  float4 b = float4(proposals[idx], proposals[idx + 1], proposals[idx + 2], proposals[idx + 3]);
1117  float left = max(a[0], b[0]);
1118  float right = min(a[2], b[2]);
1119  float top = max(a[1], b[1]);
1120  float bottom = min(a[3], b[3]);
1121  float width = max(right - left + 1.0, 0.0);
1122  float height = max(bottom - top + 1.0, 0.0);
1123  float interS = width * height;
1124  float Sa = (a[2] - a[0] + 1.0) * (a[3] - a[1] + 1.0);
1125  float Sb = (b[2] - b[0] + 1.0) * (b[3] - b[1] + 1.0);
1126  float iou = interS / (Sa + Sb - interS);
1127  if (iou - nms_thresh > 0) {
1128  cur_mask |= 1U << i;
1129  }
1130  }
1131  ushort col_blocks = (num_proposals + threads_per_group - 1) / threads_per_group;
1132  mask[cur_idx * col_blocks + col_start] = cur_mask;
1133  }
1134 }
1135 
1136 
1137 )V0G0N";