Caffe2 - Python API
A deep learning, cross platform ML framework
dropout.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 # Module caffe2.python.layers.dropout
17 from __future__ import absolute_import
18 from __future__ import division
19 from __future__ import print_function
20 from __future__ import unicode_literals
21 
22 from caffe2.python import schema
23 from caffe2.python.layers.layers import ModelLayer
24 
25 
27 
28  def __init__(
29  self,
30  model,
31  input_record,
32  name='dropout',
33  ratio=0.5,
34  **kwargs):
35 
36  super(Dropout, self).__init__(model, name, input_record, **kwargs)
37  assert isinstance(input_record, schema.Scalar), "Incorrect input type"
38  assert (ratio >= 0 and ratio < 1.0), \
39  "Expected 0 <= ratio < 1, but got ratio of %s" % ratio
40 
41  self.output_schema = input_record.clone_schema()
42  self.output_schema.set_value(self.get_next_blob_reference('output'))
43 
44  self.ratio = ratio
45 
46  def _add_ops(self, net, is_test):
47  input_blob = self.input_record.field_blobs()
48  output_blobs = self.output_schema.field_blobs() \
49  + [net.NextScopedBlob('d_mask')]
50 
51  net.Dropout(input_blob,
52  output_blobs,
53  ratio=self.ratio,
54  is_test=is_test)
55 
56  def add_train_ops(self, net):
57  self._add_ops(net, is_test=False)
58 
59  def add_eval_ops(self, net):
60  self._add_ops(net, is_test=True)
61 
62  def add_ops(self, net):
63  self.add_eval_ops(net)
def get_next_blob_reference(self, name)
Definition: layers.py:352
def _add_ops(self, net, is_test)
Definition: dropout.py:46