Caffe2 - C++ API
A deep learning, cross platform ML framework
mpscnn_kernels.h
1 // @generated
2 
3 static const char* MPSCNN_KERNELS = R"V0G0N(
4 
5 
6 #include <metal_stdlib>
7 
8 using namespace metal;
9 
10 constant ushort ushort_arg_0[[function_constant(0)]];
11 constant ushort ushort_arg_1[[function_constant(1)]];
12 constant ushort ushort_arg_2[[function_constant(2)]];
13 constant ushort ushort_arg_3[[function_constant(3)]];
14 constant ushort ushort_arg_4[[function_constant(4)]];
15 constant ushort ushort_arg_5[[function_constant(5)]];
16 constant ushort ushort_arg_6[[function_constant(6)]];
17 constant ushort ushort_arg_7[[function_constant(7)]];
18 constant ushort ushort_arg_8[[function_constant(8)]];
19 constant ushort ushort_arg_9[[function_constant(9)]];
20 
21 inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; }
22 
23 kernel void affine(constant half4* scale[[buffer(0)]],
24  constant half4* shift[[buffer(1)]],
25  texture2d_array<half, access::read> in[[texture(0)]],
26  texture2d_array<half, access::write> out[[texture(1)]],
27  ushort3 gid[[thread_position_in_grid]]) {
28  const ushort C = ushort_arg_0;
29  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
30  return;
31  }
32  const half4 scale_c = scale[gid.z % divRoundUp(C, 4)];
33  const half4 shift_c = shift[gid.z % divRoundUp(C, 4)];
34  ushort2 gid_(gid.x, gid.y);
35  const half4 x = in.read(gid_, gid.z);
36  const half4 y = scale_c * x + shift_c;
37  out.write(y, gid_, gid.z);
38 }
39 
40 kernel void affine_nonarray(constant half4* scale[[buffer(0)]],
41  constant half4* shift[[buffer(1)]],
42  texture2d<half, access::read> in[[texture(0)]],
43  texture2d<half, access::write> out[[texture(1)]],
44  ushort2 gid[[thread_position_in_grid]]) {
45  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
46  return;
47  }
48  const half4 scale_c = scale[0];
49  const half4 shift_c = shift[0];
50  half4 x = in.read(gid);
51  const half4 y = scale_c * x + shift_c;
52  out.write(y, gid);
53 }
54 
55 kernel void prelu_nonshared(constant half4* weights[[buffer(0)]],
56  texture2d_array<half, access::read> in[[texture(0)]],
57  texture2d_array<half, access::write> out[[texture(1)]],
58  ushort3 gid[[thread_position_in_grid]]) {
59  const ushort C = ushort_arg_0;
60  const ushort S = ushort_arg_1;
61  const bool channel_shared = S == 1;
62  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
63  return;
64  }
65  half4 w = channel_shared ? half4(weights[0][0], weights[0][0], weights[0][0], weights[0][0])
66  : weights[gid.z % divRoundUp(C, 4)];
67  ushort2 gid_(gid.x, gid.y);
68  half4 x = in.read(gid_, gid.z);
69  half4 y = select(x * w, x, x > 0.0h);
70  out.write(y, gid_, gid.z);
71 }
72 
73 kernel void prelu_nonshared_nonarray(constant half4* weights[[buffer(0)]],
74  texture2d<half, access::read> in[[texture(0)]],
75  texture2d<half, access::write> out[[texture(1)]],
76  ushort2 gid[[thread_position_in_grid]]) {
77  // const ushort C = ushort_arg_0;
78  const ushort S = ushort_arg_1;
79  const bool channel_shared = S == 1;
80  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
81  return;
82  }
83  half4 w = channel_shared ? half4(weights[0][0], weights[0][0], weights[0][0], weights[0][0])
84  : weights[0];
85  half4 x = in.read(gid);
86  half4 y = select(x * w, x, x > 0.0h);
87  out.write(y, gid);
88 }
89 
90 // One block per texture.
91 // 256 threads per block.
92 using AccT = float4;
93 
94 constant const bool instance_norm_has_prelu = ushort_arg_1 > 0;
95 
96 kernel void instance_norm(
97  constant half4* weights[[buffer(0)]],
98  constant half4* bias[[buffer(1)]],
99  constant half4* preluWeights[[ buffer(2), function_constant(instance_norm_has_prelu) ]],
100  texture2d_array<half, access::read> in[[texture(0)]],
101  texture2d_array<half, access::write> out[[texture(1)]],
102  ushort3 gid[[thread_position_in_grid]],
103  ushort tid[[thread_index_in_threadgroup]],
104  ushort3 tcount[[threads_per_threadgroup]]) {
105  if (gid.z >= out.get_array_size()) {
106  return;
107  }
108  const ushort C = ushort_arg_0;
109  const ushort S = ushort_arg_1;
110  const bool channel_shared = S == 1;
111  const ushort c = gid.z % divRoundUp(C, 4);
112  constexpr ushort THREADGROUP_SIZE = 256;
113 
114  threadgroup AccT per_thread_state[THREADGROUP_SIZE];
115  // Each block handles a single texture.
116  per_thread_state[tid] = 0;
117  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
118  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
119  per_thread_state[tid] += static_cast<AccT>(in.read(ushort2(x, y), gid.z));
120  }
121  }
122 
123  threadgroup_barrier(mem_flags::mem_threadgroup);
124 
125  // 256 -> 32 reduction
126  if (tid < 32) {
127  per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
128  per_thread_state[tid + 96] + per_thread_state[tid + 128] +
129  per_thread_state[tid + 160] + per_thread_state[tid + 192] +
130  per_thread_state[tid + 224];
131  }
132 
133  threadgroup_barrier(mem_flags::mem_threadgroup);
134 
135  if (tid == 0) {
136  AccT sum = 0.0;
137  for (ushort i = 0; i < 32; ++i) {
138  sum += per_thread_state[i];
139  }
140  sum /= (in.get_width() * in.get_height());
141  per_thread_state[0] = sum;
142  }
143  threadgroup_barrier(mem_flags::mem_threadgroup);
144  // Broadcast to all threads.
145  const AccT mean = per_thread_state[0];
146 
147  threadgroup_barrier(mem_flags::mem_threadgroup);
148 
149  per_thread_state[tid] = 0;
150  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
151  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
152  AccT delta = static_cast<AccT>(in.read(ushort2(x, y), gid.z)) - mean;
153  per_thread_state[tid] += delta * delta;
154  }
155  }
156 
157  threadgroup_barrier(mem_flags::mem_threadgroup);
158 
159  // 256 -> 32 reduction
160  if (tid < 32) {
161  per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
162  per_thread_state[tid + 96] + per_thread_state[tid + 128] +
163  per_thread_state[tid + 160] + per_thread_state[tid + 192] +
164  per_thread_state[tid + 224];
165  }
166 
167  threadgroup_barrier(mem_flags::mem_threadgroup);
168 
169  if (tid == 0) {
170  AccT sum = 0.0;
171  for (ushort i = 0; i < 32; ++i) {
172  sum += per_thread_state[i];
173  }
174  sum /= (in.get_width() * in.get_height());
175  per_thread_state[0] = 1.0 / sqrt(max(sum, AccT(1e-5, 1e-5, 1e-5, 1e-5)) + 1.0e-5);
176  }
177 
178  threadgroup_barrier(mem_flags::mem_threadgroup);
179  // Broadcast to all threads.
180  const AccT inv_var = per_thread_state[0];
181 
182  const AccT c_weights = static_cast<AccT>(weights[c]);
183  const AccT c_bias = static_cast<AccT>(bias[c]);
184 
185  const AccT scale = inv_var * c_weights;
186  const AccT shift = c_bias - mean * scale;
187 
188  half4 w;
189  if (instance_norm_has_prelu) {
190  w = channel_shared ? half4(preluWeights[0][0]) : preluWeights[c];
191  }
192  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
193  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
194  half4 scaled =
195  static_cast<half4>(static_cast<AccT>(in.read(ushort2(x, y), gid.z)) * scale + shift);
196  if (instance_norm_has_prelu) {
197  scaled = select(scaled * w, scaled, scaled > 0.0h);
198  }
199  out.write(scaled, ushort2(x, y), gid.z);
200  }
201  }
202 }
203 
204 // One block per texture.
205 // 256 threads per block.
206 kernel void instance_norm_nonarray(
207  constant half4* weights[[buffer(0)]],
208  constant half4* bias[[buffer(1)]],
209  constant half4* preluWeights[[ buffer(2), function_constant(instance_norm_has_prelu) ]],
210  texture2d<half, access::read> in[[texture(0)]],
211  texture2d<half, access::write> out[[texture(1)]],
212  ushort3 gid[[thread_position_in_grid]],
213  ushort tid[[thread_index_in_threadgroup]],
214  ushort3 tcount[[threads_per_threadgroup]]) {
215  // const ushort C = ushort_arg_0;
216  const ushort S = ushort_arg_1;
217  const bool channel_shared = S == 1;
218  constexpr ushort THREADGROUP_SIZE = 256;
219 
220  threadgroup AccT per_thread_state[THREADGROUP_SIZE];
221  // Each block handles a single texture.
222  per_thread_state[tid] = 0;
223  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
224  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
225  per_thread_state[tid] += static_cast<AccT>(in.read(ushort2(x, y)));
226  }
227  }
228 
229  threadgroup_barrier(mem_flags::mem_threadgroup);
230 
231  // 256 -> 32 reduction
232  if (tid < 32) {
233  per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
234  per_thread_state[tid + 96] + per_thread_state[tid + 128] +
235  per_thread_state[tid + 160] + per_thread_state[tid + 192] +
236  per_thread_state[tid + 224];
237  }
238 
239  threadgroup_barrier(mem_flags::mem_threadgroup);
240 
241  if (tid == 0) {
242  AccT sum = 0.0;
243  for (ushort i = 0; i < 32; ++i) {
244  sum += per_thread_state[i];
245  }
246  sum /= (in.get_width() * in.get_height());
247  per_thread_state[0] = sum;
248  }
249  threadgroup_barrier(mem_flags::mem_threadgroup);
250  // Broadcast to all threads.
251  const AccT mean = per_thread_state[0];
252 
253  threadgroup_barrier(mem_flags::mem_threadgroup);
254 
255  per_thread_state[tid] = 0;
256  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
257  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
258  AccT delta = static_cast<AccT>(in.read(ushort2(x, y))) - mean;
259  per_thread_state[tid] += delta * delta;
260  }
261  }
262 
263  threadgroup_barrier(mem_flags::mem_threadgroup);
264 
265  // 256 -> 32 reduction
266  if (tid < 32) {
267  per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
268  per_thread_state[tid + 96] + per_thread_state[tid + 128] +
269  per_thread_state[tid + 160] + per_thread_state[tid + 192] +
270  per_thread_state[tid + 224];
271  }
272 
273  threadgroup_barrier(mem_flags::mem_threadgroup);
274 
275  if (tid == 0) {
276  AccT sum = 0.0;
277  for (ushort i = 0; i < 32; ++i) {
278  sum += per_thread_state[i];
279  }
280  sum /= (in.get_width() * in.get_height());
281  per_thread_state[0] = 1.0 / sqrt(max(sum, AccT(1e-5, 1e-5, 1e-5, 1e-5)) + 1.0e-5);
282  }
283 
284  threadgroup_barrier(mem_flags::mem_threadgroup);
285  // Broadcast to all threads.
286  const AccT inv_var = per_thread_state[0];
287 
288  const AccT c_weights = static_cast<AccT>(weights[0]);
289  const AccT c_bias = static_cast<AccT>(bias[0]);
290 
291  const AccT scale = inv_var * c_weights;
292  const AccT shift = c_bias - mean * scale;
293 
294  half4 w;
295  if (instance_norm_has_prelu) {
296  w = channel_shared ? half4(preluWeights[0][0]) : preluWeights[0];
297  }
298  for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
299  for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
300  half4 scaled = static_cast<half4>(static_cast<AccT>(in.read(ushort2(x, y))) * scale + shift);
301  if (instance_norm_has_prelu) {
302  scaled = select(scaled * w, scaled, scaled > 0.0h);
303  }
304  out.write(scaled, ushort2(x, y));
305  }
306  }
307 }
308 
309 kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]],
310  texture2d_array<half, access::write> out[[texture(0)]],
311  ushort3 gid[[thread_position_in_grid]]) {
312  const ushort C = ushort_arg_0;
313  const ushort H = ushort_arg_1;
314  const ushort W = ushort_arg_2;
315  if (gid.x >= W || gid.y >= H) {
316  return;
317  }
318 
319  const ushort n = gid.z / divRoundUp(C, 4);
320  const ushort c = gid.z - n * divRoundUp(C, 4);
321 
322  // TODO: are the `else` branches needed?
323  // TODO: trick the optimizer for case where C == 4?
324 #define CHW_TO_CHWP4(idx, n, c_, h, w) \
325 if ((c_) < C) { \
326 trns[idx] = in[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)]; \
327 } else { \
328 trns[idx] = 0.0h; \
329 }
330 
331  half4 trns;
332  CHW_TO_CHWP4(0, n, c * 4 + 0, gid.y, gid.x);
333  CHW_TO_CHWP4(1, n, c * 4 + 1, gid.y, gid.x);
334  CHW_TO_CHWP4(2, n, c * 4 + 2, gid.y, gid.x);
335  CHW_TO_CHWP4(3, n, c * 4 + 3, gid.y, gid.x);
336 #undef CHW_TO_CHWP4
337 
338  out.write(trns, gid.xy, gid.z);
339 }
340 
341 kernel void copy_nchw_to_metal_nonarray(constant float* in[[buffer(0)]],
342  texture2d<half, access::write> out[[texture(0)]],
343  ushort2 gid[[thread_position_in_grid]]) {
344  const ushort C = ushort_arg_0;
345  const ushort H = ushort_arg_1;
346  const ushort W = ushort_arg_2;
347 
348  if (gid.x >= W || gid.y >= H) {
349  return;
350  }
351 
352  half4 trns;
353  // TODO: are the `else` branches needed?
354  // TODO: trick the optimizer for case where C % 4 == 0?
355 
356 #define CHW_TO_CHWP4(idx, c, h, w) \
357 if ((c) < C) { \
358 trns[idx] = in[int(c) * H * W + int(h) * W + int(w)]; \
359 } else { \
360 trns[idx] = 0.0h; \
361 }
362 
363  CHW_TO_CHWP4(0, 0, gid.y, gid.x);
364  CHW_TO_CHWP4(1, 1, gid.y, gid.x);
365  CHW_TO_CHWP4(2, 2, gid.y, gid.x);
366  CHW_TO_CHWP4(3, 3, gid.y, gid.x);
367 #undef CHW_TO_CHWP4
368 
369  out.write(trns, gid.xy);
370 }
371 
372 kernel void copy_metal_to_nchw(texture2d_array<half, access::read> in[[texture(0)]],
373  device float* out[[buffer(0)]],
374  ushort3 gid[[thread_position_in_grid]]) {
375  const ushort C = ushort_arg_0;
376  const ushort H = ushort_arg_1;
377  const ushort W = ushort_arg_2;
378 
379  if (gid.x >= W || gid.y >= H) {
380  return;
381  }
382  const ushort n = gid.z / divRoundUp(C, 4);
383  const ushort c = gid.z - n * divRoundUp(C, 4);
384 
385  half4 cs = in.read(gid.xy, gid.z);
386 
387 #define CHWP4_TO_CHW(idx, n, c_, h, w) \
388 if ((c_) < C) { \
389 out[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)] = cs[idx]; \
390 }
391 
392  CHWP4_TO_CHW(0, n, c * 4 + 0, gid.y, gid.x);
393  CHWP4_TO_CHW(1, n, c * 4 + 1, gid.y, gid.x);
394  CHWP4_TO_CHW(2, n, c * 4 + 2, gid.y, gid.x);
395  CHWP4_TO_CHW(3, n, c * 4 + 3, gid.y, gid.x);
396 #undef CHWP4_TO_CHW
397 }
398 
399 kernel void copy_metal_to_nchw_nonarray(texture2d<half, access::read> in[[texture(0)]],
400  device float* out[[buffer(0)]],
401  ushort2 gid[[thread_position_in_grid]]) {
402  const ushort C = ushort_arg_0;
403  const ushort H = ushort_arg_1;
404  const ushort W = ushort_arg_2;
405 
406  if (gid.x >= W || gid.y >= H) {
407  return;
408  }
409 
410  half4 cs = in.read(gid.xy);
411 
412 #define CHWP4_TO_CHW(idx, c, h, w) \
413 if ((c) < C) { \
414 out[int(c) * H * W + int(h) * W + int(w)] = cs[idx]; \
415 }
416 
417  CHWP4_TO_CHW(0, 0, gid.y, gid.x);
418  CHWP4_TO_CHW(1, 1, gid.y, gid.x);
419  CHWP4_TO_CHW(2, 2, gid.y, gid.x);
420  CHWP4_TO_CHW(3, 3, gid.y, gid.x);
421 #undef CHWP4_TO_CHW
422 }
423 
424 kernel void convtranspose_upscale(texture2d_array<half, access::read> in[[texture(0)]],
425  texture2d_array<half, access::write> out[[texture(1)]],
426  ushort3 gid[[thread_position_in_grid]]) {
427  // All resolved at compile time.
428  // Assume symmetric kernel/stride/pad for now.
429  const ushort kernel_ = ushort_arg_0;
430  const ushort stride = ushort_arg_1;
431  const ushort pad = ushort_arg_2;
432 
433  half4 zero(0.0h, 0.0h, 0.0h, 0.0h);
434 
435  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
436  return;
437  }
438  const ushort2 gid_ = gid.xy;
439  if (gid.x < kernel_ - 1 - pad || gid.y < kernel_ - 1 - pad) {
440  out.write(zero, gid_, gid.z);
441  return;
442  }
443 
444  if (((gid.x - (kernel_ - 1 - pad)) % stride == 0) &&
445  ((gid.y - (kernel_ - 1 - pad)) % stride == 0)) {
446  ushort2 in_pos((gid.x - (kernel_ - 1 - pad)) / stride, (gid.y - (kernel_ - 1 - pad)) / stride);
447 
448  if (in_pos.x < in.get_width() && in_pos.y < in.get_height()) {
449  half4 input = in.read(in_pos, gid.z);
450  out.write(input, gid_, gid.z);
451  } else {
452  out.write(zero, gid_, gid.z);
453  }
454  } else {
455  out.write(zero, gid_, gid.z);
456  }
457 }
458 
459 constant bool has_in_arr = (ushort_arg_7 > 1 || ushort_arg_0 * ushort_arg_1 * ushort_arg_6 > 4);
460 constant bool has_out_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
461 constant bool has_in_tex = (!has_in_arr);
462 constant bool has_out_tex = (!has_out_arr);
463 
464 kernel void col2im(
465  texture2d_array<half, access::read> ina[[ texture(0), function_constant(has_in_arr) ]],
466  texture2d<half, access::read> in[[ texture(0), function_constant(has_in_tex) ]],
467  texture2d_array<half, access::write> outa[[ texture(1), function_constant(has_out_arr) ]],
468  texture2d<half, access::write> out[[ texture(1), function_constant(has_out_tex) ]],
469  constant half4* bias[[buffer(0)]],
470  ushort3 gid[[thread_position_in_grid]]) {
471  if (has_out_tex) {
472  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
473  return;
474  }
475  } else {
476  if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
477  return;
478  }
479  }
480  const ushort kernel_h = ushort_arg_0;
481  const ushort kernel_w = ushort_arg_1;
482  const ushort stride_h = ushort_arg_2;
483  const ushort stride_w = ushort_arg_3;
484  const ushort pad_l = ushort_arg_4;
485  const ushort pad_t = ushort_arg_5;
486  const ushort C = ushort_arg_6;
487  // const int N = ushort_arg_7;
488  const ushort height_col = ushort_arg_8;
489  const ushort width_col = ushort_arg_9;
490 
491  const ushort n = gid.z / divRoundUp(C, 4);
492  const ushort c = gid.z - n * divRoundUp(C, 4);
493 
494  const ushort w = gid.x + pad_l;
495  const ushort h = gid.y + pad_t;
496 
497  // compute the start and end of the output
498  const ushort w_col_start = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
499  const ushort w_col_end = min(ushort(w / stride_w + 1), ushort(width_col));
500  const ushort h_col_start = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
501  const ushort h_col_end = min(ushort(h / stride_h + 1), ushort(height_col));
502 
503  float4 val = static_cast<float4>(bias[c]);
504  for (ushort h_col = h_col_start; h_col < h_col_end; ++h_col) {
505  for (ushort w_col = w_col_start; w_col < w_col_end; ++w_col) {
506  const ushort w_k = w - w_col * stride_w;
507  const ushort h_k = h - h_col * stride_h;
508 
509  // layout is essentially: [N][K][K][C][H][W]
510  // - where the divRoundUp(K * K * C, 4) channels are interleaved as usual.
511  // Thus, it's actually [N][divRoundUp(K * K * C, 4)][H][W].
512 
513  // If C % 4 is not zero, then we have to play some games via partial indexing.
514  // TODO: is it worth optimizing this loop via padding in C?
515  if (C % 4 == 0) {
516  ushort c_col = n * kernel_h * kernel_w * divRoundUp(C, 4) +
517  h_k * kernel_w * divRoundUp(C, 4) + w_k * divRoundUp(C, 4) + c;
518  if (has_in_arr) {
519  val += static_cast<float4>(ina.read(ushort2(w_col, h_col), c_col));
520  }
521  if (has_in_tex) {
522  val += static_cast<float4>(in.read(ushort2(w_col, h_col), c_col));
523  }
524  } else {
525  half4 components(0, 0, 0, 0);
526  for (auto i = 0; i < 4; ++i) {
527  ushort c_col_i = n * divRoundUp(kernel_h * kernel_w * C, 4) * 4 + h_k * kernel_w * C +
528  w_k * C + c * 4 + i;
529  ushort c_col_i_z = c_col_i / 4;
530  ushort c_col_i_off = c_col_i - c_col_i_z * 4;
531  if (has_in_arr) {
532  components[i] = ina.read(ushort2(w_col, h_col), c_col_i_z)[c_col_i_off];
533  }
534  if (has_in_tex) {
535  components[i] = in.read(ushort2(w_col, h_col))[c_col_i_off];
536  }
537  }
538  val += static_cast<float4>(components);
539  }
540  }
541  }
542  if (has_out_arr) {
543  outa.write(static_cast<half4>(val), gid.xy, gid.z);
544  }
545  if (has_out_tex) {
546  out.write(static_cast<half4>(val), gid.xy);
547  }
548 }
549 
550 kernel void preprocess_stylizer(device uchar4* in[[buffer(0)]],
551  constant half* mean[[buffer(1)]],
552  constant half4* noise[[buffer(2)]],
553  texture2d<half, access::write> out[[texture(0)]],
554  ushort2 gid[[thread_position_in_grid]]) {
555 
556  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
557  return;
558  }
559  const ushort noise_size = ushort_arg_0;
560 
561  half4 mean_half(mean[0], mean[1], mean[2], 0.0h);
562  uint input_noise_idx = ((uint)out.get_width() * (uint)gid.y + (uint)gid.x) % (noise_size / 4);
563  const half4 input_noise = noise[input_noise_idx];
564  const uint W = out.get_width();
565 #define in_at(h, w) in[(uint)(h)*W + (uint)(w)]
566  uchar4 input = in_at(gid.y, gid.x);
567 #undef in_at
568  half4 input_half = static_cast<half4>(input);
569  out.write(input_half - mean_half + input_noise, gid);
570 }
571 
572 kernel void deprocess_stylizer(texture2d<half, access::read> in[[texture(0)]],
573  device uchar4* out[[buffer(0)]],
574  constant half* mean[[buffer(1)]],
575  ushort2 gid[[thread_position_in_grid]]) {
576  if (gid.x >= in.get_width() || gid.y >= in.get_height()) {
577  return;
578  }
579 
580  half4 value = in.read(gid);
581 
582  half4 mean_h(mean[0], mean[1], mean[2], 0.0h);
583  half4 min_h(0.0h, 0.0h, 0.0h, 255.0h);
584  half4 max_h(255.0h, 255.0h, 255.0h, 255.0h);
585  half4 clamped = clamp(value + mean_h, min_h, max_h);
586  const uint W = in.get_width();
587 #define out_at(h, w, v) out[(uint)(h)*W + (uint)(w)] = (v)
588  out_at(gid.y, gid.x, static_cast<uchar4>(clamped));
589 #undef out_at
590 }
591 
592 kernel void reflection_padding_nonarray(texture2d<half, access::read> in[[texture(0)]],
593  texture2d<half, access::write> out[[texture(1)]],
594  ushort2 gid[[thread_position_in_grid]]) {
595  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
596  return;
597  }
598  ushort H = in.get_height();
599  ushort PH = out.get_height();
600 
601  // Note: we assume symmetric padding on H/W here, which is verified
602  // in the calling code.
603  ushort pad_h = (PH - H) / 2;
604  ushort W = in.get_width();
605  ushort PW = out.get_width();
606  ushort pad_w = (PW - W) / 2;
607 
608  short h = short(gid.y) - short(pad_h);
609  h = max(h, short(-h));
610  h = min(h, short(2 * H - h - 2));
611 
612  short w = short(gid.x) - short(pad_w);
613  w = max(w, short(-w));
614  w = min(w, short(2 * W - w - 2));
615 
616  ushort2 inid(w, h);
617  out.write(in.read(inid), gid);
618 }
619 
620 kernel void reflection_padding(texture2d_array<half, access::read> in[[texture(0)]],
621  texture2d_array<half, access::write> out[[texture(1)]],
622  ushort3 gid[[thread_position_in_grid]]) {
623  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
624  return;
625  }
626  ushort H = in.get_height();
627  ushort PH = out.get_height();
628 
629  // Note: we assume symmetric padding on H/W here, which is verified
630  // in the calling code.
631  ushort pad_h = (PH - H) / 2;
632  ushort W = in.get_width();
633  ushort PW = out.get_width();
634  ushort pad_w = (PW - W) / 2;
635 
636  short h = short(gid.y) - short(pad_h);
637  h = max(h, short(-h));
638  h = min(h, short(2 * H - h - 2));
639 
640  short w = short(gid.x) - short(pad_w);
641  w = max(w, short(-w));
642  w = min(w, short(2 * W - w - 2));
643 
644  ushort2 inid(w, h);
645 
646  out.write(in.read(inid, gid.z), gid.xy, gid.z);
647 }
648 
649 kernel void bilinear_upsample(texture2d<half, access::sample> in[[texture(0)]],
650  texture2d<half, access::write> out[[texture(1)]],
651  ushort2 gid[[thread_position_in_grid]]) {
652  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
653  return;
654  }
655  ushort2 src = gid / 2;
656  constexpr sampler sampler(address::clamp_to_edge, filter::linear, coord::pixel);
657  half4 value = in.sample(sampler, static_cast<float2>(src));
658  out.write(value, gid);
659 }
660 
661 constant bool in0_is_tex = ushort_arg_0 <= 1 && ushort_arg_1 <= 4;
662 constant bool in0_is_arr = !in0_is_tex;
663 
664 kernel void elementwise_mul(texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
665  texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
666  texture2d<half, access::write> out[[texture(2), function_constant(in0_is_tex)]],
667  texture2d_array<half, access::write> outa[[texture(2), function_constant(in0_is_arr)]],
668  constant float* in1[[buffer(1)]],
669  ushort3 gid[[thread_position_in_grid]]) {
670  ushort last_dim = ushort_arg_2;
671  ushort idx;
672  if (in0_is_tex) {
673  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
674  return;
675  }
676  idx = gid.y * out.get_width() + gid.x;
677  } else {
678  if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
679  return;
680  }
681  idx = gid.y * outa.get_width() + gid.x;
682  }
683  ushort2 gid_ = gid.xy;
684  if (in0_is_tex) {
685  out.write(in0.read(gid_) * in1[idx % last_dim], gid_);
686  } else {
687  outa.write(ina0.read(gid_, gid.z) * in1[idx % last_dim], gid_, gid.z);
688  }
689 }
690 
691 kernel void elementwise_sub(texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
692  texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
693  texture2d<half, access::write> out[[texture(2), function_constant(in0_is_tex)]],
694  texture2d_array<half, access::write> outa[[texture(2), function_constant(in0_is_arr)]],
695  constant float* in1[[buffer(1)]],
696  ushort3 gid[[thread_position_in_grid]]) {
697  ushort last_dim = ushort_arg_2;
698  ushort idx;
699  if (in0_is_tex) {
700  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
701  return;
702  }
703  idx = gid.y * out.get_width() + gid.x;
704  } else {
705  if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
706  return;
707  }
708  idx = gid.y * outa.get_width() + gid.x;
709  }
710  ushort2 gid_ = gid.xy;
711  if (in0_is_tex) {
712  out.write(in0.read(gid_) - in1[idx % last_dim], gid_);
713  } else {
714  outa.write(ina0.read(gid_, gid.z) - in1[idx % last_dim], gid_, gid.z);
715  }
716 }
717 
718 kernel void elementwise_add_nonarray(texture2d<half, access::read> in0[[texture(0)]],
719  texture2d<half, access::read> in1[[texture(1)]],
720  texture2d<half, access::write> out[[texture(2)]],
721  ushort2 gid[[thread_position_in_grid]]) {
722  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
723  return;
724  }
725  out.write(in0.read(gid) + in1.read(gid), gid);
726 }
727 
728 kernel void elementwise_add(texture2d_array<half, access::read> in0[[texture(0)]],
729  texture2d_array<half, access::read> in1[[texture(1)]],
730  texture2d_array<half, access::write> out[[texture(2)]],
731  ushort3 gid[[thread_position_in_grid]]) {
732  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
733  return;
734  }
735  ushort2 gid_ = gid.xy;
736  out.write(in0.read(gid_, gid.z) + in1.read(gid_, gid.z), gid_, gid.z);
737 }
738 
739 constant bool has_in0_arg = (ushort_arg_0 > 0);
740 constant bool has_in1_arg = (ushort_arg_1 > 0);
741 constant bool has_in2_arg = (ushort_arg_2 > 0);
742 constant bool has_in3_arg = (ushort_arg_3 > 0);
743 
744 constant bool has_in0_tex = (has_in0_arg && ushort_arg_0 <= 4 && ushort_arg_5 <= 1);
745 constant bool has_in1_tex = (has_in1_arg && ushort_arg_1 <= 4 && ushort_arg_5 <= 1);
746 constant bool has_in2_tex = (has_in2_arg && ushort_arg_2 <= 4 && ushort_arg_5 <= 1);
747 constant bool has_in3_tex = (has_in3_arg && ushort_arg_3 <= 4 && ushort_arg_5 <= 1);
748 
749 constant bool has_in0_array = (has_in0_arg && !has_in0_tex);
750 constant bool has_in1_array = (has_in1_arg && !has_in1_tex);
751 constant bool has_in2_array = (has_in2_arg && !has_in2_tex);
752 constant bool has_in3_array = (has_in3_arg && !has_in3_tex);
753 
754 constant bool concat_has_out_tex = (ushort_arg_4 <= 4 && ushort_arg_5 <= 1);
755 constant bool concat_has_out_array = !concat_has_out_tex;
756 
757 inline ushort idx_3(ushort z, ushort C0, ushort C1, ushort C2, ushort C3) {
758  if (z < C0) {
759  return 0;
760  }
761  if (z < (C0 + C1)) {
762  return 1;
763  }
764  if (z < (C0 + C1 + C2)) {
765  return 2;
766  }
767  return 3;
768 }
769 
770 inline ushort idx_2(ushort z, ushort C0, ushort C1, ushort C2) {
771  if (z < C0) {
772  return 0;
773  }
774  if (z < (C0 + C1)) {
775  return 1;
776  }
777  return 2;
778 }
779 
780 inline ushort idx_1(ushort z, ushort C0, ushort C1) {
781  if (z < C0) {
782  return 0;
783  } else {
784  return 1;
785  }
786 }
787 
788 inline ushort idx_0(ushort z, ushort C0) { return 0; }
789 
790 // in a texture_array with size C, find the offset for image N at plane c.
791 inline constexpr ushort z_off(ushort n, ushort c, ushort C) { return n * divRoundUp(C, 4) + c / 4; }
792 
793 kernel void concat(
794  texture2d<half, access::read> in0[[ texture(0), function_constant(has_in0_tex) ]],
795  texture2d<half, access::read> in1[[ texture(1), function_constant(has_in1_tex) ]],
796  texture2d<half, access::read> in2[[ texture(2), function_constant(has_in2_tex) ]],
797  texture2d<half, access::read> in3[[ texture(3), function_constant(has_in3_tex) ]],
798  texture2d_array<half, access::read> ina0[[ texture(0), function_constant(has_in0_array) ]],
799  texture2d_array<half, access::read> ina1[[ texture(1), function_constant(has_in1_array) ]],
800  texture2d_array<half, access::read> ina2[[ texture(2), function_constant(has_in2_array) ]],
801  texture2d_array<half, access::read> ina3[[ texture(3), function_constant(has_in3_array) ]],
802  texture2d<half, access::write> out[[texture(5),
803  function_constant(concat_has_out_tex) ]],
804  texture2d_array<half, access::write> outa[[texture(5),
805  function_constant(concat_has_out_array) ]],
806  ushort3 gid[[thread_position_in_grid]]) {
807  if (concat_has_out_tex) {
808  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
809  return;
810  }
811  } else {
812  if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
813  return;
814  }
815  }
816 
817  const ushort C0 = ushort_arg_0;
818  const ushort C1 = ushort_arg_1;
819  const ushort C2 = ushort_arg_2;
820  const ushort C3 = ushort_arg_3;
821  const ushort C = C0 + C1 + C2 + C3;
822  const ushort n = gid.z / divRoundUp(C, 4);
823  const ushort c = gid.z - n * divRoundUp(C, 4);
824  // Fill channel 4*c to 4*(c+1) of nth image of output
825 
826  ushort2 gid_ = ushort2(gid.x, gid.y);
827  half4 value;
828 
829  for (int off = 0; off < 4; ++off) {
830  ushort cur_channel = c * 4 + off;
831  ushort cur_idx = 0;
832  if (cur_channel >= C) {
833  break;
834  }
835  if (has_in3_arg) {
836  cur_idx = idx_3(cur_channel, C0, C1, C2, C3);
837  } else if (has_in2_arg) {
838  cur_idx = idx_2(cur_channel, C0, C1, C2);
839  } else if (has_in1_arg) {
840  cur_idx = idx_1(cur_channel, C0, C1);
841  } else if (has_in0_arg) {
842  cur_idx = idx_0(cur_channel, C0);
843  } else {
844  // never reached.
845  cur_idx = 0;
846  }
847  ushort src_off = 0;
848  switch (cur_idx) {
849  case 0:
850  src_off = cur_channel % 4;
851  break;
852  case 1:
853  src_off = (cur_channel - C0) % 4;
854  break;
855  case 2:
856  src_off = (cur_channel - (C0 + C1)) % 4;
857  break;
858  case 3:
859  src_off = (cur_channel - (C0 + C1 + C2)) % 4;
860  break;
861  }
862  // try to see if we can only issue one read op for the 4 values
863  bool fast_path = false;
864  if (off == 0 && src_off == 0 && (cur_channel + 3) < C) {
865  ushort last_idx = 0;
866  if (has_in3_arg) {
867  last_idx = idx_3(cur_channel + 3, C0, C1, C2, C3);
868  } else if (has_in2_arg) {
869  last_idx = idx_2(cur_channel + 3, C0, C1, C2);
870  } else if (has_in1_arg) {
871  last_idx = idx_1(cur_channel + 3, C0, C1);
872  } else if (has_in0_arg) {
873  last_idx = idx_0(cur_channel + 3, C0);
874  } else {
875  // never reached.
876  last_idx = 0;
877  }
878  if (cur_idx == last_idx) {
879  fast_path = true;
880  }
881  }
882  switch (cur_idx) {
883  case 0: {
884  if (has_in0_tex) {
885  if (fast_path) {
886  value = in0.read(gid_);
887  } else {
888  value[off] = in0.read(gid_)[src_off];
889  }
890  }
891  if (has_in0_array) {
892  if (fast_path) {
893  value = ina0.read(gid_, z_off(n, cur_channel, C0));
894  } else {
895  value[off] = ina0.read(gid_, z_off(n, cur_channel, C0))[src_off];
896  }
897  }
898  break;
899  }
900  case 1: {
901  if (has_in1_tex) {
902  if (fast_path) {
903  value = in1.read(gid_);
904  } else {
905  value[off] = in1.read(gid_)[src_off];
906  }
907  }
908  if (has_in1_array) {
909  if (fast_path) {
910  value = ina1.read(gid_, z_off(n, cur_channel - C0, C1));
911  } else {
912  value[off] = ina1.read(gid_, z_off(n, cur_channel - C0, C1))[src_off];
913  }
914  }
915  break;
916  }
917  case 2: {
918  if (has_in2_tex) {
919  if (fast_path) {
920  value = in2.read(gid_);
921  } else {
922  value[off] = in2.read(gid_)[src_off];
923  }
924  }
925  if (has_in2_array) {
926  if (fast_path) {
927  value = ina2.read(gid_, z_off(n, cur_channel - (C0 + C1), C2));
928  } else {
929  value[off] = ina2.read(gid_, z_off(n, cur_channel - (C0 + C1), C2))[src_off];
930  }
931  }
932  break;
933  }
934  case 3: {
935  if (has_in3_tex) {
936  if (fast_path) {
937  value = in3.read(gid_);
938  } else {
939  value[off] = in3.read(gid_)[src_off];
940  }
941  }
942  if (has_in3_array) {
943  if (fast_path) {
944  value = ina3.read(gid_, z_off(n, cur_channel - (C0 + C1 + C2), C3));
945  } else {
946  value[off] = ina3.read(gid_, z_off(n, cur_channel - (C0 + C1 + C2), C3))[src_off];
947  }
948  }
949  break;
950  }
951  }
952  if (fast_path) {
953  break;
954  }
955  }
956  if (concat_has_out_tex) {
957  out.write(value, gid_, gid.z);
958  } else {
959  outa.write(value, gid_, gid.z);
960  }
961 }
962 
963 using RoIT = half;
964 using RoIT4 = half4;
965 constant bool rw_has_in_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
966 constant bool rw_has_out_arr = (ushort_arg_4 > 1 || ushort_arg_2 > 4);
967 constant bool rw_has_in_tex = (!rw_has_in_arr);
968 constant bool rw_has_out_tex = (!rw_has_out_arr);
969 kernel void roi_warp(texture2d_array<half, access::sample> ina[[texture(0), function_constant(rw_has_in_arr)]],
970  texture2d<half, access::sample> in[[texture(0), function_constant(rw_has_in_tex)]],
971  texture2d_array<half, access::write> outa[[texture(1), function_constant(rw_has_out_arr)]],
972  texture2d<half, access::write> out[[texture(1), function_constant(rw_has_out_tex)]],
973  constant half4* rois[[buffer(0)]],
974  ushort3 gid[[thread_position_in_grid]]) {
975  ushort out_width, out_height;
976  if (rw_has_out_arr) {
977  out_width = outa.get_width();
978  out_height = outa.get_height();
979  } else {
980  out_width = out.get_width();
981  out_height = out.get_height();
982  }
983  if (gid.x >= out_width || gid.y >= out_height) {
984  return;
985  }
986  constexpr sampler s2(coord::pixel, address::clamp_to_edge, filter::linear);
987 
988  const half spatial_scale = half(ushort_arg_0) / 10000;
989  const ushort sampling_ratio = ushort_arg_1;
990  const ushort C = ushort_arg_2;
991  const ushort pw = gid.x;
992  const ushort ph = gid.y;
993  const ushort n = gid.z / divRoundUp(C, 4);
994  const ushort c = gid.z % divRoundUp(C, 4);
995 
996  const RoIT4 roi_scaled = rois[n] * spatial_scale;
997  const RoIT roi_start_w = roi_scaled[0];
998  const RoIT roi_start_h = roi_scaled[1];
999  const RoIT roi_end_w = roi_scaled[2];
1000  const RoIT roi_end_h = roi_scaled[3];
1001 
1002  // Force malformed ROIs to be 1x1
1003  const RoIT roi_width = max(roi_end_w - roi_start_w, (RoIT)1.);
1004  const RoIT roi_height = max(roi_end_h - roi_start_h, (RoIT)1.);
1005 
1006  const RoIT bin_size_h = static_cast<RoIT>(roi_height) / static_cast<RoIT>(out_height);
1007  const RoIT bin_size_w = static_cast<RoIT>(roi_width) / static_cast<RoIT>(out_width);
1008  const ushort roi_bin_grid_h = sampling_ratio > 0 ? sampling_ratio : ceil(roi_height / static_cast<RoIT>(out_height));
1009  const ushort roi_bin_grid_w = sampling_ratio > 0 ? sampling_ratio : ceil(roi_width / static_cast<RoIT>(out_width));
1010  const ushort iy_upper = (sampling_ratio > 0) ? roi_bin_grid_h : (roi_bin_grid_h + 1);
1011  const ushort ix_upper = (sampling_ratio > 0) ? roi_bin_grid_w : (roi_bin_grid_w + 1);
1012 
1013  const RoIT count = iy_upper * ix_upper;
1014 
1015  RoIT4 output_val = 0.0;
1016  for (int iy = 0; iy < iy_upper; iy++) {
1017  for (int ix = 0; ix < ix_upper; ix++) {
1018  const RoIT y =
1019  roi_start_h + ph * bin_size_h + iy * bin_size_h / static_cast<RoIT>(roi_bin_grid_h);
1020  const RoIT x =
1021  roi_start_w + pw * bin_size_w + ix * bin_size_w / static_cast<RoIT>(roi_bin_grid_w);
1022  if (rw_has_in_arr) {
1023  output_val += ina.sample(s2, float2(x + 0.5, y + 0.5), c);
1024  } else {
1025  output_val += in.sample(s2, float2(x + 0.5, y + 0.5));
1026  }
1027  }
1028  }
1029  output_val /= count;
1030  if (rw_has_out_arr) {
1031  outa.write(static_cast<half4>(output_val), gid.xy, gid.z);
1032  } else {
1033  out.write(static_cast<half4>(output_val), gid.xy);
1034  }
1035 }
1036 
1037 kernel void resize_nearest(texture2d_array<half, access::sample> in[[texture(0)]],
1038  texture2d_array<half, access::write> out[[texture(1)]],
1039  ushort3 gid[[thread_position_in_grid]]) {
1040  const ushort oH = ushort_arg_0;
1041  const ushort oW = ushort_arg_1;
1042  if (gid.x >= oW || gid.y >= oH) {
1043  return;
1044  }
1045  const float height_scale = float(ushort_arg_2) / 10000;
1046  const float width_scale = float(ushort_arg_3) / 10000;
1047  constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
1048  const int in_y = (int)(gid.y / height_scale);
1049  const int in_x = (int)(gid.x / width_scale);
1050  out.write(in.sample(s, float2(in_x, in_y), gid.z), gid.xy, gid.z);
1051 }
1052 
1053 kernel void resize_nearest_nonarray(texture2d<half, access::sample> in[[texture(0)]],
1054  texture2d<half, access::write> out[[texture(1)]],
1055  ushort2 gid[[thread_position_in_grid]]) {
1056  const ushort oH = ushort_arg_0;
1057  const ushort oW = ushort_arg_1;
1058  if (gid.x >= oW || gid.y >= oH) {
1059  return;
1060  }
1061  const float height_scale = float(ushort_arg_2) / 10000;
1062  const float width_scale = float(ushort_arg_3) / 10000;
1063  constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
1064  const int in_y = (int)(gid.y / height_scale);
1065  const int in_x = (int)(gid.x / width_scale);
1066  out.write(in.sample(s, float2(in_x, in_y)), gid.xy);
1067 }
1068 
1069 kernel void nms(device uint* mask[[buffer(0)]],
1070  constant float* proposals[[buffer(1)]],
1071  constant int* indices[[buffer(2)]],
1072  ushort2 tgid[[threadgroup_position_in_grid]],
1073  ushort2 tid[[thread_position_in_threadgroup]]) {
1074  const ushort num_proposals = ushort_arg_0;
1075  const ushort threads_per_group = ushort_arg_1;
1076  float nms_thresh = float(ushort_arg_2) / 10000.0;
1077  const ushort global_offset = ushort_arg_3;
1078  const ushort row_start = tgid.y;
1079  const ushort col_start = tgid.x;
1080  const ushort trd_id = tid.x;
1081 
1082  const short row_size = min(short(32), short(num_proposals - row_start * threads_per_group));
1083  const short col_size = min(short(32), short(num_proposals - col_start * threads_per_group));
1084 
1085  // mask the bit if the IoU between two proposals exceeds the threshold
1086  if (trd_id < row_size) {
1087  const ushort cur_idx = global_offset + row_start * threads_per_group + trd_id;
1088  const ushort offset = indices[cur_idx] * 4;
1089  const float4 cur_proposal = float4(
1090  proposals[offset], proposals[offset + 1], proposals[offset + 2], proposals[offset + 3]);
1091  uint cur_mask = 0;
1092  ushort group_start = 0; // start index within group
1093  if (row_start == col_start) {
1094  // if in the same group, start from the next
1095  group_start = trd_id + 1;
1096  }
1097  for (ushort i = group_start; i < col_size; i++) {
1098  float4 a = cur_proposal;
1099  ushort idx = indices[global_offset + col_start * threads_per_group + i] * 4;
1100  float4 b = float4(proposals[idx], proposals[idx + 1], proposals[idx + 2], proposals[idx + 3]);
1101  float left = max(a[0], b[0]);
1102  float right = min(a[2], b[2]);
1103  float top = max(a[1], b[1]);
1104  float bottom = min(a[3], b[3]);
1105  float width = max(right - left + 1.0, 0.0);
1106  float height = max(bottom - top + 1.0, 0.0);
1107  float interS = width * height;
1108  float Sa = (a[2] - a[0] + 1.0) * (a[3] - a[1] + 1.0);
1109  float Sb = (b[2] - b[0] + 1.0) * (b[3] - b[1] + 1.0);
1110  float iou = interS / (Sa + Sb - interS);
1111  if (iou - nms_thresh > 0) {
1112  cur_mask |= 1U << i;
1113  }
1114  }
1115  ushort col_blocks = (num_proposals + threads_per_group - 1) / threads_per_group;
1116  mask[cur_idx * col_blocks + col_start] = cur_mask;
1117  }
1118 }
1119 
1120 
1121 kernel void channel_shuffle(
1122  texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
1123  texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
1124  texture2d<half, access::write> out[[texture(1), function_constant(in0_is_tex)]],
1125  texture2d_array<half, access::write> outa[[texture(1), function_constant(in0_is_arr)]],
1126  ushort3 gid[[thread_position_in_grid]]) {
1127  ushort C = ushort_arg_1;
1128  ushort K = ushort_arg_2;
1129  ushort groups = ushort_arg_3;
1130 
1131  if (in0_is_tex) {
1132  if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
1133  return;
1134  }
1135  } else {
1136  if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
1137  return;
1138  }
1139  }
1140  const ushort n = gid.z / divRoundUp(C, 4);
1141  const ushort c = gid.z - n * divRoundUp(C, 4);
1142  half4 value;
1143  ushort2 gid_ = gid.xy;
1144  for (int off = 0; off < 4; ++off) {
1145  ushort cur_channel = c * 4 + off;
1146  if (cur_channel >= C) {
1147  break;
1148  }
1149  ushort channel_id = cur_channel / groups;
1150  ushort group_id = cur_channel % groups;
1151  ushort c0 = group_id * K + channel_id;
1152  if (in0_is_tex) {
1153  value[off] = in0.read(gid_)[c0 % 4];
1154  } else {
1155  value[off] = ina0.read(gid_, c0 / 4 + n * divRoundUp(C, 4))[c0 % 4];
1156  }
1157  }
1158  if (in0_is_tex) {
1159  out.write(value, gid_);
1160  } else {
1161  outa.write(value, gid_, gid.z);
1162  }
1163 }
1164 
1165 )V0G0N";