Caffe2 - Python API
A deep learning, cross platform ML framework
tags.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package tags
17 # Module caffe2.python.layers.tags
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 import six
24 
25 from caffe2.python import context
26 
27 
28 @context.define_context(allow_default=True)
29 class TagContext(object):
30  """
31  Scope driven way to provide tags to the layers.
32  """
33 
34  def __init__(self, tags=None):
35  # Tags is expected to be list to keep order of adding/removing things
36  self.tags = tags or []
37 
38  def add_tags(self, tags):
39  self.tags.extend(tags)
40 
41  def remove_tags(self, tags):
42  assert self.tags[-len(tags):] == tags
43  self.tags = self.tags[:-len(tags)]
44 
45 
46 class Tags(object):
47  # TODO(amalevich): Tags might need to live in their own contexts, add this
48  # split later
49  EXCLUDE_FROM_TRAIN = 'exclude_from_train'
50  EXCLUDE_FROM_EVAL = 'exclude_from_eval'
51  EXCLUDE_FROM_PREDICTION = 'exclude_from_prediction'
52  EXCLUDE_FROM_ACCUMULATE_PRED = 'exclude_from_accumulate_pred'
53  PREPROCESSING = 'preprocessing'
54  HANDLE_AS_SPARSE_LAYER = 'handle_as_sparse_layer'
55  GRADIENT_FROM_PS = 'gradient_from_ps'
56  PREFER_GPU = 'prefer_gpu'
57  CPU_ONLY = 'cpu_only'
58 
59  # The following three tags are hints to **distributed training framework**.
60  """
61  Indicates a layer contains a sparse shardable parameter. The parameter
62  should be sharded nd operators on those parameters should be done on
63  distributed parameter servers.
64  """
65  SPARSE_SHARDED = 'sparse_sharded'
66  """
67  Indicates a layer contains a sparse parameters among others, and that the
68  parameters should not be sharded (i.e. should be placed together on a node).
69  """
70  SPARSE_DONT_SHARD = 'sparse_dont_shard'
71  """
72  Used to manually indicate a component for an operator. Parameters for
73  all operators with the same component should be colocated on the same
74  parameter server.
75  """
76  COMPONENT = 'component:'
77  """
78  Valid tag prefixes for distributed training framework.
79  """
80  DT_TAGS = (SPARSE_SHARDED, SPARSE_DONT_SHARD, COMPONENT)
81 
82  # In certain cases we want to have different schema for training and
83  # prediction, as an example in prediction we might need to have only
84  # subset of ids present in the orignal schema. This tag is one of the ways
85  # to mark operators that will be removed from prediction and should
86  # override schema for predictors.
87  PREDICTION_SCHEMA = 'prediction_schema'
88 
89  def __init__(self, tags):
90  if not isinstance(tags, list):
91  tags = [tags]
92  self.tags = tags
93 
94  def __enter__(self):
95  TagContext.current().add_tags(self.tags)
96  return self
97 
98  def __exit__(self, type, value, traceback):
99  TagContext.current().remove_tags(self.tags)
100 
101  def __call__(self, func):
102  @six.wraps(func)
103  def wrapper(*args, **kwargs):
104  with self:
105  return func(*args, **kwargs)
106  return wrapper
107 
108 
109 Tags.TRAIN_ONLY = [Tags.EXCLUDE_FROM_PREDICTION, Tags.EXCLUDE_FROM_EVAL,
110  Tags.EXCLUDE_FROM_ACCUMULATE_PRED]
111 Tags.EVAL_ONLY = [Tags.EXCLUDE_FROM_PREDICTION, Tags.EXCLUDE_FROM_TRAIN,
112  Tags.EXCLUDE_FROM_ACCUMULATE_PRED]
113 Tags.PREDICTION_ONLY = [Tags.EXCLUDE_FROM_TRAIN, Tags.EXCLUDE_FROM_EVAL,
114  Tags.EXCLUDE_FROM_ACCUMULATE_PRED]