3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
13 @context.define_context(allow_default=
True)
16 Scope driven way to provide tags to the layers. 19 def __init__(self, tags=None):
21 self.
tags = tags
or []
23 def add_tags(self, tags):
24 self.tags.extend(tags)
26 def remove_tags(self, tags):
27 assert self.
tags[-len(tags):] == tags
34 EXCLUDE_FROM_TRAIN =
'exclude_from_train' 35 EXCLUDE_FROM_EVAL =
'exclude_from_eval' 36 EXCLUDE_FROM_PREDICTION =
'exclude_from_prediction' 37 EXCLUDE_FROM_ACCUMULATE_PRED =
'exclude_from_accumulate_pred' 38 PREPROCESSING =
'preprocessing' 39 HANDLE_AS_SPARSE_LAYER =
'handle_as_sparse_layer' 40 PREFER_GPU =
'prefer_gpu' 45 Indicates a layer contains a sparse shardable parameter. The parameter 46 should be sharded nd operators on those parameters should be done on 47 distributed parameter servers. 49 SPARSE_SHARDED =
'sparse_sharded' 51 Indicates a layer contains a sparse parameters among others, and that the 52 parameters should not be sharded (i.e. should be placed together on a node). 54 SPARSE_DONT_SHARD =
'sparse_dont_shard' 56 Used to manually indicate a component for an operator. Parameters for 57 all operators with the same component should be colocated on the same 60 COMPONENT =
'component:' 62 Valid tag prefixes for distributed training framework. 65 Used to pass on info to the 'extra_info' field in the net 66 Proto. Typically to provide info for distributed training. 68 EXTRA_INFO =
'extra_info:' 70 An empty tag, used to make conditional statement on with(Tags) block more concise 72 EMPTY_TAG =
'empty_tag' 74 DT_TAGS = (SPARSE_SHARDED, SPARSE_DONT_SHARD, COMPONENT)
81 PREDICTION_SCHEMA =
'prediction_schema' 84 FEATURE_TRANSFORM =
'feature_transform' 86 FEATURE_TRANSFORM_SCHEMA =
'feature_transform_schema' 88 def __init__(self, tags):
89 if not isinstance(tags, list):
94 TagContext.current().add_tags(self.
tags)
97 def __exit__(self, type, value, traceback):
98 TagContext.current().remove_tags(self.
tags)
100 def __call__(self, func):
102 def wrapper(*args, **kwargs):
104 return func(*args, **kwargs)
108 Tags.TRAIN_ONLY = [Tags.EXCLUDE_FROM_PREDICTION, Tags.EXCLUDE_FROM_EVAL,
109 Tags.EXCLUDE_FROM_ACCUMULATE_PRED]
110 Tags.EVAL_ONLY = [Tags.EXCLUDE_FROM_PREDICTION, Tags.EXCLUDE_FROM_TRAIN,
111 Tags.EXCLUDE_FROM_ACCUMULATE_PRED]
112 Tags.PREDICTION_ONLY = [Tags.EXCLUDE_FROM_TRAIN, Tags.EXCLUDE_FROM_EVAL,
113 Tags.EXCLUDE_FROM_ACCUMULATE_PRED]