Caffe2 - C++ API
A deep learning, cross platform ML framework
Related Pages
Modules
Data Structures
Files
C++ API
Python API
GitHub
File List
Globals
caffe2
mobile
contrib
ios
mpscnn
mpscnn_kernels.h
1
// @generated
2
3
static
const
char
* MPSCNN_KERNELS = R
"V0G0N(
4
5
6
#include <metal_stdlib>
7
8
using namespace metal;
9
10
constant ushort ushort_arg_0[[function_constant(0)]];
11
constant ushort ushort_arg_1[[function_constant(1)]];
12
constant ushort ushort_arg_2[[function_constant(2)]];
13
constant ushort ushort_arg_3[[function_constant(3)]];
14
constant ushort ushort_arg_4[[function_constant(4)]];
15
constant ushort ushort_arg_5[[function_constant(5)]];
16
constant ushort ushort_arg_6[[function_constant(6)]];
17
constant ushort ushort_arg_7[[function_constant(7)]];
18
constant ushort ushort_arg_8[[function_constant(8)]];
19
constant ushort ushort_arg_9[[function_constant(9)]];
20
21
inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; }
22
23
kernel void affine(constant half4* scale[[buffer(0)]],
24
constant half4* shift[[buffer(1)]],
25
texture2d_array<half, access::read> in[[texture(0)]],
26
texture2d_array<half, access::write> out[[texture(1)]],
27
ushort3 gid[[thread_position_in_grid]]) {
28
const ushort C = ushort_arg_0;
29
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
30
return;
31
}
32
const half4 scale_c = scale[gid.z % divRoundUp(C, 4)];
33
const half4 shift_c = shift[gid.z % divRoundUp(C, 4)];
34
ushort2 gid_(gid.x, gid.y);
35
const half4 x = in.read(gid_, gid.z);
36
const half4 y = scale_c * x + shift_c;
37
out.write(y, gid_, gid.z);
38
}
39
40
kernel void affine_nonarray(constant half4* scale[[buffer(0)]],
41
constant half4* shift[[buffer(1)]],
42
texture2d<half, access::read> in[[texture(0)]],
43
texture2d<half, access::write> out[[texture(1)]],
44
ushort2 gid[[thread_position_in_grid]]) {
45
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
46
return;
47
}
48
const half4 scale_c = scale[0];
49
const half4 shift_c = shift[0];
50
half4 x = in.read(gid);
51
const half4 y = scale_c * x + shift_c;
52
out.write(y, gid);
53
}
54
55
kernel void prelu_nonshared(constant half4* weights[[buffer(0)]],
56
texture2d_array<half, access::read> in[[texture(0)]],
57
texture2d_array<half, access::write> out[[texture(1)]],
58
ushort3 gid[[thread_position_in_grid]]) {
59
const ushort C = ushort_arg_0;
60
const ushort S = ushort_arg_1;
61
const bool channel_shared = S == 1;
62
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
63
return;
64
}
65
half4 w = channel_shared ? half4(weights[0][0], weights[0][0], weights[0][0], weights[0][0])
66
: weights[gid.z % divRoundUp(C, 4)];
67
ushort2 gid_(gid.x, gid.y);
68
half4 x = in.read(gid_, gid.z);
69
half4 y = select(x * w, x, x > 0.0h);
70
out.write(y, gid_, gid.z);
71
}
72
73
kernel void prelu_nonshared_nonarray(constant half4* weights[[buffer(0)]],
74
texture2d<half, access::read> in[[texture(0)]],
75
texture2d<half, access::write> out[[texture(1)]],
76
ushort2 gid[[thread_position_in_grid]]) {
77
// const ushort C = ushort_arg_0;
78
const ushort S = ushort_arg_1;
79
const bool channel_shared = S == 1;
80
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
81
return;
82
}
83
half4 w = channel_shared ? half4(weights[0][0], weights[0][0], weights[0][0], weights[0][0])
84
: weights[0];
85
half4 x = in.read(gid);
86
half4 y = select(x * w, x, x > 0.0h);
87
out.write(y, gid);
88
}
89
90
// One block per texture.
91
// 256 threads per block.
92
using AccT = float4;
93
94
constant const bool instance_norm_has_prelu = ushort_arg_1 > 0;
95
96
kernel void instance_norm(
97
constant half4* weights[[buffer(0)]],
98
constant half4* bias[[buffer(1)]],
99
constant half4* preluWeights[[ buffer(2), function_constant(instance_norm_has_prelu) ]],
100
texture2d_array<half, access::read> in[[texture(0)]],
101
texture2d_array<half, access::write> out[[texture(1)]],
102
ushort3 gid[[thread_position_in_grid]],
103
ushort tid[[thread_index_in_threadgroup]],
104
ushort3 tcount[[threads_per_threadgroup]]) {
105
if (gid.z >= out.get_array_size()) {
106
return;
107
}
108
const ushort C = ushort_arg_0;
109
const ushort S = ushort_arg_1;
110
const bool channel_shared = S == 1;
111
const ushort c = gid.z % divRoundUp(C, 4);
112
constexpr ushort THREADGROUP_SIZE = 256;
113
114
threadgroup AccT per_thread_state[THREADGROUP_SIZE];
115
// Each block handles a single texture.
116
per_thread_state[tid] = 0;
117
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
118
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
119
per_thread_state[tid] += static_cast<AccT>(in.read(ushort2(x, y), gid.z));
120
}
121
}
122
123
threadgroup_barrier(mem_flags::mem_threadgroup);
124
125
// 256 -> 32 reduction
126
if (tid < 32) {
127
per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
128
per_thread_state[tid + 96] + per_thread_state[tid + 128] +
129
per_thread_state[tid + 160] + per_thread_state[tid + 192] +
130
per_thread_state[tid + 224];
131
}
132
133
threadgroup_barrier(mem_flags::mem_threadgroup);
134
135
if (tid == 0) {
136
AccT sum = 0.0;
137
for (ushort i = 0; i < 32; ++i) {
138
sum += per_thread_state[i];
139
}
140
sum /= (in.get_width() * in.get_height());
141
per_thread_state[0] = sum;
142
}
143
threadgroup_barrier(mem_flags::mem_threadgroup);
144
// Broadcast to all threads.
145
const AccT mean = per_thread_state[0];
146
147
threadgroup_barrier(mem_flags::mem_threadgroup);
148
149
per_thread_state[tid] = 0;
150
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
151
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
152
AccT delta = static_cast<AccT>(in.read(ushort2(x, y), gid.z)) - mean;
153
per_thread_state[tid] += delta * delta;
154
}
155
}
156
157
threadgroup_barrier(mem_flags::mem_threadgroup);
158
159
// 256 -> 32 reduction
160
if (tid < 32) {
161
per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
162
per_thread_state[tid + 96] + per_thread_state[tid + 128] +
163
per_thread_state[tid + 160] + per_thread_state[tid + 192] +
164
per_thread_state[tid + 224];
165
}
166
167
threadgroup_barrier(mem_flags::mem_threadgroup);
168
169
if (tid == 0) {
170
AccT sum = 0.0;
171
for (ushort i = 0; i < 32; ++i) {
172
sum += per_thread_state[i];
173
}
174
sum /= (in.get_width() * in.get_height());
175
per_thread_state[0] = 1.0 / sqrt(max(sum, AccT(1e-5, 1e-5, 1e-5, 1e-5)) + 1.0e-5);
176
}
177
178
threadgroup_barrier(mem_flags::mem_threadgroup);
179
// Broadcast to all threads.
180
const AccT inv_var = per_thread_state[0];
181
182
const AccT c_weights = static_cast<AccT>(weights[c]);
183
const AccT c_bias = static_cast<AccT>(bias[c]);
184
185
const AccT scale = inv_var * c_weights;
186
const AccT shift = c_bias - mean * scale;
187
188
half4 w;
189
if (instance_norm_has_prelu) {
190
w = channel_shared ? half4(preluWeights[0][0]) : preluWeights[c];
191
}
192
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
193
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
194
half4 scaled =
195
static_cast<half4>(static_cast<AccT>(in.read(ushort2(x, y), gid.z)) * scale + shift);
196
if (instance_norm_has_prelu) {
197
scaled = select(scaled * w, scaled, scaled > 0.0h);
198
}
199
out.write(scaled, ushort2(x, y), gid.z);
200
}
201
}
202
}
203
204
// One block per texture.
205
// 256 threads per block.
206
kernel void instance_norm_nonarray(
207
constant half4* weights[[buffer(0)]],
208
constant half4* bias[[buffer(1)]],
209
constant half4* preluWeights[[ buffer(2), function_constant(instance_norm_has_prelu) ]],
210
texture2d<half, access::read> in[[texture(0)]],
211
texture2d<half, access::write> out[[texture(1)]],
212
ushort3 gid[[thread_position_in_grid]],
213
ushort tid[[thread_index_in_threadgroup]],
214
ushort3 tcount[[threads_per_threadgroup]]) {
215
// const ushort C = ushort_arg_0;
216
const ushort S = ushort_arg_1;
217
const bool channel_shared = S == 1;
218
constexpr ushort THREADGROUP_SIZE = 256;
219
220
threadgroup AccT per_thread_state[THREADGROUP_SIZE];
221
// Each block handles a single texture.
222
per_thread_state[tid] = 0;
223
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
224
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
225
per_thread_state[tid] += static_cast<AccT>(in.read(ushort2(x, y)));
226
}
227
}
228
229
threadgroup_barrier(mem_flags::mem_threadgroup);
230
231
// 256 -> 32 reduction
232
if (tid < 32) {
233
per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
234
per_thread_state[tid + 96] + per_thread_state[tid + 128] +
235
per_thread_state[tid + 160] + per_thread_state[tid + 192] +
236
per_thread_state[tid + 224];
237
}
238
239
threadgroup_barrier(mem_flags::mem_threadgroup);
240
241
if (tid == 0) {
242
AccT sum = 0.0;
243
for (ushort i = 0; i < 32; ++i) {
244
sum += per_thread_state[i];
245
}
246
sum /= (in.get_width() * in.get_height());
247
per_thread_state[0] = sum;
248
}
249
threadgroup_barrier(mem_flags::mem_threadgroup);
250
// Broadcast to all threads.
251
const AccT mean = per_thread_state[0];
252
253
threadgroup_barrier(mem_flags::mem_threadgroup);
254
255
per_thread_state[tid] = 0;
256
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
257
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
258
AccT delta = static_cast<AccT>(in.read(ushort2(x, y))) - mean;
259
per_thread_state[tid] += delta * delta;
260
}
261
}
262
263
threadgroup_barrier(mem_flags::mem_threadgroup);
264
265
// 256 -> 32 reduction
266
if (tid < 32) {
267
per_thread_state[tid] += per_thread_state[tid + 32] + per_thread_state[tid + 64] +
268
per_thread_state[tid + 96] + per_thread_state[tid + 128] +
269
per_thread_state[tid + 160] + per_thread_state[tid + 192] +
270
per_thread_state[tid + 224];
271
}
272
273
threadgroup_barrier(mem_flags::mem_threadgroup);
274
275
if (tid == 0) {
276
AccT sum = 0.0;
277
for (ushort i = 0; i < 32; ++i) {
278
sum += per_thread_state[i];
279
}
280
sum /= (in.get_width() * in.get_height());
281
per_thread_state[0] = 1.0 / sqrt(max(sum, AccT(1e-5, 1e-5, 1e-5, 1e-5)) + 1.0e-5);
282
}
283
284
threadgroup_barrier(mem_flags::mem_threadgroup);
285
// Broadcast to all threads.
286
const AccT inv_var = per_thread_state[0];
287
288
const AccT c_weights = static_cast<AccT>(weights[0]);
289
const AccT c_bias = static_cast<AccT>(bias[0]);
290
291
const AccT scale = inv_var * c_weights;
292
const AccT shift = c_bias - mean * scale;
293
294
half4 w;
295
if (instance_norm_has_prelu) {
296
w = channel_shared ? half4(preluWeights[0][0]) : preluWeights[0];
297
}
298
for (ushort y = gid.y; y < in.get_height(); y += tcount.y) {
299
for (ushort x = gid.x; x < in.get_width(); x += tcount.x) {
300
half4 scaled = static_cast<half4>(static_cast<AccT>(in.read(ushort2(x, y))) * scale + shift);
301
if (instance_norm_has_prelu) {
302
scaled = select(scaled * w, scaled, scaled > 0.0h);
303
}
304
out.write(scaled, ushort2(x, y));
305
}
306
}
307
}
308
309
kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]],
310
texture2d_array<half, access::write> out[[texture(0)]],
311
ushort3 gid[[thread_position_in_grid]]) {
312
const ushort C = ushort_arg_0;
313
const ushort H = ushort_arg_1;
314
const ushort W = ushort_arg_2;
315
if (gid.x >= W || gid.y >= H) {
316
return;
317
}
318
319
const ushort n = gid.z / divRoundUp(C, 4);
320
const ushort c = gid.z - n * divRoundUp(C, 4);
321
322
// TODO: are the `else` branches needed?
323
// TODO: trick the optimizer for case where C == 4?
324
#define CHW_TO_CHWP4(idx, n, c_, h, w) \
325
if ((c_) < C) { \
326
trns[idx] = in[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)]; \
327
} else { \
328
trns[idx] = 0.0h; \
329
}
330
331
half4 trns;
332
CHW_TO_CHWP4(0, n, c * 4 + 0, gid.y, gid.x);
333
CHW_TO_CHWP4(1, n, c * 4 + 1, gid.y, gid.x);
334
CHW_TO_CHWP4(2, n, c * 4 + 2, gid.y, gid.x);
335
CHW_TO_CHWP4(3, n, c * 4 + 3, gid.y, gid.x);
336
#undef CHW_TO_CHWP4
337
338
out.write(trns, gid.xy, gid.z);
339
}
340
341
kernel void copy_nchw_to_metal_nonarray(constant float* in[[buffer(0)]],
342
texture2d<half, access::write> out[[texture(0)]],
343
ushort2 gid[[thread_position_in_grid]]) {
344
const ushort C = ushort_arg_0;
345
const ushort H = ushort_arg_1;
346
const ushort W = ushort_arg_2;
347
348
if (gid.x >= W || gid.y >= H) {
349
return;
350
}
351
352
half4 trns;
353
// TODO: are the `else` branches needed?
354
// TODO: trick the optimizer for case where C % 4 == 0?
355
356
#define CHW_TO_CHWP4(idx, c, h, w) \
357
if ((c) < C) { \
358
trns[idx] = in[int(c) * H * W + int(h) * W + int(w)]; \
359
} else { \
360
trns[idx] = 0.0h; \
361
}
362
363
CHW_TO_CHWP4(0, 0, gid.y, gid.x);
364
CHW_TO_CHWP4(1, 1, gid.y, gid.x);
365
CHW_TO_CHWP4(2, 2, gid.y, gid.x);
366
CHW_TO_CHWP4(3, 3, gid.y, gid.x);
367
#undef CHW_TO_CHWP4
368
369
out.write(trns, gid.xy);
370
}
371
372
kernel void copy_metal_to_nchw(texture2d_array<half, access::read> in[[texture(0)]],
373
device float* out[[buffer(0)]],
374
ushort3 gid[[thread_position_in_grid]]) {
375
const ushort C = ushort_arg_0;
376
const ushort H = ushort_arg_1;
377
const ushort W = ushort_arg_2;
378
379
if (gid.x >= W || gid.y >= H) {
380
return;
381
}
382
const ushort n = gid.z / divRoundUp(C, 4);
383
const ushort c = gid.z - n * divRoundUp(C, 4);
384
385
half4 cs = in.read(gid.xy, gid.z);
386
387
#define CHWP4_TO_CHW(idx, n, c_, h, w) \
388
if ((c_) < C) { \
389
out[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)] = cs[idx]; \
390
}
391
392
CHWP4_TO_CHW(0, n, c * 4 + 0, gid.y, gid.x);
393
CHWP4_TO_CHW(1, n, c * 4 + 1, gid.y, gid.x);
394
CHWP4_TO_CHW(2, n, c * 4 + 2, gid.y, gid.x);
395
CHWP4_TO_CHW(3, n, c * 4 + 3, gid.y, gid.x);
396
#undef CHWP4_TO_CHW
397
}
398
399
kernel void copy_metal_to_nchw_nonarray(texture2d<half, access::read> in[[texture(0)]],
400
device float* out[[buffer(0)]],
401
ushort2 gid[[thread_position_in_grid]]) {
402
const ushort C = ushort_arg_0;
403
const ushort H = ushort_arg_1;
404
const ushort W = ushort_arg_2;
405
406
if (gid.x >= W || gid.y >= H) {
407
return;
408
}
409
410
half4 cs = in.read(gid.xy);
411
412
#define CHWP4_TO_CHW(idx, c, h, w) \
413
if ((c) < C) { \
414
out[int(c) * H * W + int(h) * W + int(w)] = cs[idx]; \
415
}
416
417
CHWP4_TO_CHW(0, 0, gid.y, gid.x);
418
CHWP4_TO_CHW(1, 1, gid.y, gid.x);
419
CHWP4_TO_CHW(2, 2, gid.y, gid.x);
420
CHWP4_TO_CHW(3, 3, gid.y, gid.x);
421
#undef CHWP4_TO_CHW
422
}
423
424
kernel void convtranspose_upscale(texture2d_array<half, access::read> in[[texture(0)]],
425
texture2d_array<half, access::write> out[[texture(1)]],
426
ushort3 gid[[thread_position_in_grid]]) {
427
// All resolved at compile time.
428
// Assume symmetric kernel/stride/pad for now.
429
const ushort kernel_ = ushort_arg_0;
430
const ushort stride = ushort_arg_1;
431
const ushort pad = ushort_arg_2;
432
433
half4 zero(0.0h, 0.0h, 0.0h, 0.0h);
434
435
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
436
return;
437
}
438
const ushort2 gid_ = gid.xy;
439
if (gid.x < kernel_ - 1 - pad || gid.y < kernel_ - 1 - pad) {
440
out.write(zero, gid_, gid.z);
441
return;
442
}
443
444
if (((gid.x - (kernel_ - 1 - pad)) % stride == 0) &&
445
((gid.y - (kernel_ - 1 - pad)) % stride == 0)) {
446
ushort2 in_pos((gid.x - (kernel_ - 1 - pad)) / stride, (gid.y - (kernel_ - 1 - pad)) / stride);
447
448
if (in_pos.x < in.get_width() && in_pos.y < in.get_height()) {
449
half4 input = in.read(in_pos, gid.z);
450
out.write(input, gid_, gid.z);
451
} else {
452
out.write(zero, gid_, gid.z);
453
}
454
} else {
455
out.write(zero, gid_, gid.z);
456
}
457
}
458
459
constant bool has_in_arr = (ushort_arg_7 > 1 || ushort_arg_0 * ushort_arg_1 * ushort_arg_6 > 4);
460
constant bool has_out_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
461
constant bool has_in_tex = (!has_in_arr);
462
constant bool has_out_tex = (!has_out_arr);
463
464
kernel void col2im(
465
texture2d_array<half, access::read> ina[[ texture(0), function_constant(has_in_arr) ]],
466
texture2d<half, access::read> in[[ texture(0), function_constant(has_in_tex) ]],
467
texture2d_array<half, access::write> outa[[ texture(1), function_constant(has_out_arr) ]],
468
texture2d<half, access::write> out[[ texture(1), function_constant(has_out_tex) ]],
469
constant half4* bias[[buffer(0)]],
470
ushort3 gid[[thread_position_in_grid]]) {
471
if (has_out_tex) {
472
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
473
return;
474
}
475
} else {
476
if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
477
return;
478
}
479
}
480
const ushort kernel_h = ushort_arg_0;
481
const ushort kernel_w = ushort_arg_1;
482
const ushort stride_h = ushort_arg_2;
483
const ushort stride_w = ushort_arg_3;
484
const ushort pad_l = ushort_arg_4;
485
const ushort pad_t = ushort_arg_5;
486
const ushort C = ushort_arg_6;
487
// const int N = ushort_arg_7;
488
const ushort height_col = ushort_arg_8;
489
const ushort width_col = ushort_arg_9;
490
491
const ushort n = gid.z / divRoundUp(C, 4);
492
const ushort c = gid.z - n * divRoundUp(C, 4);
493
494
const ushort w = gid.x + pad_l;
495
const ushort h = gid.y + pad_t;
496
497
// compute the start and end of the output
498
const ushort w_col_start = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
499
const ushort w_col_end = min(ushort(w / stride_w + 1), ushort(width_col));
500
const ushort h_col_start = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
501
const ushort h_col_end = min(ushort(h / stride_h + 1), ushort(height_col));
502
503
float4 val = static_cast<float4>(bias[c]);
504
for (ushort h_col = h_col_start; h_col < h_col_end; ++h_col) {
505
for (ushort w_col = w_col_start; w_col < w_col_end; ++w_col) {
506
const ushort w_k = w - w_col * stride_w;
507
const ushort h_k = h - h_col * stride_h;
508
509
// layout is essentially: [N][K][K][C][H][W]
510
// - where the divRoundUp(K * K * C, 4) channels are interleaved as usual.
511
// Thus, it's actually [N][divRoundUp(K * K * C, 4)][H][W].
512
513
// If C % 4 is not zero, then we have to play some games via partial indexing.
514
// TODO: is it worth optimizing this loop via padding in C?
515
if (C % 4 == 0) {
516
ushort c_col = n * kernel_h * kernel_w * divRoundUp(C, 4) +
517
h_k * kernel_w * divRoundUp(C, 4) + w_k * divRoundUp(C, 4) + c;
518
if (has_in_arr) {
519
val += static_cast<float4>(ina.read(ushort2(w_col, h_col), c_col));
520
}
521
if (has_in_tex) {
522
val += static_cast<float4>(in.read(ushort2(w_col, h_col), c_col));
523
}
524
} else {
525
half4 components(0, 0, 0, 0);
526
for (auto i = 0; i < 4; ++i) {
527
ushort c_col_i = n * divRoundUp(kernel_h * kernel_w * C, 4) * 4 + h_k * kernel_w * C +
528
w_k * C + c * 4 + i;
529
ushort c_col_i_z = c_col_i / 4;
530
ushort c_col_i_off = c_col_i - c_col_i_z * 4;
531
if (has_in_arr) {
532
components[i] = ina.read(ushort2(w_col, h_col), c_col_i_z)[c_col_i_off];
533
}
534
if (has_in_tex) {
535
components[i] = in.read(ushort2(w_col, h_col))[c_col_i_off];
536
}
537
}
538
val += static_cast<float4>(components);
539
}
540
}
541
}
542
if (has_out_arr) {
543
outa.write(static_cast<half4>(val), gid.xy, gid.z);
544
}
545
if (has_out_tex) {
546
out.write(static_cast<half4>(val), gid.xy);
547
}
548
}
549
550
kernel void preprocess_stylizer(device uchar4* in[[buffer(0)]],
551
constant half* mean[[buffer(1)]],
552
constant half4* noise[[buffer(2)]],
553
texture2d<half, access::write> out[[texture(0)]],
554
ushort2 gid[[thread_position_in_grid]]) {
555
556
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
557
return;
558
}
559
const ushort noise_size = ushort_arg_0;
560
561
half4 mean_half(mean[0], mean[1], mean[2], 0.0h);
562
uint input_noise_idx = ((uint)out.get_width() * (uint)gid.y + (uint)gid.x) % (noise_size / 4);
563
const half4 input_noise = noise[input_noise_idx];
564
const uint W = out.get_width();
565
#define in_at(h, w) in[(uint)(h)*W + (uint)(w)]
566
uchar4 input = in_at(gid.y, gid.x);
567
#undef in_at
568
half4 input_half = static_cast<half4>(input);
569
out.write(input_half - mean_half + input_noise, gid);
570
}
571
572
kernel void deprocess_stylizer(texture2d<half, access::read> in[[texture(0)]],
573
device uchar4* out[[buffer(0)]],
574
constant half* mean[[buffer(1)]],
575
ushort2 gid[[thread_position_in_grid]]) {
576
if (gid.x >= in.get_width() || gid.y >= in.get_height()) {
577
return;
578
}
579
580
half4 value = in.read(gid);
581
582
half4 mean_h(mean[0], mean[1], mean[2], 0.0h);
583
half4 min_h(0.0h, 0.0h, 0.0h, 255.0h);
584
half4 max_h(255.0h, 255.0h, 255.0h, 255.0h);
585
half4 clamped = clamp(value + mean_h, min_h, max_h);
586
const uint W = in.get_width();
587
#define out_at(h, w, v) out[(uint)(h)*W + (uint)(w)] = (v)
588
out_at(gid.y, gid.x, static_cast<uchar4>(clamped));
589
#undef out_at
590
}
591
592
kernel void reflection_padding_nonarray(texture2d<half, access::read> in[[texture(0)]],
593
texture2d<half, access::write> out[[texture(1)]],
594
ushort2 gid[[thread_position_in_grid]]) {
595
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
596
return;
597
}
598
ushort H = in.get_height();
599
ushort PH = out.get_height();
600
601
// Note: we assume symmetric padding on H/W here, which is verified
602
// in the calling code.
603
ushort pad_h = (PH - H) / 2;
604
ushort W = in.get_width();
605
ushort PW = out.get_width();
606
ushort pad_w = (PW - W) / 2;
607
608
short h = short(gid.y) - short(pad_h);
609
h = max(h, short(-h));
610
h = min(h, short(2 * H - h - 2));
611
612
short w = short(gid.x) - short(pad_w);
613
w = max(w, short(-w));
614
w = min(w, short(2 * W - w - 2));
615
616
ushort2 inid(w, h);
617
out.write(in.read(inid), gid);
618
}
619
620
kernel void reflection_padding(texture2d_array<half, access::read> in[[texture(0)]],
621
texture2d_array<half, access::write> out[[texture(1)]],
622
ushort3 gid[[thread_position_in_grid]]) {
623
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
624
return;
625
}
626
ushort H = in.get_height();
627
ushort PH = out.get_height();
628
629
// Note: we assume symmetric padding on H/W here, which is verified
630
// in the calling code.
631
ushort pad_h = (PH - H) / 2;
632
ushort W = in.get_width();
633
ushort PW = out.get_width();
634
ushort pad_w = (PW - W) / 2;
635
636
short h = short(gid.y) - short(pad_h);
637
h = max(h, short(-h));
638
h = min(h, short(2 * H - h - 2));
639
640
short w = short(gid.x) - short(pad_w);
641
w = max(w, short(-w));
642
w = min(w, short(2 * W - w - 2));
643
644
ushort2 inid(w, h);
645
646
out.write(in.read(inid, gid.z), gid.xy, gid.z);
647
}
648
649
kernel void bilinear_upsample(texture2d<half, access::sample> in[[texture(0)]],
650
texture2d<half, access::write> out[[texture(1)]],
651
ushort2 gid[[thread_position_in_grid]]) {
652
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
653
return;
654
}
655
ushort2 src = gid / 2;
656
constexpr sampler sampler(address::clamp_to_edge, filter::linear, coord::pixel);
657
half4 value = in.sample(sampler, static_cast<float2>(src));
658
out.write(value, gid);
659
}
660
661
constant bool in0_is_tex = ushort_arg_0 <= 1 && ushort_arg_1 <= 4;
662
constant bool in0_is_arr = !in0_is_tex;
663
664
kernel void elementwise_mul(texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
665
texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
666
texture2d<half, access::write> out[[texture(2), function_constant(in0_is_tex)]],
667
texture2d_array<half, access::write> outa[[texture(2), function_constant(in0_is_arr)]],
668
constant float* in1[[buffer(1)]],
669
ushort3 gid[[thread_position_in_grid]]) {
670
ushort last_dim = ushort_arg_2;
671
ushort idx;
672
if (in0_is_tex) {
673
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
674
return;
675
}
676
idx = gid.y * out.get_width() + gid.x;
677
} else {
678
if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
679
return;
680
}
681
idx = gid.y * outa.get_width() + gid.x;
682
}
683
ushort2 gid_ = gid.xy;
684
if (in0_is_tex) {
685
out.write(in0.read(gid_) * in1[idx % last_dim], gid_);
686
} else {
687
outa.write(ina0.read(gid_, gid.z) * in1[idx % last_dim], gid_, gid.z);
688
}
689
}
690
691
kernel void elementwise_sub(texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
692
texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
693
texture2d<half, access::write> out[[texture(2), function_constant(in0_is_tex)]],
694
texture2d_array<half, access::write> outa[[texture(2), function_constant(in0_is_arr)]],
695
constant float* in1[[buffer(1)]],
696
ushort3 gid[[thread_position_in_grid]]) {
697
ushort last_dim = ushort_arg_2;
698
ushort idx;
699
if (in0_is_tex) {
700
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
701
return;
702
}
703
idx = gid.y * out.get_width() + gid.x;
704
} else {
705
if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
706
return;
707
}
708
idx = gid.y * outa.get_width() + gid.x;
709
}
710
ushort2 gid_ = gid.xy;
711
if (in0_is_tex) {
712
out.write(in0.read(gid_) - in1[idx % last_dim], gid_);
713
} else {
714
outa.write(ina0.read(gid_, gid.z) - in1[idx % last_dim], gid_, gid.z);
715
}
716
}
717
718
kernel void elementwise_add_nonarray(texture2d<half, access::read> in0[[texture(0)]],
719
texture2d<half, access::read> in1[[texture(1)]],
720
texture2d<half, access::write> out[[texture(2)]],
721
ushort2 gid[[thread_position_in_grid]]) {
722
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
723
return;
724
}
725
out.write(in0.read(gid) + in1.read(gid), gid);
726
}
727
728
kernel void elementwise_add(texture2d_array<half, access::read> in0[[texture(0)]],
729
texture2d_array<half, access::read> in1[[texture(1)]],
730
texture2d_array<half, access::write> out[[texture(2)]],
731
ushort3 gid[[thread_position_in_grid]]) {
732
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
733
return;
734
}
735
ushort2 gid_ = gid.xy;
736
out.write(in0.read(gid_, gid.z) + in1.read(gid_, gid.z), gid_, gid.z);
737
}
738
739
constant bool has_in0_arg = (ushort_arg_0 > 0);
740
constant bool has_in1_arg = (ushort_arg_1 > 0);
741
constant bool has_in2_arg = (ushort_arg_2 > 0);
742
constant bool has_in3_arg = (ushort_arg_3 > 0);
743
744
constant bool has_in0_tex = (has_in0_arg && ushort_arg_0 <= 4 && ushort_arg_5 <= 1);
745
constant bool has_in1_tex = (has_in1_arg && ushort_arg_1 <= 4 && ushort_arg_5 <= 1);
746
constant bool has_in2_tex = (has_in2_arg && ushort_arg_2 <= 4 && ushort_arg_5 <= 1);
747
constant bool has_in3_tex = (has_in3_arg && ushort_arg_3 <= 4 && ushort_arg_5 <= 1);
748
749
constant bool has_in0_array = (has_in0_arg && !has_in0_tex);
750
constant bool has_in1_array = (has_in1_arg && !has_in1_tex);
751
constant bool has_in2_array = (has_in2_arg && !has_in2_tex);
752
constant bool has_in3_array = (has_in3_arg && !has_in3_tex);
753
754
constant bool concat_has_out_tex = (ushort_arg_4 <= 4 && ushort_arg_5 <= 1);
755
constant bool concat_has_out_array = !concat_has_out_tex;
756
757
inline ushort idx_3(ushort z, ushort C0, ushort C1, ushort C2, ushort C3) {
758
if (z < C0) {
759
return 0;
760
}
761
if (z < (C0 + C1)) {
762
return 1;
763
}
764
if (z < (C0 + C1 + C2)) {
765
return 2;
766
}
767
return 3;
768
}
769
770
inline ushort idx_2(ushort z, ushort C0, ushort C1, ushort C2) {
771
if (z < C0) {
772
return 0;
773
}
774
if (z < (C0 + C1)) {
775
return 1;
776
}
777
return 2;
778
}
779
780
inline ushort idx_1(ushort z, ushort C0, ushort C1) {
781
if (z < C0) {
782
return 0;
783
} else {
784
return 1;
785
}
786
}
787
788
inline ushort idx_0(ushort z, ushort C0) { return 0; }
789
790
// in a texture_array with size C, find the offset for image N at plane c.
791
inline constexpr ushort z_off(ushort n, ushort c, ushort C) { return n * divRoundUp(C, 4) + c / 4; }
792
793
kernel void concat(
794
texture2d<half, access::read> in0[[ texture(0), function_constant(has_in0_tex) ]],
795
texture2d<half, access::read> in1[[ texture(1), function_constant(has_in1_tex) ]],
796
texture2d<half, access::read> in2[[ texture(2), function_constant(has_in2_tex) ]],
797
texture2d<half, access::read> in3[[ texture(3), function_constant(has_in3_tex) ]],
798
texture2d_array<half, access::read> ina0[[ texture(0), function_constant(has_in0_array) ]],
799
texture2d_array<half, access::read> ina1[[ texture(1), function_constant(has_in1_array) ]],
800
texture2d_array<half, access::read> ina2[[ texture(2), function_constant(has_in2_array) ]],
801
texture2d_array<half, access::read> ina3[[ texture(3), function_constant(has_in3_array) ]],
802
texture2d<half, access::write> out[[texture(5),
803
function_constant(concat_has_out_tex) ]],
804
texture2d_array<half, access::write> outa[[texture(5),
805
function_constant(concat_has_out_array) ]],
806
ushort3 gid[[thread_position_in_grid]]) {
807
if (concat_has_out_tex) {
808
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
809
return;
810
}
811
} else {
812
if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
813
return;
814
}
815
}
816
817
const ushort C0 = ushort_arg_0;
818
const ushort C1 = ushort_arg_1;
819
const ushort C2 = ushort_arg_2;
820
const ushort C3 = ushort_arg_3;
821
const ushort C = C0 + C1 + C2 + C3;
822
const ushort n = gid.z / divRoundUp(C, 4);
823
const ushort c = gid.z - n * divRoundUp(C, 4);
824
// Fill channel 4*c to 4*(c+1) of nth image of output
825
826
ushort2 gid_ = ushort2(gid.x, gid.y);
827
half4 value;
828
829
for (int off = 0; off < 4; ++off) {
830
ushort cur_channel = c * 4 + off;
831
ushort cur_idx = 0;
832
if (cur_channel >= C) {
833
break;
834
}
835
if (has_in3_arg) {
836
cur_idx = idx_3(cur_channel, C0, C1, C2, C3);
837
} else if (has_in2_arg) {
838
cur_idx = idx_2(cur_channel, C0, C1, C2);
839
} else if (has_in1_arg) {
840
cur_idx = idx_1(cur_channel, C0, C1);
841
} else if (has_in0_arg) {
842
cur_idx = idx_0(cur_channel, C0);
843
} else {
844
// never reached.
845
cur_idx = 0;
846
}
847
ushort src_off = 0;
848
switch (cur_idx) {
849
case 0:
850
src_off = cur_channel % 4;
851
break;
852
case 1:
853
src_off = (cur_channel - C0) % 4;
854
break;
855
case 2:
856
src_off = (cur_channel - (C0 + C1)) % 4;
857
break;
858
case 3:
859
src_off = (cur_channel - (C0 + C1 + C2)) % 4;
860
break;
861
}
862
// try to see if we can only issue one read op for the 4 values
863
bool fast_path = false;
864
if (off == 0 && src_off == 0 && (cur_channel + 3) < C) {
865
ushort last_idx = 0;
866
if (has_in3_arg) {
867
last_idx = idx_3(cur_channel + 3, C0, C1, C2, C3);
868
} else if (has_in2_arg) {
869
last_idx = idx_2(cur_channel + 3, C0, C1, C2);
870
} else if (has_in1_arg) {
871
last_idx = idx_1(cur_channel + 3, C0, C1);
872
} else if (has_in0_arg) {
873
last_idx = idx_0(cur_channel + 3, C0);
874
} else {
875
// never reached.
876
last_idx = 0;
877
}
878
if (cur_idx == last_idx) {
879
fast_path = true;
880
}
881
}
882
switch (cur_idx) {
883
case 0: {
884
if (has_in0_tex) {
885
if (fast_path) {
886
value = in0.read(gid_);
887
} else {
888
value[off] = in0.read(gid_)[src_off];
889
}
890
}
891
if (has_in0_array) {
892
if (fast_path) {
893
value = ina0.read(gid_, z_off(n, cur_channel, C0));
894
} else {
895
value[off] = ina0.read(gid_, z_off(n, cur_channel, C0))[src_off];
896
}
897
}
898
break;
899
}
900
case 1: {
901
if (has_in1_tex) {
902
if (fast_path) {
903
value = in1.read(gid_);
904
} else {
905
value[off] = in1.read(gid_)[src_off];
906
}
907
}
908
if (has_in1_array) {
909
if (fast_path) {
910
value = ina1.read(gid_, z_off(n, cur_channel - C0, C1));
911
} else {
912
value[off] = ina1.read(gid_, z_off(n, cur_channel - C0, C1))[src_off];
913
}
914
}
915
break;
916
}
917
case 2: {
918
if (has_in2_tex) {
919
if (fast_path) {
920
value = in2.read(gid_);
921
} else {
922
value[off] = in2.read(gid_)[src_off];
923
}
924
}
925
if (has_in2_array) {
926
if (fast_path) {
927
value = ina2.read(gid_, z_off(n, cur_channel - (C0 + C1), C2));
928
} else {
929
value[off] = ina2.read(gid_, z_off(n, cur_channel - (C0 + C1), C2))[src_off];
930
}
931
}
932
break;
933
}
934
case 3: {
935
if (has_in3_tex) {
936
if (fast_path) {
937
value = in3.read(gid_);
938
} else {
939
value[off] = in3.read(gid_)[src_off];
940
}
941
}
942
if (has_in3_array) {
943
if (fast_path) {
944
value = ina3.read(gid_, z_off(n, cur_channel - (C0 + C1 + C2), C3));
945
} else {
946
value[off] = ina3.read(gid_, z_off(n, cur_channel - (C0 + C1 + C2), C3))[src_off];
947
}
948
}
949
break;
950
}
951
}
952
if (fast_path) {
953
break;
954
}
955
}
956
if (concat_has_out_tex) {
957
out.write(value, gid_, gid.z);
958
} else {
959
outa.write(value, gid_, gid.z);
960
}
961
}
962
963
using RoIT = half;
964
using RoIT4 = half4;
965
constant bool rw_has_in_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
966
constant bool rw_has_out_arr = (ushort_arg_4 > 1 || ushort_arg_2 > 4);
967
constant bool rw_has_in_tex = (!rw_has_in_arr);
968
constant bool rw_has_out_tex = (!rw_has_out_arr);
969
kernel void roi_warp(texture2d_array<half, access::sample> ina[[texture(0), function_constant(rw_has_in_arr)]],
970
texture2d<half, access::sample> in[[texture(0), function_constant(rw_has_in_tex)]],
971
texture2d_array<half, access::write> outa[[texture(1), function_constant(rw_has_out_arr)]],
972
texture2d<half, access::write> out[[texture(1), function_constant(rw_has_out_tex)]],
973
constant half4* rois[[buffer(0)]],
974
ushort3 gid[[thread_position_in_grid]]) {
975
ushort out_width, out_height;
976
if (rw_has_out_arr) {
977
out_width = outa.get_width();
978
out_height = outa.get_height();
979
} else {
980
out_width = out.get_width();
981
out_height = out.get_height();
982
}
983
if (gid.x >= out_width || gid.y >= out_height) {
984
return;
985
}
986
constexpr sampler s2(coord::pixel, address::clamp_to_edge, filter::linear);
987
988
const half spatial_scale = half(ushort_arg_0) / 10000;
989
const ushort sampling_ratio = ushort_arg_1;
990
const ushort C = ushort_arg_2;
991
const ushort pw = gid.x;
992
const ushort ph = gid.y;
993
const ushort n = gid.z / divRoundUp(C, 4);
994
const ushort c = gid.z % divRoundUp(C, 4);
995
996
const RoIT4 roi_scaled = rois[n] * spatial_scale;
997
const RoIT roi_start_w = roi_scaled[0];
998
const RoIT roi_start_h = roi_scaled[1];
999
const RoIT roi_end_w = roi_scaled[2];
1000
const RoIT roi_end_h = roi_scaled[3];
1001
1002
// Force malformed ROIs to be 1x1
1003
const RoIT roi_width = max(roi_end_w - roi_start_w, (RoIT)1.);
1004
const RoIT roi_height = max(roi_end_h - roi_start_h, (RoIT)1.);
1005
1006
const RoIT bin_size_h = static_cast<RoIT>(roi_height) / static_cast<RoIT>(out_height);
1007
const RoIT bin_size_w = static_cast<RoIT>(roi_width) / static_cast<RoIT>(out_width);
1008
const ushort roi_bin_grid_h = sampling_ratio > 0 ? sampling_ratio : ceil(roi_height / static_cast<RoIT>(out_height));
1009
const ushort roi_bin_grid_w = sampling_ratio > 0 ? sampling_ratio : ceil(roi_width / static_cast<RoIT>(out_width));
1010
const ushort iy_upper = (sampling_ratio > 0) ? roi_bin_grid_h : (roi_bin_grid_h + 1);
1011
const ushort ix_upper = (sampling_ratio > 0) ? roi_bin_grid_w : (roi_bin_grid_w + 1);
1012
1013
const RoIT count = iy_upper * ix_upper;
1014
1015
RoIT4 output_val = 0.0;
1016
for (int iy = 0; iy < iy_upper; iy++) {
1017
for (int ix = 0; ix < ix_upper; ix++) {
1018
const RoIT y =
1019
roi_start_h + ph * bin_size_h + iy * bin_size_h / static_cast<RoIT>(roi_bin_grid_h);
1020
const RoIT x =
1021
roi_start_w + pw * bin_size_w + ix * bin_size_w / static_cast<RoIT>(roi_bin_grid_w);
1022
if (rw_has_in_arr) {
1023
output_val += ina.sample(s2, float2(x + 0.5, y + 0.5), c);
1024
} else {
1025
output_val += in.sample(s2, float2(x + 0.5, y + 0.5));
1026
}
1027
}
1028
}
1029
output_val /= count;
1030
if (rw_has_out_arr) {
1031
outa.write(static_cast<half4>(output_val), gid.xy, gid.z);
1032
} else {
1033
out.write(static_cast<half4>(output_val), gid.xy);
1034
}
1035
}
1036
1037
kernel void resize_nearest(texture2d_array<half, access::sample> in[[texture(0)]],
1038
texture2d_array<half, access::write> out[[texture(1)]],
1039
ushort3 gid[[thread_position_in_grid]]) {
1040
const ushort oH = ushort_arg_0;
1041
const ushort oW = ushort_arg_1;
1042
if (gid.x >= oW || gid.y >= oH) {
1043
return;
1044
}
1045
const float height_scale = float(ushort_arg_2) / 10000;
1046
const float width_scale = float(ushort_arg_3) / 10000;
1047
constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
1048
const int in_y = (int)(gid.y / height_scale);
1049
const int in_x = (int)(gid.x / width_scale);
1050
out.write(in.sample(s, float2(in_x, in_y), gid.z), gid.xy, gid.z);
1051
}
1052
1053
kernel void resize_nearest_nonarray(texture2d<half, access::sample> in[[texture(0)]],
1054
texture2d<half, access::write> out[[texture(1)]],
1055
ushort2 gid[[thread_position_in_grid]]) {
1056
const ushort oH = ushort_arg_0;
1057
const ushort oW = ushort_arg_1;
1058
if (gid.x >= oW || gid.y >= oH) {
1059
return;
1060
}
1061
const float height_scale = float(ushort_arg_2) / 10000;
1062
const float width_scale = float(ushort_arg_3) / 10000;
1063
constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
1064
const int in_y = (int)(gid.y / height_scale);
1065
const int in_x = (int)(gid.x / width_scale);
1066
out.write(in.sample(s, float2(in_x, in_y)), gid.xy);
1067
}
1068
1069
kernel void nms(device uint* mask[[buffer(0)]],
1070
constant float* proposals[[buffer(1)]],
1071
constant int* indices[[buffer(2)]],
1072
ushort2 tgid[[threadgroup_position_in_grid]],
1073
ushort2 tid[[thread_position_in_threadgroup]]) {
1074
const ushort num_proposals = ushort_arg_0;
1075
const ushort threads_per_group = ushort_arg_1;
1076
float nms_thresh = float(ushort_arg_2) / 10000.0;
1077
const ushort global_offset = ushort_arg_3;
1078
const ushort row_start = tgid.y;
1079
const ushort col_start = tgid.x;
1080
const ushort trd_id = tid.x;
1081
1082
const short row_size = min(short(32), short(num_proposals - row_start * threads_per_group));
1083
const short col_size = min(short(32), short(num_proposals - col_start * threads_per_group));
1084
1085
// mask the bit if the IoU between two proposals exceeds the threshold
1086
if (trd_id < row_size) {
1087
const ushort cur_idx = global_offset + row_start * threads_per_group + trd_id;
1088
const ushort offset = indices[cur_idx] * 4;
1089
const float4 cur_proposal = float4(
1090
proposals[offset], proposals[offset + 1], proposals[offset + 2], proposals[offset + 3]);
1091
uint cur_mask = 0;
1092
ushort group_start = 0; // start index within group
1093
if (row_start == col_start) {
1094
// if in the same group, start from the next
1095
group_start = trd_id + 1;
1096
}
1097
for (ushort i = group_start; i < col_size; i++) {
1098
float4 a = cur_proposal;
1099
ushort idx = indices[global_offset + col_start * threads_per_group + i] * 4;
1100
float4 b = float4(proposals[idx], proposals[idx + 1], proposals[idx + 2], proposals[idx + 3]);
1101
float left = max(a[0], b[0]);
1102
float right = min(a[2], b[2]);
1103
float top = max(a[1], b[1]);
1104
float bottom = min(a[3], b[3]);
1105
float width = max(right - left + 1.0, 0.0);
1106
float height = max(bottom - top + 1.0, 0.0);
1107
float interS = width * height;
1108
float Sa = (a[2] - a[0] + 1.0) * (a[3] - a[1] + 1.0);
1109
float Sb = (b[2] - b[0] + 1.0) * (b[3] - b[1] + 1.0);
1110
float iou = interS / (Sa + Sb - interS);
1111
if (iou - nms_thresh > 0) {
1112
cur_mask |= 1U << i;
1113
}
1114
}
1115
ushort col_blocks = (num_proposals + threads_per_group - 1) / threads_per_group;
1116
mask[cur_idx * col_blocks + col_start] = cur_mask;
1117
}
1118
}
1119
1120
1121
kernel void channel_shuffle(
1122
texture2d<half, access::read> in0[[texture(0), function_constant(in0_is_tex)]],
1123
texture2d_array<half, access::read> ina0[[texture(0), function_constant(in0_is_arr)]],
1124
texture2d<half, access::write> out[[texture(1), function_constant(in0_is_tex)]],
1125
texture2d_array<half, access::write> outa[[texture(1), function_constant(in0_is_arr)]],
1126
ushort3 gid[[thread_position_in_grid]]) {
1127
ushort C = ushort_arg_1;
1128
ushort K = ushort_arg_2;
1129
ushort groups = ushort_arg_3;
1130
1131
if (in0_is_tex) {
1132
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
1133
return;
1134
}
1135
} else {
1136
if (gid.x >= outa.get_width() || gid.y >= outa.get_height()) {
1137
return;
1138
}
1139
}
1140
const ushort n = gid.z / divRoundUp(C, 4);
1141
const ushort c = gid.z - n * divRoundUp(C, 4);
1142
half4 value;
1143
ushort2 gid_ = gid.xy;
1144
for (int off = 0; off < 4; ++off) {
1145
ushort cur_channel = c * 4 + off;
1146
if (cur_channel >= C) {
1147
break;
1148
}
1149
ushort channel_id = cur_channel / groups;
1150
ushort group_id = cur_channel % groups;
1151
ushort c0 = group_id * K + channel_id;
1152
if (in0_is_tex) {
1153
value[off] = in0.read(gid_)[c0 % 4];
1154
} else {
1155
value[off] = ina0.read(gid_, c0 / 4 + n * divRoundUp(C, 4))[c0 % 4];
1156
}
1157
}
1158
if (in0_is_tex) {
1159
out.write(value, gid_);
1160
} else {
1161
outa.write(value, gid_, gid.z);
1162
}
1163
}
1164
1165
)V0G0N";
Generated on Thu Mar 21 2019 13:06:12 for Caffe2 - C++ API by
1.8.11