2 from __future__
import absolute_import
3 from __future__
import division
4 from __future__
import print_function
5 from __future__
import unicode_literals
19 dropout_for_eval=
False,
22 super(Dropout, self).__init__(model, name, input_record, **kwargs)
23 assert isinstance(input_record,
schema.Scalar),
"Incorrect input type" 24 assert (ratio >= 0
and ratio < 1.0), \
25 "Expected 0 <= ratio < 1, but got ratio of %s" % ratio
33 def _add_ops(self, net, is_test):
34 input_blob = self.input_record.field_blobs()
35 output_blobs = self.output_schema.field_blobs() \
36 + [net.NextScopedBlob(
'd_mask')]
38 net.Dropout(input_blob,
43 def add_train_ops(self, net):
46 def add_eval_ops(self, net):
49 def add_ops(self, net):
def get_next_blob_reference(self, name)
def add_eval_ops(self, net)
def _add_ops(self, net, is_test)