3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.proto.caffe2_pb2
import OperatorDef, NetDef
12 from collections
import defaultdict
13 from contextlib
import contextmanager
15 from future.utils
import viewkeys
16 from itertools
import chain
17 from six
import binary_type, text_type
22 def register(cls, Type):
23 if not(hasattr(cls,
'visitors')):
27 '{} already registered!'.format(Type)
35 def __call__(self, obj, *args, **kwargs):
40 if Type
not in self.__class__.visitors:
41 raise TypeError(
'%s: unsupported object type: %s' % (
42 self.__class__.__name__, Type))
44 func = self.__class__.visitors[Type]
45 return func(self, obj, *args, **kwargs)
49 PREFIXES_TO_IGNORE = {
'distributed_ctx_init'}
52 self.
workspaces = defaultdict(
lambda: defaultdict(
lambda: 0))
60 def set_workspace(self, node=None, ws=None, do_copy=False):
63 elif node
is not None:
69 self.workspace_ctx.append(ws)
73 def define_blob(self, blob):
76 def need_blob(self, blob):
77 if any(blob.startswith(p)
for p
in Analyzer.PREFIXES_TO_IGNORE):
79 assert blob
in self.
workspace,
'Blob undefined: %s' % blob
82 @Analyzer.register(OperatorDef)
83 def analyze_op(analyzer, op):
87 analyzer.define_blob(x)
90 @Analyzer.register(Net)
91 def analyze_net(analyzer, net):
92 for x
in net.Proto().op:
96 @Analyzer.register(ExecutionStep)
97 def analyze_step(analyzer, step):
99 with analyzer.set_workspace(do_copy=proto.create_workspace):
101 with analyzer.set_workspace(do_copy=
True):
102 analyzer(step.get_net(proto.report_net))
103 all_new_blobs = set()
104 substeps = step.Substeps() + [step.get_net(n)
for n
in proto.network]
105 for substep
in substeps:
106 with analyzer.set_workspace(
107 do_copy=proto.concurrent_substeps)
as ws_in:
109 if proto.should_stop_blob:
110 analyzer.need_blob(proto.should_stop_blob)
111 if proto.concurrent_substeps:
112 new_blobs = set(viewkeys(ws_in)) - set(viewkeys(analyzer.workspace))
113 assert len(all_new_blobs & new_blobs) == 0, (
114 'Error: Blobs created by multiple parallel steps: %s' % (
115 ', '.join(all_new_blobs & new_blobs)))
116 all_new_blobs |= new_blobs
117 for x
in all_new_blobs:
118 analyzer.define_blob(x)
121 @Analyzer.register(Task)
122 def analyze_task(analyzer, task):
124 step = task.get_step()
125 plan =
Plan(task.node)
127 proto_len = len(plan.Proto().SerializeToString())
128 assert proto_len < 2 ** 26, (
129 'Due to a protobuf limitation, serialized tasks must be smaller ' 130 'than 64Mb, but this task has {} bytes.' % proto_len)
132 is_private = task.workspace_type() != WorkspaceType.GLOBAL
133 with analyzer.set_workspace(do_copy=is_private):
137 @Analyzer.register(TaskGroup)
138 def analyze_task_group(analyzer, tg):
139 for task
in tg.tasks_by_node().tasks():
140 with analyzer.set_workspace(node=task.node):
144 @Analyzer.register(Job)
145 def analyze_job(analyzer, job):
146 analyzer(job.init_group)
147 analyzer(job.epoch_group)
152 Given a Job, visits all the execution steps making sure that: 153 - no undefined blobs will be found during excution 154 - no blob with same name is defined in concurrent steps 168 self.
add(
'with %s:' % text)
170 self._lines_in_context.append(0)
180 self.lines.append((
' ' * self.
_indent) + text)
183 return '\n'.join(self.
lines)
187 def __init__(self, factor_prefixes=False, c2_syntax=True):
188 super(Visitor, self).__init__()
189 super(Text, self).__init__()
195 def _sanitize_str(s):
196 if isinstance(s, text_type):
198 elif isinstance(s, binary_type):
199 sanitized = s.decode(
'ascii', errors=
'ignore')
202 if len(sanitized) < 64:
203 return "'%s'" % sanitized
205 return "'%s'" % sanitized[:64] +
'...<+len=%d>' % (len(sanitized) - 64)
209 if arg.HasField(
'f'):
211 if arg.HasField(
'i'):
213 if arg.HasField(
's'):
214 return _sanitize_str(arg.s)
216 return str(list(arg.floats))
218 return str(list(arg.ints))
220 return str([_sanitize_str(s)
for s
in arg.strings])
225 "Given a list of strings, returns the longest common prefix" 230 for i, c
in enumerate(s1):
236 def format_value(val):
237 if isinstance(val, list):
238 return '[%s]' %
', '.join(
"'%s'" % str(v)
for v
in val)
243 def factor_prefix(vals, do_it):
244 vals = [format_value(v)
for v
in vals]
245 prefix = commonprefix(vals)
if len(vals) > 1
and do_it
else '' 246 joined =
', '.join(v[len(prefix):]
for v
in vals)
247 return '%s[%s]' % (prefix, joined)
if prefix
else joined
250 def call(op, inputs=None, outputs=None, factor_prefixes=False):
254 inputs_v = [a
for a
in inputs
if not isinstance(a, tuple)]
255 inputs_kv = [a
for a
in inputs
if isinstance(a, tuple)]
259 [factor_prefix(inputs_v, factor_prefixes)],
260 (
'%s=%s' % kv
for kv
in inputs_kv),
264 call =
'%s(%s)' % (op, inputs)
265 return call
if not outputs
else '%s = %s' % (
266 factor_prefix(outputs, factor_prefixes), call)
269 def format_device_option(dev_opt):
270 if not dev_opt
or not (
271 dev_opt.device_type
or dev_opt.device_id
or dev_opt.node_name):
275 [dev_opt.device_type, dev_opt.device_id,
"'%s'" % dev_opt.node_name])
278 @Printer.register(OperatorDef)
279 def print_op(text, op):
280 args = [(a.name, _arg_val(a))
for a
in op.arg]
281 dev_opt_txt = format_device_option(op.device_option)
283 args.append((
'device_option', dev_opt_txt))
287 text.c2_net_name +
'.' + op.type,
288 [list(op.input), list(op.output)] + args))
292 list(op.input) + args,
294 factor_prefixes=text.factor_prefixes))
296 if arg.HasField(
'n'):
297 with text.context(
'arg: %s' % arg.name):
301 @Printer.register(NetDef)
302 def print_net_def(text, net_def):
304 text.add(call(
'core.Net', [
"'%s'" % net_def.name], [net_def.name]))
305 text.c2_net_name = net_def.name
307 text.add(
'# net: %s' % net_def.name)
308 for op
in net_def.op:
311 text.c2_net_name =
None 314 @Printer.register(Net)
315 def print_net(text, net):
319 def _get_step_context(step):
321 if proto.should_stop_blob:
322 return call(
'loop'),
False 323 if proto.num_iter
and proto.num_iter != 1:
324 return call(
'loop', [proto.num_iter]),
False 325 if proto.num_concurrent_instances > 1:
328 [(
'num_instances', proto.num_concurrent_instances)]),
329 len(step.Substeps()) > 1)
330 concurrent = proto.concurrent_substeps
and len(step.Substeps()) > 1
332 return call(
'parallel'),
True 334 return call(
'run_once'),
False 338 @Printer.register(ExecutionStep)
339 def print_step(text, step):
341 step_ctx, do_substep = _get_step_context(step)
342 with text.context(step_ctx):
344 with text.context(call(
'report_net', [proto.report_interval])):
345 text(step.get_net(proto.report_net))
346 substeps = step.Substeps() + [step.get_net(n)
for n
in proto.network]
347 for substep
in substeps:
349 substep.Proto()
if isinstance(substep, ExecutionStep)
else None)
350 if sub_proto
is not None and sub_proto.run_every_ms:
353 [str(substep), (
'interval_ms', sub_proto.run_every_ms)])
357 if sub_proto
is not None and sub_proto.create_workspace
else 359 substep_ctx = call(title, [str(substep)])
362 with text.context(substep_ctx):
364 if proto.should_stop_blob:
365 text.add(call(
'yield stop_if', [proto.should_stop_blob]))
368 def _print_task_output(x):
369 assert isinstance(x, TaskOutput)
370 return 'Output[' +
', '.join(str(x)
for x
in x.names) +
']' 373 @Printer.register(Task)
374 def print_task(text, task):
375 outs =
', '.join(_print_task_output(o)
for o
in task.outputs())
376 context = [(
'node', task.node), (
'name', task.name), (
'outputs', outs)]
377 with text.context(call(
'Task', context)):
378 text(task.get_step())
381 @Printer.register(TaskGroup)
382 def print_task_group(text, tg, header=None):
383 with text.context(header
or call(
'TaskGroup')):
384 for task
in tg.tasks_by_node().tasks():
388 @Printer.register(Job)
389 def print_job(text, job):
390 text(job.init_group,
'Job.current().init_group')
391 text(job.epoch_group,
'Job.current().epoch_group')
392 with text.context(
'Job.current().stop_conditions'):
393 for out
in job.stop_conditions:
394 text.add(_print_task_output(out))
395 text(job.download_group,
'Job.current().download_group')
396 text(job.exit_group,
'Job.current().exit_group')
399 def to_string(obj, **kwargs):
401 Given a Net, ExecutionStep, Task, TaskGroup or Job, produces a string 402 with detailed description of the execution steps. 411 Given a Net, produce another net that logs info about the operator call 412 before each operator execution. Use for debugging purposes. 414 assert isinstance(net, Net)
415 debug_net =
Net(str(net))
416 assert isinstance(net, Net)
417 for op
in net.Proto().op:
420 debug_net.LogInfo(str(text))
421 debug_net.Proto().op.extend([op])
Module caffe2.python.workspace.
Module caffe2.python.context.