3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
9 from caffe2.proto
import caffe2_pb2
12 def _get_weights(model, namescope=None):
14 namescope = scope.CurrentNameScope()
17 return model.weights[:]
19 return [w
for w
in model.weights
if w.GetNameScope() == namescope]
22 def iter(model, blob_out, **kwargs):
23 if 'device_option' in kwargs:
24 del kwargs[
'device_option']
25 model.param_init_net.ConstantFill(
30 dtype=core.DataType.INT64,
31 device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
34 return model.net.Iter(blob_out, blob_out, **kwargs)
37 def accuracy(model, blob_in, blob_out, **kwargs):
38 dev = kwargs[
'device_option']
if 'device_option' in kwargs \
39 else scope.CurrentDeviceScope()
40 is_cpu = dev
is None or dev.device_type == caffe2_pb2.CPU
43 if not is_cpu
and 'top_k' in kwargs
and kwargs[
'top_k'] > 1:
44 pred_host = model.net.CopyGPUToCPU(blob_in[0], blob_in[0] +
"_host")
45 label_host = model.net.CopyGPUToCPU(blob_in[1], blob_in[1] +
"_host")
49 [pred_host, label_host],
51 device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
55 model.net.Accuracy(blob_in, blob_out)
58 def add_weight_decay(model, weight_decay):
59 """Adds a decay to weights in the model. 61 This is a form of L2 regularization. 64 weight_decay: strength of the regularization 66 if weight_decay <= 0.0:
68 wd = model.param_init_net.ConstantFill(
69 [],
'wd', shape=[1], value=weight_decay
71 ONE = model.param_init_net.ConstantFill([],
"ONE", shape=[1], value=1.0)
72 for param
in _get_weights(model):
74 grad = model.param_to_grad[param]
75 model.net.WeightedSum(
76 [grad, ONE, param, wd],