Caffe2 - Python API
A deep learning, cross platform ML framework
checkpoint.py
1 ## @package checkpoint
2 # Module caffe2.python.checkpoint
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import os
9 import logging
10 from caffe2.python import core, context
11 from caffe2.python.net_builder import ops
12 from caffe2.python.task import (
13  final_output,
14  Node,
15  Task,
16  TaskGroup,
17  TaskOutput,
18  WorkspaceType,
19 )
20 
21 logger = logging.getLogger(__name__)
22 logger.setLevel(logging.INFO)
23 
24 
25 
26 @context.define_context()
27 class Job(object):
28  """
29  A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the
30  `exit_group` which will be run by a JobRunner.
31 
32  The `init_group` will be run only once at startup. Its role is to
33  initialize globally persistent blobs such as model weights, accumulators
34  and data file lists.
35 
36  The `epoch_group` will be run in a loop after init_group. The loop will
37  exit when any of the stop signals added with `add_stop_condition` is True
38  at the end of an epoch.
39 
40  The download_group will be run only once, after all the executions of
41  epoch_group finish. Its role is to collect the distribute scattered
42  parameters back after training.
43 
44  The `exit_group` will be run only once at the very end of the job, the
45  role of this group is to save the results of training in the end of the job.
46 
47  Jobs are context-driven, so that Tasks can be added to the active Job
48  without having to explicitly pass the job object around.
49 
50  Example of usage:
51 
52  def build_reader(partitions):
53  with Job.current().init_group:
54  reader = HiveReader(init_reader, ..., partitions)
55  Task(step=init_reader)
56  with Job.current().epoch_group:
57  limited_reader = ReaderWithLimit(reader, num_iter=10000)
58  data_queue = pipe(limited_reader, num_threads=8)
59  Job.current().add_stop_condition(limited_reader.data_finished())
60  return data_queue
61 
62  def build_hogwild_trainer(reader, model):
63  with Job.current().init_group:
64  Task(step=model.param_init_net)
65  with Job.current().epoch_group:
66  pipe(reader, processor=model, num_threads=8)
67  with Job.current().exit_group:
68  Task(step=model.save_model_net)
69 
70  with Job() as job:
71  reader = build_reader(partitions)
72  model = build_model(params)
73  build_hogwild_trainer(reader, model)
74  """
75  def __init__(self,
76  init_group=None, epoch_group=None,
77  download_group=None, exit_group=None,
78  stop_conditions=None, nodes_to_checkpoint=None):
79  self.init_group = init_group or TaskGroup(
80  workspace_type=WorkspaceType.GLOBAL)
81  self.epoch_group = epoch_group or TaskGroup()
82  self.download_group = download_group or TaskGroup()
83  self.exit_group = exit_group or TaskGroup()
84  self.stop_conditions = stop_conditions or []
85  self._nodes_to_checkpoint = nodes_to_checkpoint
86 
87  def nodes_to_checkpoint(self):
88  if self._nodes_to_checkpoint:
89  return self._nodes_to_checkpoint
90  else:
91  return self.init_group.used_nodes()
92 
93  def compile(self, session_class):
95  self.init_group = session_class.compile(self.init_group)
96  self.epoch_group = session_class.compile(self.epoch_group)
97  self.download_group = session_class.compile(self.download_group)
98  self.exit_group = session_class.compile(self.exit_group)
99 
100  def __enter__(self):
101  self.epoch_group.__enter__()
102  return self
103 
104  def __exit__(self, *args):
105  self.epoch_group.__exit__()
106 
107  def add_stop_condition(self, output):
108  if isinstance(output, core.BlobReference):
109  t = Task(outputs=[output], group=self.epoch_group)
110  output = t.outputs()[0]
111  assert isinstance(output, TaskOutput)
112  self.stop_conditions.append(output)
113 
114 
115 def get_ckpt_filename(node_name, epoch):
116  """Returns the checkpoint filename.
117 
118  Args:
119  node_name: A string. The name of the node.
120  epoch: An integer. The checkpoint epoch.
121 
122  Returns:
123  ckpt_filename: A string. The filename of the checkpoint.
124  """
125  return node_name + '.' + str(epoch)
126 
127 
128 def db_name(epoch, node_name, db_prefix, path_prefix=None):
129  """Returns the full db name where checkpoint files are saved.
130 
131  Args:
132  epoch: An integer. The checkpoint epoch.
133  node_name: A string. The name of the node.
134  db_prefix: A string. The prefix used to construct full db name.
135  path_prefix: A string. Optional param used to construct db name or path
136  where checkpoint files are are stored.
137  Returns:
138  db_name: A string. The absolute path of full_db_name where checkpoint
139  files are saved
140  """
141  if path_prefix:
142  db_name = path_prefix + get_ckpt_filename(node_name, epoch)
143  else:
144  ckpt_filename = get_ckpt_filename(node_name, epoch)
145  db_name = os.path.join(db_prefix, ckpt_filename)
146  return db_name
147 
148 
149 class CheckpointManager(object):
150  """
151  Controls saving and loading of workspaces on every epoch boundary of a job.
152  If a CheckpointManager instance is passed to JobRunner, then JobRunner will
153  call `init`, `read` and `save` at different moments in between epoch runs.
154 
155  Args:
156  db_prefix: The prefix used to construct full db name. Since `absolute_path`
157  is set to True, this will be used as db_name in SaveOp.
158  node_name: Name of the node where this checkpoint_manager is used.
159  db_type: Type of database to use for storing checkpoint.
160  metadata_handler: An optional object capable of reading/writing
161  checkpoint info in storage of choice.
162  """
163 
164  BLOB_NAMES = "blob_names"
165 
166  def __init__(self, db_prefix, node_name, db_type, metadata_handler=None):
167  self._db_prefix = db_prefix
168  self._node_name = node_name
169  self._db_type = db_type
170  self._metadata_handler = metadata_handler
171  # make sure these blobs are the first in the checkpoint file.
172  self._net = core.Net('!!checkpoint_mngr')
173  self._blob_names = self._net.AddExternalInput(self.BLOB_NAMES)
174  self._names_output = None
175  self._path_prefix = None
176  self._path_type = None
177  self._current_db_name = None
178  self._current_checkpoint_duration = None
179 
180  """
181  Initialize the checkpoint manager. Determines all blobs that need to be saved
182  or loads from a checkpoint.
183 
184  Args:
185  nodes: An array of nodes where this checkpoint manager is running. Should
186  only contain a single node.
187  retrieve_from_epoch: Set to a number to load blobs from this epoch.
188  path_prefix: Used to construct db name or path where checkpoint files are
189  stored.
190  path_type: Indicate the type of path where checkpoint files are stored.
191  """
192  def init(
193  self,
194  nodes=None,
195  retrieve_from_epoch=None,
196  path_prefix=None,
197  path_type=None
198  ):
199  """
200  Build a Task that will be run once after the job's `init_group` is run.
201  This task will determine which blobs need to be checkpointed.
202  If retrieve_from_epoch is not None, then the checkpoint metadata is
203  retrieved from a previously saved checkpoint.
204  """
205  assert nodes is None or len(nodes) == 1, (
206  'CheckpointManager only supports single node.')
207 
208  with Task(outputs=[self._blob_names]) as task:
209  if retrieve_from_epoch is None:
210  ops.GetAllBlobNames(
211  [],
212  self._blob_names,
213  include_shared=False)
214  else:
215  full_db_name = db_name(retrieve_from_epoch,
216  self._node_name, self._db_prefix, path_prefix)
217  db_type = path_type or self._db_type
218  logger.info("Initializing checkpoints from = %s"
219  % full_db_name)
220  ops.Load(
221  [], self._blob_names,
222  db=full_db_name,
223  db_type=db_type,
224  absolute_path=True,
225  keep_device=True,
226  )
227  self._names_output = task.outputs()[0]
228  return task
229 
230  def blob_list(self):
231  assert self._names_output
232  return self._names_output.fetch().tolist()
233 
234  def _timed_task(self, cp_op_name, add_op):
235  """
236  Build a Task that will measure the time span of checkpoint operations,
237  once operation is done, time can be read from _current_checkpoint_duration.
238 
239  Args:
240  cp_op_name: A string name of the checkpoint operation.
241  add_op: A functor to add the checkpoint operation.
242 
243  Returns:
244  A task with timer.
245  """
246  with Task(name=cp_op_name) as task:
247  with ops.task_init():
248  timer = ops.TimerBegin([], counter_name=self._node_name)
249  add_op()
250  with ops.task_exit():
251  time_span_blob = ops.TimerGetAndEnd(timer)
252  self._current_checkpoint_duration = final_output(time_span_blob)
253  return task
254 
255  def collect_checkpoint_stats(self, stats):
256  """
257  Add one checkpoint stats into the stats.
258 
259  Args:
260  stats: A dict of checkpoint stats that will be reported.
261  """
263  stats[self._current_db_name] = self._current_checkpoint_duration.fetch()[0]
264  else:
265  logger.info(
266  "Failed to collect checkpoint stats: {}".format(
267  self._current_db_name
268  )
269  )
270 
271  def load(self, epoch, path_prefix=None, path_type=None):
272  """
273  Build a Task that will be run by JobRunner when the job is to be
274  resumed from a given epoch. This task will run a Load op that will
275  load and deserialize all relevant blobs from a persistent storage.
276  """
277  self._current_db_name = db_name(
278  epoch, self._node_name, self._db_prefix, path_prefix
279  )
280  db_type = path_type or self._db_type
281  logger.info("Loading checkpoints from = %s" % self._current_db_name)
282 
283  def add_op():
284  ops.Load(
285  [],
286  self.blob_list(),
287  db=self._current_db_name,
288  db_type=db_type,
289  absolute_path=True,
290  keep_device=True,
291  )
292 
293  return self._timed_task('checkpoint_load', add_op)
294 
295  def load_blobs_from_checkpoint(self, blob_names, epoch):
296  """
297  Builds a Task that loads only the necessary blobs from a checkpoint of
298  the given epoch. The necessary blobs are given in the blob_names
299  argument.
300 
301  Args:
302  blob_names: A list of strings. Each string is the name of a
303  blob.
304  epoch: The checkpoint epoch to load from.
305 
306  Returns:
307  A Task which loads the specified blobs from the checkpoint of the
308  given epoch.
309  """
310  self._current_db_name = db_name(epoch, self._node_name, self._db_prefix)
311  logger.info('Load from %s' % self._current_db_name)
312 
313  def add_op():
314  ops.Load(
315  [],
316  blob_names,
317  db=self._current_db_name,
318  db_type=self._db_type,
319  absolute_path=True,
320  allow_incomplete=True)
321 
322  return self._timed_task('checkpoint_partial_load', add_op)
323 
324  def check_db_exists(self, epoch):
325  logger.info('Check existence of %s' %
326  db_name(epoch, self._node_name, self._db_prefix))
327  with Task() as task:
328  existence = ops.Const(False)
329  ops.DBExists(
330  [],
331  [existence],
332  db_name=db_name(epoch, self._node_name, self._db_prefix),
333  db_type=self._db_type,
334  absolute_path=True)
335  task.add_output(existence)
336  return task
337 
338  def report_checkpoint_stats(self, action_name):
339  """
340  Report checkpoint operation stats for current node.
341 
342  Args:
343  action_name: A string of the name of checkpoint operation.
344  """
345  all_stats = {}
346  self.collect_checkpoint_stats(all_stats)
347  if self._metadata_handler:
348  self._metadata_handler.report(action_name, all_stats)
349 
350  def save(self, epoch):
351  """
352  Build a Task that is run once after `init_group` and after each
353  epoch is run. This will execute a Save ops to serialize and persist
354  blobs present in the global workspace.
355  """
356  self._current_db_name = db_name(epoch, self._node_name, self._db_prefix)
357  logger.info('Saving to %s' % self._current_db_name)
358 
359  def add_op():
360  ops.Save(
361  self.blob_list(), [],
362  db=self._current_db_name,
363  db_type=self._db_type,
364  absolute_path=True)
365 
366  return self._timed_task('checkpoint_save', add_op)
367 
368  def write_checkpoint_metadata(self, epoch):
369  """
370  Write metadata for checkpoint
371 
372  Args:
373  epoch: An integer. The epoch-id for which checkpoint metadata is
374  written
375  """
376  if self._metadata_handler is not None:
377  self._metadata_handler.write(epoch=epoch)
378 
379  def get_resume_from_epoch_id(self, user_epoch=None):
380  """
381  Identify the epoch-id from which Job must resume
382 
383  Args:
384  user_epoch: An integer. Optional parameter for user to explicitly
385  identify the epoch-id to load checkpoint from
386  Retruns:
387  epoch: the epoch-id to load checkpoints from
388  or None if no checkpoints were written
389  """
390  last_epoch = user_epoch
391  if self._metadata_handler is not None:
392  last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
393  return last_epoch
394 
395  def set_params(self, nodes, path_prefix=None, path_type=None):
396  """Set parameters associated with CP manager
397 
398  Args:
399  nodes: An array of nodes where this checkpoint manager is running.
400  path_prefix: Used to construct db name or path where checkpoint files are
401  stored.
402  path_type: Indicate the type of path where checkpoint files are stored.
403  """
404  if path_prefix:
405  self._path_prefix = path_prefix
406  if path_type:
407  self._path_type = path_type
408  if self._metadata_handler:
409  self._metadata_handler.set_params(
410  db_prefix=self._db_prefix,
411  db_type=self._db_type,
412  node_names=[str(self._node_name)],
413  path_prefix=self._path_prefix,
414  path_type=self._path_type)
415 
416  def cp_accessible(self, epoch=None):
417  """Returns True if Checkpoint data is accessible
418 
419  Args:
420  epoch: An integer. The epoch of the checkpoint. If None,
421  it implies we need to check if checkpoint directory is accessible
422 
423  Returns:
424  is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
425  """
426  if self._metadata_handler is not None:
427  return self._metadata_handler.cp_accessible(epoch)
428  else:
429  return True
430 
431 
433  """
434  Coordinates checkpointing and checkpointing across multiple nodes.
435  Each of `init`, `load` and `save` will build TaskGroups which will
436  trigger checkpointing on each of the nodes involved in a distributed job.
437 
438  Args:
439  db_prefix: The prefix used to construct full db name. Since `absolute_path`
440  is set to True, this will be used as db_name in SaveOp.
441  db_type: Type of database to use for storing checkpoint.
442  metadata_handler: An optional object capable of reading/writing
443  checkpoint info in storage of choice.
444  """
445  def __init__(self, db_prefix, db_type, metadata_handler=None):
446  self._node_managers = None
447  self._db_prefix = db_prefix
448  self._db_type = db_type
449  self._metadata_handler = metadata_handler
450  self._path_prefix = None
451  self._path_type = None
452 
453  def _task_group(self, func, *args, **kw):
454  assert self._node_managers is not None, 'init must be called first.'
455  with TaskGroup(WorkspaceType.GLOBAL) as task_group:
456  for node, manager in self._node_managers:
457  with Node(node):
458  func(manager, *args, **kw)
459  return task_group
460 
461  """
462  Args:
463  nodes: An array of nodes where this checkpoint manager is running.
464  retrieve_from_epoch: Set to a number to load blobs from this epoch.
465  path_prefix: Used to construct db name or path where checkpoint files are
466  stored.
467  path_type: Indicate the type of path where checkpoint files are stored.
468  """
469  def init(
470  self, nodes, retrieve_from_epoch=None, path_prefix=None, path_type=None
471  ):
472  if self._node_managers is not None:
473  assert [node for node, _ in self._node_managers] == nodes
474  return TaskGroup(WorkspaceType.GLOBAL)
475  self._node_managers = []
476  for node in nodes:
477  with Node(node):
478  manager = CheckpointManager(
479  db_prefix=self._db_prefix,
480  node_name=str(node),
481  db_type=self._db_type)
482  self._node_managers.append((node, manager))
483  return self._task_group(
484  CheckpointManager.init,
485  nodes=[node],
486  retrieve_from_epoch=retrieve_from_epoch,
487  path_prefix=path_prefix,
488  path_type=path_type)
489 
490  def load(self, epoch, path_prefix=None, path_type=None):
491  return self._task_group(
492  CheckpointManager.load,
493  epoch,
494  path_prefix=path_prefix,
495  path_type=path_type)
496 
497  def load_blobs_locally(self, nodes, blob_names, epoch, session):
498  """Loads the necessary blobs from the checkpoints to the current node.
499 
500  Args:
501  blob_names: A list of strings. Each string is the name of a
502  blob.
503  epoch: An integer. The checkpoint epoch to load from.
504  session: A Session object to execute the Load ops.
505  """
506  if self._node_managers is not None:
507  assert [node for node, _ in self._node_managers] == nodes
508  else:
509  self._node_managers = []
510  for node in nodes:
511  with Node(node):
512  manager = CheckpointManager(
513  db_prefix=self._db_prefix,
514  node_name=str(node),
515  db_type=self._db_type)
516  self._node_managers.append((node, manager))
517  assert self._node_managers is not None, 'must initialize node managers'
518  for _, manager in self._node_managers:
519  existence_task = manager.check_db_exists(epoch)
520  session.run(existence_task)
521  existence = existence_task.outputs()[0].fetch()
522  if not existence:
523  logger.info('DB %s does not exist!' %
524  db_name(epoch, manager._node_name, manager._db_prefix))
525  return False
526  load_task = manager.load_blobs_from_checkpoint(blob_names, epoch)
527  session.run(load_task)
528  logger.info('Successfully loaded from checkpoints.')
529  return True
530 
531  def get_ckpt_db_name(self, node_name, epoch):
532  """Returns the DB name of the given node and the given epoch.
533 
534  The DB name is effectively the checkpoint path of the given node and
535  the given epoch.
536 
537  Args:
538  node_name: A string. The node name of interest.
539  epoch: An integer. The epoch of the checkpoint.
540 
541  Returns:
542  checkpoint_db_name: A string. The checkpoint path of the given
543  node and the given epoch.
544  """
545  for node, manager in self._node_managers:
546  if str(node) == node_name:
547  return db_name(epoch, manager._node_name, manager._db_prefix)
548 
549  def report_checkpoint_stats(self, action_name):
550  """
551  Report the checkpoint stats for all the nodes, we need to aggregate all
552  the node's stats together so that we know which node's checkpoint
553  operation dominates.
554 
555  Args:
556  action_name: A string of the name of checkpoint operation.
557  """
558  all_stats = {}
559  for _, manager in self._node_managers:
560  manager.collect_checkpoint_stats(all_stats)
561  logger.debug("checkpoint stats: {}".format(all_stats))
562  if self._metadata_handler:
563  self._metadata_handler.report(action_name, all_stats)
564 
565  def save(self, epoch):
566  """
567  Build a Task that will execute a Save ops to serialize and persist
568  blobs present in the global workspace.
569  """
570  return self._task_group(CheckpointManager.save, epoch)
571 
572  def write_checkpoint_metadata(self, epoch):
573  """
574  Write metadata for checkpoint
575 
576  Args:
577  epoch: An integer. The epoch-id for which checkpoint metadata is
578  written
579  """
580  if self._metadata_handler is not None:
581  self._metadata_handler.write(epoch=epoch)
582 
583  def get_resume_from_epoch_id(self, user_epoch=None):
584  """
585  Identify the epoch-id from which Job must resume
586 
587  Args:
588  user_epoch: An integer. Optional parameter for user to explicitly
589  identify the epoch-id to load checkpoint from
590  Retruns:
591  epoch: the epoch-id to load checkpoints from
592  or None if no checkpoints were written
593  """
594  last_epoch = user_epoch
595  if self._metadata_handler is not None:
596  last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
597  return last_epoch
598 
599  def set_params(self, nodes, path_prefix=None, path_type=None):
600  """Set parameters associated with CP manager
601 
602  Args:
603  nodes: An array of nodes where this checkpoint manager is running.
604  path_prefix: Used to construct db name or path where checkpoint files are
605  stored.
606  path_type: Indicate the type of path where checkpoint files are stored.
607  """
608  self._node_names = [str(node) for node in nodes]
609  if path_prefix:
610  self._path_prefix = path_prefix
611  if path_type:
612  self._path_type = path_type
613  if self._metadata_handler:
614  self._metadata_handler.set_params(
615  db_prefix=self._db_prefix,
616  db_type=self._db_type,
617  node_names=self._node_names,
618  path_prefix=self._path_prefix,
619  path_type=self._path_type)
620 
621  def cp_accessible(self, epoch=None):
622  """Returns True if Checkpoint data is accessible
623 
624  Args:
625  epoch: An integer. The epoch of the checkpoint. If None,
626  it implies we need to check if checkpoint directory is accessible
627 
628  Returns:
629  is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
630  """
631  if self._metadata_handler is not None:
632  return self._metadata_handler.cp_accessible(epoch)
633  else:
634  return True
635 
636 
638  """A simple class to upload checkpoints."""
639  def build(self, epoch, checkpoint_manager):
640  """Builds the task group to upload checkpoints.
641 
642  Args:
643  epoch: An integer. The checkpoint epoch to be uploaded.
644  checkpoint_manager: Can be a CheckpointManager for single machine
645  or a MultiNodeCheckpointManager for multi-machine. The manager
646  that initializes/saves/loads checkpoints.
647 
648  Raises:
649  NotImplementedError: This base class only has the interface,
650  the implementation will be in the subclasses.
651  """
652  raise NotImplementedError()
653 
654 
655 class JobRunner(object):
656  """
657  Implement the runtime logic for jobs with checkpointing at the level of
658  epoch. Can be used to run either single-host or distributed jobs. Job
659  runner is a callable to be called once from the master, passing a session
660  as an argument. This call will block until the Job execution is complete.
661 
662  If a checkpoint_manager is passed, checkpoints will be taken after
663  initialization and after each epoch execution. If, in addition,
664  `resume_from_epoch` is an epoch number, the corresponding checkpoint will
665  be loaded and job execution will continue from the given epoch. In
666  this case, the job's init_group will not be run.
667 
668  Refer to checkpoint_test.py for an example.
669  """
670  def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None,
671  upload_task_group_builder=None):
672  """Initializes the JobRunner.
673 
674  Args:
675  job: A Job object. The job to be executed.
676  checkpoint_manager: Can be a CheckpointManager for single machine
677  or a MultiNodeCheckpointManager for multi-machine. The manager
678  that initializes/saves/loads checkpoints.
679  resume_from_epoch: An integer. The epoch to resume from.
680  upload_task_group_builder: A subclass of the
681  UploadTaskGroupBuilder. Creates a task group to upload
682  checkpoints.
683  """
684  self.resume_from_epoch = resume_from_epoch
685  self.checkpoint_manager = checkpoint_manager
686  self.job = job
687  self.upload_task_group_builder = upload_task_group_builder
688 
689  def train(self, session):
690  """Runs the training flow.
691 
692  Args:
693  session: A Session object. Valid choises are: LocalSession,
694  LocalHostScheduler, and DistributedSession. It is used to
695  execute one TaskGroup a time.
696  """
697  # identify the epoch we must resume from
698  if self.checkpoint_manager:
699  self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
701  get_resume_from_epoch_id(self.resume_from_epoch)
702  if self.resume_from_epoch is not None:
703  logger.info('Resuming from epoch {}'.format(self.resume_from_epoch))
704 
705  # Initialize all the nodes.
706  from_scratch = self.resume_from_epoch is None
707  if from_scratch:
708  session.run(self.job.init_group)
709 
710  if self.checkpoint_manager:
711  logger.info('Preparing checkpoints ...')
712  session.run(self.checkpoint_manager.init(
713  self.job.nodes_to_checkpoint(),
714  retrieve_from_epoch=self.resume_from_epoch))
715  # Save the first checkpoint before training starts, or resume from
716  # a previously saved checkpoint.
717  if from_scratch:
718  self.save_checkpoints(0, session)
719  else:
720  logger.info('Loading checkpoints for epoch {} ...'.format(
721  self.resume_from_epoch))
722  session.run(
723  self.checkpoint_manager.load(self.resume_from_epoch))
724  self.checkpoint_manager.report_checkpoint_stats('checkpoint_load')
725  logger.info('Checkpoint loaded')
726 
727  logger.info("Finished initializing")
728 
729  # Start training.
730  epoch = 1 if from_scratch else self.resume_from_epoch + 1
731  while True:
732  logger.info('Starting epoch %d' % epoch)
733  session.run(self.job.epoch_group)
734  logger.info('Finished epoch %d' % epoch)
735  stop_conditions = [o.fetch() for o in self.job.stop_conditions]
736 
737  if self.checkpoint_manager:
738  self.save_checkpoints(epoch, session)
739 
740  if any(stop_conditions):
741  logger.info('Stopping')
742  break
743  epoch += 1
744  logger.info('Finished training')
745  # Upload the checkpoints.
746  if (self.upload_task_group_builder):
747  upload_task_group = self.upload_task_group_builder.build(
748  epoch, self.checkpoint_manager)
749  session.run(upload_task_group)
750  logger.info('Finished uploading the checkpoints')
751 
752  # Download the parameters to save
753  session.run(self.job.download_group)
754  logger.info('Finished downloading the parameters')
755 
756  # Finally run the exit step to save nets
757  session.run(self.job.exit_group)
758  logger.info('Finished running the exit group')
759  return epoch
760 
761  def load_blobs_from_checkpoints(self, blob_names, epoch, session):
762  """Loads the necessary blobs from the checkpoints.
763 
764  Checkpoints store the snapshots of the workspace in each node.
765  Sometimes we only need to load a subset of the blobs from the
766  checkpoints. One common scenario is to load only the model blobs from
767  the checkpoints for evaluation purpose. Given the names of the
768  necessary blobs, this function goes over all the checkpoints of all the
769  nodes, but only loads the blobs specified in the blob_names to the
770  current workspace.
771 
772  Args:
773  blob_names: A list of strings. Each string is the name of a
774  blob.
775  epoch: An integer. The checkpoint epoch to load from.
776  session: A Session object to execute the load ops.
777 
778  Raises:
779  ValueError: When the checkpoint manager is invalid.
780  """
781  if not self.checkpoint_manager:
782  raise ValueError('Checkpoint manager is None')
783  logger.info('Loading checkpoint for epoch {} ...'.format(epoch))
784  result = self.checkpoint_manager.load_blobs_locally(
785  self.job.nodes_to_checkpoint(), blob_names, epoch, session)
786  self.checkpoint_manager.report_checkpoint_stats('checkpoint_partial_load')
787  return result
788 
789  def save_checkpoints(self, epoch, session):
790  """Triggers operation to save checkpoints
791 
792  This method will trigger the Save ops to serialize and persist the
793  blobs present in the global workspaace.
794 
795  Args:
796  epoch: An integer. The checkpoint epoch-id that we are saving.
797  session: A Session object to execute the save ops.
798 
799  Raises:
800  ValueError: When the checkpoint manager is invalid.
801  """
802  if not self.checkpoint_manager:
803  raise ValueError('Checkpoint manager is None')
804  try:
805  is_accessible = self.checkpoint_manager.cp_accessible(epoch=None)
806  if is_accessible:
807  logger.info('Saving checkpoints for epoch {}'.format(epoch))
808  session.run(self.checkpoint_manager.save(epoch))
809  self.checkpoint_manager.write_checkpoint_metadata(epoch)
810  logger.info('Checkpoints saved')
811  self.checkpoint_manager.report_checkpoint_stats('checkpoint_save')
812  else:
813  logger.warning("Checkpoint files cannot be accessed!")
814  except Exception as ex:
815  logger.warning("Unable to write checkpoint for epoch {}. Error={}".
816  format(epoch, ex))
817 
818 
819 def epoch_limiter(job, num_epochs):
820  """
821  Creates a task that will output True when a given
822  number of epochs has finished.
823  """
824  with job.init_group:
825  init_net = core.Net('epoch_counter_init')
826  counter = init_net.CreateCounter([], init_count=num_epochs - 1)
827  Task(step=init_net)
828 
829  with job.epoch_group:
830  epoch_net = core.Net('epoch_countdown')
831  finished = epoch_net.CountDown(counter)
832  output = Task(step=epoch_net, outputs=finished).outputs()[0]
833  job.add_stop_condition(output)
def set_params(self, nodes, path_prefix=None, path_type=None)
Definition: checkpoint.py:599
def _timed_task(self, cp_op_name, add_op)
Definition: checkpoint.py:234
def init(self, nodes=None, retrieve_from_epoch=None, path_prefix=None, path_type=None)
Definition: checkpoint.py:198
def load_blobs_locally(self, nodes, blob_names, epoch, session)
Definition: checkpoint.py:497
def load_blobs_from_checkpoints(self, blob_names, epoch, session)
Definition: checkpoint.py:761
def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None, upload_task_group_builder=None)
Definition: checkpoint.py:671
def build(self, epoch, checkpoint_manager)
Definition: checkpoint.py:639
def save_checkpoints(self, epoch, session)
Definition: checkpoint.py:789
def get_resume_from_epoch_id(self, user_epoch=None)
Definition: checkpoint.py:379
def load(self, epoch, path_prefix=None, path_type=None)
Definition: checkpoint.py:271
def load_blobs_from_checkpoint(self, blob_names, epoch)
Definition: checkpoint.py:295
def set_params(self, nodes, path_prefix=None, path_type=None)
Definition: checkpoint.py:395
def get_resume_from_epoch_id(self, user_epoch=None)
Definition: checkpoint.py:583
def report_checkpoint_stats(self, action_name)
Definition: checkpoint.py:338