Caffe2 - Python API
A deep learning, cross platform ML framework
tags.py
1 ## @package tags
2 # Module caffe2.python.layers.tags
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import six
9 
10 from caffe2.python import context
11 
12 
13 @context.define_context(allow_default=True)
14 class TagContext(object):
15  """
16  Scope driven way to provide tags to the layers.
17  """
18 
19  def __init__(self, tags=None):
20  # Tags is expected to be list to keep order of adding/removing things
21  self.tags = tags or []
22 
23  def add_tags(self, tags):
24  self.tags.extend(tags)
25 
26  def remove_tags(self, tags):
27  assert self.tags[-len(tags):] == tags
28  self.tags = self.tags[:-len(tags)]
29 
30 
31 class Tags(object):
32  # TODO(amalevich): Tags might need to live in their own contexts, add this
33  # split later
34  EXCLUDE_FROM_TRAIN = 'exclude_from_train'
35  EXCLUDE_FROM_EVAL = 'exclude_from_eval'
36  EXCLUDE_FROM_PREDICTION = 'exclude_from_prediction'
37  EXCLUDE_FROM_ACCUMULATE_PRED = 'exclude_from_accumulate_pred'
38  PREPROCESSING = 'preprocessing'
39  HANDLE_AS_SPARSE_LAYER = 'handle_as_sparse_layer'
40  PREFER_GPU = 'prefer_gpu'
41  CPU_ONLY = 'cpu_only'
42 
43  # The following three tags are hints to **distributed training framework**.
44  """
45  Indicates a layer contains a sparse shardable parameter. The parameter
46  should be sharded nd operators on those parameters should be done on
47  distributed parameter servers.
48  """
49  SPARSE_SHARDED = 'sparse_sharded'
50  """
51  Indicates a layer contains a sparse parameters among others, and that the
52  parameters should not be sharded (i.e. should be placed together on a node).
53  """
54  SPARSE_DONT_SHARD = 'sparse_dont_shard'
55  """
56  Used to manually indicate a component for an operator. Parameters for
57  all operators with the same component should be colocated on the same
58  parameter server.
59  """
60  COMPONENT = 'component:'
61  """
62  Valid tag prefixes for distributed training framework.
63  """
64  """
65  Used to pass on info to the 'extra_info' field in the net
66  Proto. Typically to provide info for distributed training.
67  """
68  EXTRA_INFO = 'extra_info:'
69  """
70  An empty tag, used to make conditional statement on with(Tags) block more concise
71  """
72  EMPTY_TAG = 'empty_tag'
73 
74  DT_TAGS = (SPARSE_SHARDED, SPARSE_DONT_SHARD, COMPONENT)
75 
76  # In certain cases we want to have different schema for training and
77  # prediction, as an example in prediction we might need to have only
78  # subset of ids present in the orignal schema. This tag is one of the ways
79  # to mark operators that will be removed from prediction and should
80  # override schema for predictors.
81  PREDICTION_SCHEMA = 'prediction_schema'
82 
83  # This is to mark layers in the feature transform process.
84  FEATURE_TRANSFORM = 'feature_transform'
85  # This is to mark the output layers in the feature transform process
86  FEATURE_TRANSFORM_SCHEMA = 'feature_transform_schema'
87 
88  def __init__(self, tags):
89  if not isinstance(tags, list):
90  tags = [tags]
91  self.tags = tags
92 
93  def __enter__(self):
94  TagContext.current().add_tags(self.tags)
95  return self
96 
97  def __exit__(self, type, value, traceback):
98  TagContext.current().remove_tags(self.tags)
99 
100  def __call__(self, func):
101  @six.wraps(func)
102  def wrapper(*args, **kwargs):
103  with self:
104  return func(*args, **kwargs)
105  return wrapper
106 
107 
108 Tags.TRAIN_ONLY = [Tags.EXCLUDE_FROM_PREDICTION, Tags.EXCLUDE_FROM_EVAL,
109  Tags.EXCLUDE_FROM_ACCUMULATE_PRED]
110 Tags.EVAL_ONLY = [Tags.EXCLUDE_FROM_PREDICTION, Tags.EXCLUDE_FROM_TRAIN,
111  Tags.EXCLUDE_FROM_ACCUMULATE_PRED]
112 Tags.PREDICTION_ONLY = [Tags.EXCLUDE_FROM_TRAIN, Tags.EXCLUDE_FROM_EVAL,
113  Tags.EXCLUDE_FROM_ACCUMULATE_PRED]