4 from __future__
import absolute_import
5 from __future__
import division
6 from __future__
import print_function
12 Utility for creating ResNe(X)t 13 "Deep Residual Learning for Image Recognition" by He, Zhang et. al. 2015 14 "Aggregated Residual Transformations for Deep Neural Networks" by Xie et. al. 2016 20 Helper class for constructing residual blocks. 39 self.
no_bias = 1
if no_bias
else 0
57 weight_init=(
"MSRAFill", {}),
74 def add_spatial_bn(self, num_filters):
87 Add a "bottleneck" component as decribed in He et. al. Figure 3 (right) 97 spatial_batch_norm=
True,
110 if spatial_batch_norm:
125 if spatial_batch_norm:
130 last_conv = self.
add_conv(base_filters, output_filters, kernel=1)
131 if spatial_batch_norm:
138 if output_filters != input_filters:
139 shortcut_blob = brew.conv(
145 weight_init=(
"MSRAFill", {}),
150 if spatial_batch_norm:
151 shortcut_blob = brew.spatial_bn(
154 'shortcut_projection_%d_spatbn' % self.
comp_count,
162 self.
model, [shortcut_blob, last_conv],
171 return output_filters
173 def add_simple_block(
178 spatial_batch_norm=
True 188 stride=(1
if down_sampling
is False else 2),
192 if spatial_batch_norm:
196 last_conv = self.
add_conv(num_filters, num_filters, kernel=3, pad=1)
197 if spatial_batch_norm:
201 if (num_filters != input_filters):
202 shortcut_blob = brew.conv(
208 weight_init=(
"MSRAFill", {}),
210 stride=(1
if down_sampling
is False else 2),
213 if spatial_batch_norm:
214 shortcut_blob = brew.spatial_bn(
217 'shortcut_projection_%d_spatbn' % self.
comp_count,
224 self.
model, [shortcut_blob, last_conv],
234 def create_resnet_32x32(
235 model, data, num_input_channels, num_groups, num_labels, is_test=
False 238 Create residual net for smaller images (sec 4.2 of He et. al (2015)) 239 num_groups = 'n' in the paper 243 model, data,
'conv1', num_input_channels, 16, kernel=3, stride=1
246 model,
'conv1',
'conv1_spatbn', 16, epsilon=1e-3, is_test=is_test
248 brew.relu(model,
'conv1_spatbn',
'relu1')
251 filters = [16, 32, 64]
253 builder =
ResNetBuilder(model,
'relu1', no_bias=0, is_test=is_test)
255 for groupidx
in range(0, 3):
256 for blockidx
in range(0, 2 * num_groups):
257 builder.add_simple_block(
258 prev_filters
if blockidx == 0
else filters[groupidx],
260 down_sampling=(
True if blockidx == 0
and 261 groupidx > 0
else False))
262 prev_filters = filters[groupidx]
266 model, builder.prev_blob,
'final_avg', kernel=8, stride=1
268 brew.fc(model,
'final_avg',
'last_out', 64, num_labels)
269 softmax = brew.softmax(model,
'last_out',
'softmax')
273 RESNEXT_BLOCK_CONFIG = {
282 RESNEXT_STRIDES = [1, 2, 2, 2]
284 logging.basicConfig()
285 log = logging.getLogger(
"resnext_builder")
286 log.setLevel(logging.DEBUG)
310 if num_layers
not in RESNEXT_BLOCK_CONFIG:
311 log.error(
"{}-layer is invalid for resnext config".format(num_layers))
313 num_blocks = RESNEXT_BLOCK_CONFIG[num_layers]
314 strides = RESNEXT_STRIDES
315 num_filters = [64, 256, 512, 1024, 2048]
317 if num_layers
in [18, 34]:
318 num_filters = [64, 64, 128, 256, 512]
321 num_features = num_filters[-1]
324 conv_blob = brew.conv(
330 weight_init=(
"MSRAFill", {}),
337 bn_blob = brew.spatial_bn(
343 momentum=bn_momentum,
346 relu_blob = brew.relu(model, bn_blob, bn_blob)
347 max_pool = brew.max_pool(model, relu_blob,
'pool1', kernel=3, stride=2, pad=1)
351 is_test=is_test, bn_epsilon=1e-5, bn_momentum=0.9)
353 inner_dim = num_groups * num_width_per_group
356 for residual_idx
in range(4):
357 residual_num = num_blocks[residual_idx]
358 residual_stride = strides[residual_idx]
359 dim_in = num_filters[residual_idx]
361 for blk_idx
in range(residual_num):
362 dim_in = builder.add_bottleneck(
365 num_filters[residual_idx + 1],
366 stride=residual_stride
if blk_idx == 0
else 1,
373 final_avg = brew.average_pool(
377 kernel=final_avg_kernel,
384 model, final_avg,
'last_out_L{}'.format(num_labels), num_features, num_labels
391 if (label
is not None):
392 (softmax, loss) = model.SoftmaxWithLoss(
397 return (softmax, loss)
400 return brew.softmax(model, last_out,
"softmax")
419 return create_resnext(
426 num_width_per_group=64,
431 conv1_kernel=conv1_kernel,
432 conv1_stride=conv1_stride,
433 final_avg_kernel=final_avg_kernel,
def add_conv(self, in_filters, out_filters, kernel, stride=1, group=1, pad=0)
def add_spatial_bn(self, num_filters)