Caffe2 - Python API
A deep learning, cross platform ML framework
train.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package train
17 # Module caffe2.python.helpers.train
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core, scope
24 from caffe2.proto import caffe2_pb2
25 
26 
27 def _get_weights(model, namescope=None):
28  if namescope is None:
29  namescope = scope.CurrentNameScope()
30 
31  if namescope == '':
32  return model.weights[:]
33  else:
34  return [w for w in model.weights if w.GetNameScope() == namescope]
35 
36 
37 def iter(model, blob_out, **kwargs):
38  if 'device_option' in kwargs:
39  del kwargs['device_option']
40  model.param_init_net.ConstantFill(
41  [],
42  blob_out,
43  shape=[1],
44  value=0,
45  dtype=core.DataType.INT64,
46  device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
47  **kwargs
48  )
49  return model.net.Iter(blob_out, blob_out, **kwargs)
50 
51 
52 def accuracy(model, blob_in, blob_out, **kwargs):
53  dev = kwargs['device_option'] if 'device_option' in kwargs \
54  else scope.CurrentDeviceScope()
55  is_cpu = dev is None or dev.device_type == caffe2_pb2.CPU
56 
57  # We support top_k > 1 only on CPU
58  if not is_cpu and 'top_k' in kwargs and kwargs['top_k'] > 1:
59  pred_host = model.net.CopyGPUToCPU(blob_in[0], blob_in[0] + "_host")
60  label_host = model.net.CopyGPUToCPU(blob_in[1], blob_in[1] + "_host")
61 
62  # Now use the Host version of the accuracy op
63  model.net.Accuracy(
64  [pred_host, label_host],
65  blob_out,
66  device_option=core.DeviceOption(caffe2_pb2.CPU, 0),
67  **kwargs
68  )
69  else:
70  model.net.Accuracy(blob_in, blob_out)
71 
72 
73 def add_weight_decay(model, weight_decay):
74  """Adds a decay to weights in the model.
75 
76  This is a form of L2 regularization.
77 
78  Args:
79  weight_decay: strength of the regularization
80  """
81  if weight_decay <= 0.0:
82  return
83  wd = model.param_init_net.ConstantFill(
84  [], 'wd', shape=[1], value=weight_decay
85  )
86  ONE = model.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
87  for param in _get_weights(model):
88  # Equivalent to: grad += wd * param
89  grad = model.param_to_grad[param]
90  model.net.WeightedSum(
91  [grad, ONE, param, wd],
92  grad,
93  )