Caffe2 - Python API
A deep learning, cross platform ML framework
sampling_trainable_mixin.py
1 ## @package sampling_trainable_mixin
2 # Module caffe2.python.layers.sampling_trainable_mixin
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import abc
9 import six
10 
11 
12 class SamplingTrainableMixin(six.with_metaclass(abc.ABCMeta, object)):
13 
14  def __init__(self, *args, **kwargs):
15  super(SamplingTrainableMixin, self).__init__(*args, **kwargs)
16  self._train_param_blobs = None
17  self._train_param_blobs_frozen = False
18 
19  @property
20  @abc.abstractmethod
21  def param_blobs(self):
22  """
23  List of parameter blobs for prediction net
24  """
25  pass
26 
27  @property
28  def train_param_blobs(self):
29  """
30  If train_param_blobs is not set before used, default to param_blobs
31  """
32  if self._train_param_blobs is None:
33  self.train_param_blobs = self.param_blobs
34  return self._train_param_blobs
35 
36  @train_param_blobs.setter
37  def train_param_blobs(self, blobs):
38  assert not self._train_param_blobs_frozen
39  assert blobs is not None
40  self._train_param_blobs_frozen = True
41  self._train_param_blobs = blobs
42 
43  @abc.abstractmethod
44  def _add_ops(self, net, param_blobs):
45  """
46  Add ops to the given net, using the given param_blobs
47  """
48  pass
49 
50  def add_ops(self, net):
51  self._add_ops(net, self.param_blobs)
52 
53  def add_train_ops(self, net):
54  self._add_ops(net, self.train_param_blobs)