Caffe2 - Python API
A deep learning, cross platform ML framework
nonlinearity.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package nonlinearity
17 # Module caffe2.python.helpers.nonlinearity
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.python import core
24 
25 
26 def prelu(model, blob_in, blob_out, num_channels=1, slope_init=None,
27  **kwargs):
28  """PRelu"""
29  slope_init = (
30  slope_init if slope_init else ('ConstantFill', {'value': 0.25}))
31  if model.init_params:
32  slope = model.param_init_net.__getattr__(slope_init[0])(
33  [],
34  blob_out + '_slope',
35  shape=[num_channels],
36  **slope_init[1]
37  )
38  else:
39  slope = core.ScopedBlobReference(
40  blob_out + '_slope', model.param_init_net)
41 
42  model.AddParameter(slope)
43 
44  return model.net.PRelu([blob_in, slope], [blob_out])
45 
46 
47 def relu(model, blob_in, blob_out, use_cudnn=False, order="NCHW", **kwargs):
48  """Relu."""
49  if use_cudnn:
50  kwargs['engine'] = 'CUDNN'
51  return model.net.Relu(blob_in, blob_out, order=order, **kwargs)
52 
53 
54 def tanh(model, blob_in, blob_out, use_cudnn=False, order="NCHW", **kwargs):
55  """Tanh."""
56  if use_cudnn:
57  kwargs['engine'] = 'CUDNN'
58  return model.net.Tanh(blob_in, blob_out, order=order, **kwargs)