Caffe2 - Python API
A deep learning, cross platform ML framework
sampling_trainable_mixin.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package sampling_trainable_mixin
17 # Module caffe2.python.layers.sampling_trainable_mixin
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 import abc
24 import six
25 
26 
27 class SamplingTrainableMixin(six.with_metaclass(abc.ABCMeta, object)):
28 
29  def __init__(self, *args, **kwargs):
30  super(SamplingTrainableMixin, self).__init__(*args, **kwargs)
31  self._train_param_blobs = None
32  self._train_param_blobs_frozen = False
33 
34  @property
35  @abc.abstractmethod
36  def param_blobs(self):
37  """
38  List of parameter blobs for prediction net
39  """
40  pass
41 
42  @property
43  def train_param_blobs(self):
44  """
45  If train_param_blobs is not set before used, default to param_blobs
46  """
47  if self._train_param_blobs is None:
48  self.train_param_blobs = self.param_blobs
49  return self._train_param_blobs
50 
51  @train_param_blobs.setter
52  def train_param_blobs(self, blobs):
53  assert not self._train_param_blobs_frozen
54  assert blobs is not None
55  self._train_param_blobs_frozen = True
56  self._train_param_blobs = blobs
57 
58  @abc.abstractmethod
59  def _add_ops(self, net, param_blobs):
60  """
61  Add ops to the given net, using the given param_blobs
62  """
63  pass
64 
65  def add_ops(self, net):
66  self._add_ops(net, self.param_blobs)
67 
68  def add_train_ops(self, net):
69  self._add_ops(net, self.train_param_blobs)