Caffe2 - Python API
A deep learning, cross platform ML framework
checkpoint.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package checkpoint
17 # Module caffe2.python.checkpoint
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 import os
24 import logging
25 from caffe2.python import core, context
26 from caffe2.python.net_builder import ops
27 from caffe2.python.task import Node, Task, TaskGroup, TaskOutput, WorkspaceType
28 
29 logger = logging.getLogger(__name__)
30 logger.setLevel(logging.INFO)
31 
32 
33 @context.define_context()
34 class Job(object):
35  """
36  A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the
37  `exit_group` which will be run by a JobRunner.
38 
39  The `init_group` will be run only once at startup. Its role is to
40  initialize globally persistent blobs such as model weights, accumulators
41  and data file lists.
42 
43  The `epoch_group` will be run in a loop after init_group. The loop will
44  exit when any of the stop signals added with `add_stop_signal` is True
45  at the end of an epoch.
46 
47  The download_group will be run only once, after all the executions of
48  epoch_group finish. Its role is to collect the distribute scattered
49  parameters back after training.
50 
51  The `exit_group` will be run only once at the very end of the job, the
52  role of this group is to save the results of training in the end of the job.
53 
54  Jobs are context-driven, so that Tasks can be added to the active Job
55  without having to explicitly pass the job object around.
56 
57  Example of usage:
58 
59  def build_reader(partitions):
60  with Job.current().init_group:
61  reader = HiveReader(init_reader, ..., partitions)
62  Task(step=init_reader)
63  with Job.current().epoch_group:
64  limited_reader = ReaderWithLimit(reader, num_iter=10000)
65  data_queue = pipe(limited_reader, num_threads=8)
66  Job.current().add_stop_signal(limited_reader.data_finished())
67  return data_queue
68 
69  def build_hogwild_trainer(reader, model):
70  with Job.current().init_group:
71  Task(step=model.param_init_net)
72  with Job.current().epoch_group:
73  pipe(reader, processor=model, num_threads=8)
74  with Job.current().exit_group:
75  Task(step=model.save_model_net)
76 
77  with Job() as job:
78  reader = build_reader(partitions)
79  model = build_model(params)
80  build_hogwild_trainer(reader, model)
81  """
82  def __init__(self,
83  init_group=None, epoch_group=None,
84  download_group=None, exit_group=None,
85  stop_signals=None, nodes_to_checkpoint=None):
86  self.init_group = init_group or TaskGroup(
87  workspace_type=WorkspaceType.GLOBAL)
88  self.epoch_group = epoch_group or TaskGroup()
89  self.download_group = download_group or TaskGroup()
90  self.exit_group = exit_group or TaskGroup()
91  self.stop_signals = stop_signals or []
92  self._nodes_to_checkpoint = nodes_to_checkpoint
93 
94  def nodes_to_checkpoint(self):
95  if self._nodes_to_checkpoint:
96  return self._nodes_to_checkpoint
97  else:
98  return self.init_group.used_nodes()
99 
100  def compile(self, session_class):
101  return Job(
102  init_group=session_class.compile(self.init_group),
103  epoch_group=session_class.compile(self.epoch_group),
104  download_group=session_class.compile(self.download_group),
105  exit_group=session_class.compile(self.exit_group),
106  stop_signals=self.stop_signals,
107  nodes_to_checkpoint=self.nodes_to_checkpoint())
108 
109  def __enter__(self):
110  self.epoch_group.__enter__()
111  return self
112 
113  def __exit__(self, *args):
114  self.epoch_group.__exit__()
115 
116  def add_stop_signal(self, output):
117  if isinstance(output, core.BlobReference):
118  t = Task(outputs=[output], group=self.epoch_group)
119  output = t.outputs()[0]
120  assert isinstance(output, TaskOutput)
121  self.stop_signals.append(output)
122 
123 
124 def get_ckpt_filename(node_name, epoch):
125  """Returns the checkpoint filename.
126 
127  Args:
128  node_name: A string. The name of the node.
129  epoch: An integer. The checkpoint epoch.
130 
131  Returns:
132  ckpt_filename: A string. The filename of the checkpoint.
133  """
134  return node_name + '.' + str(epoch)
135 
136 
137 def db_name(epoch, node_name, db_prefix, path_prefix=None):
138  """Returns the full db name where checkpoint files are saved.
139 
140  Args:
141  epoch: An integer. The checkpoint epoch.
142  node_name: A string. The name of the node.
143  db_prefix: A string. The prefix used to construct full db name.
144  path_prefix: A string. Optional param used to construct db name or path
145  where checkpoint files are are stored.
146  Returns:
147  db_name: A string. The absolute path of full_db_name where checkpoint
148  files are saved
149  """
150  if path_prefix:
151  db_name = path_prefix + get_ckpt_filename(node_name, epoch)
152  else:
153  ckpt_filename = get_ckpt_filename(node_name, epoch)
154  db_name = os.path.join(db_prefix, ckpt_filename)
155  return db_name
156 
157 
158 class CheckpointManager(object):
159  """
160  Controls saving and loading of workspaces on every epoch boundary of a job.
161  If a CheckpointManager instance is passed to JobRunner, then JobRunner will
162  call `init`, `read` and `save` at different moments in between epoch runs.
163 
164  Args:
165  db_prefix: The prefix used to construct full db name. Since `absolute_path`
166  is set to True, this will be used as db_name in SaveOp.
167  node_name: Name of the node where this checkpoint_manager is used.
168  db_type: Type of database to use for storing checkpoint.
169  metadata_handler: An optional object capable of reading/writing
170  checkpoint info in storage of choice.
171  """
172  def __init__(self, db_prefix, node_name, db_type, metadata_handler=None):
173  self._db_prefix = db_prefix
174  self._node_name = node_name
175  self._db_type = db_type
176  self._metadata_handler = metadata_handler
177  # make sure these blobs are the first in the checkpoint file.
178  self._net = core.Net('!!checkpoint_mngr')
179  self._blob_names = self._net.AddExternalInput('blob_names')
180  self._names_output = None
181  self._path_prefix = None
182  self._path_type = None
183 
184  """
185  Initialize the checkpoint manager. Determines all blobs that need to be saved
186  or loads from a checkpoint.
187 
188  Args:
189  nodes: An array of nodes where this checkpoint manager is running. Should
190  only contain a single node.
191  retrieve_from_epoch: Set to a number to load blobs from this epoch.
192  path_prefix: Used to construct db name or path where checkpoint files are
193  stored.
194  path_type: Indicate the type of path where checkpoint files are stored.
195  """
196  def init(
197  self,
198  nodes=None,
199  retrieve_from_epoch=None,
200  path_prefix=None,
201  path_type=None
202  ):
203  """
204  Build a Task that will be run once after the job's `init_group` is run.
205  This task will determine which blobs need to be checkpointed.
206  If retrieve_from_epoch is not None, then the checkpoint metadata is
207  retrieved from a previously saved checkpoint.
208  """
209  assert nodes is None or len(nodes) == 1, (
210  'CheckpointManager only supports single node.')
211 
212  with Task(outputs=[self._blob_names]) as task:
213  if retrieve_from_epoch is None:
214  ops.GetAllBlobNames(
215  [],
216  self._blob_names,
217  include_shared=False)
218  else:
219  full_db_name = db_name(retrieve_from_epoch,
220  self._node_name, self._db_prefix, path_prefix)
221  db_type = path_type or self._db_type
222  logger.info("Initializing checkpoints from = %s"
223  % full_db_name)
224  ops.Load(
225  [], self._blob_names,
226  db=full_db_name,
227  db_type=db_type,
228  absolute_path=True)
229  self._names_output = task.outputs()[0]
230  return task
231 
232  def blob_list(self):
233  assert self._names_output
234  return self._names_output.fetch().tolist()
235 
236  def load(self, epoch, path_prefix=None, path_type=None):
237  """
238  Build a Task that will be run by JobRunner when the job is to be
239  resumed from a given epoch. This task will run a Load op that will
240  load and deserialize all relevant blobs from a persistent storage.
241  """
242  full_db_name = db_name(epoch, self._node_name, self._db_prefix, path_prefix)
243  db_type = path_type or self._db_type
244  logger.info("Loading checkpoints from = %s" % full_db_name)
245  with Task() as task:
246  ops.Load(
247  [],
248  self.blob_list(),
249  db=full_db_name,
250  db_type=db_type,
251  absolute_path=True)
252  return task
253 
254  def load_blobs_from_checkpoint(self, blob_names, epoch):
255  """
256  Builds a Task that loads only the necessary blobs from a checkpoint of
257  the given epoch. The necessary blobs are given in the blob_names
258  argument.
259 
260  Args:
261  blob_names: A list of strings. Each string is the name of a
262  blob.
263  epoch: The checkpoint epoch to load from.
264 
265  Returns:
266  A Task which loads the specified blobs from the checkpoint of the
267  given epoch.
268  """
269  logger.info('Load from %s' % db_name(epoch, self._node_name, self._db_prefix))
270  with Task() as task:
271  ops.Load(
272  [],
273  blob_names,
274  db=db_name(epoch, self._node_name, self._db_prefix),
275  db_type=self._db_type,
276  absolute_path=True,
277  allow_incomplete=True)
278  return task
279 
280  def check_db_exists(self, epoch):
281  logger.info('Check existence of %s' %
282  db_name(epoch, self._node_name, self._db_prefix))
283  with Task() as task:
284  existence = ops.Const(False)
285  ops.DBExists(
286  [],
287  [existence],
288  db_name=db_name(epoch, self._node_name, self._db_prefix),
289  db_type=self._db_type,
290  absolute_path=True)
291  task.add_output(existence)
292  return task
293 
294  def save(self, epoch):
295  """
296  Build a Task that is run once after `init_group` and after each
297  epoch is run. This will execute a Save ops to serialize and persist
298  blobs present in the global workspace.
299  """
300  logger.info('Saving to %s' % db_name(epoch, self._node_name, self._db_prefix))
301  with Task() as task:
302  ops.Save(
303  self.blob_list(), [],
304  db=db_name(epoch, self._node_name, self._db_prefix),
305  db_type=self._db_type, absolute_path=True)
306  return task
307 
308  def write_checkpoint_metadata(self, epoch):
309  """
310  Write metadata for checkpoint
311 
312  Args:
313  epoch: An integer. The epoch-id for which checkpoint metadata is
314  written
315  """
316  if self._metadata_handler is not None:
317  self._metadata_handler.write(epoch=epoch)
318 
319  def get_resume_from_epoch_id(self, user_epoch=None):
320  """
321  Identify the epoch-id from which Job must resume
322 
323  Args:
324  user_epoch: An integer. Optional parameter for user to explicitly
325  identify the epoch-id to load checkpoint from
326  Retruns:
327  epoch: the epoch-id to load checkpoints from
328  or None if no checkpoints were written
329  """
330  last_epoch = user_epoch
331  if self._metadata_handler is not None:
332  last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
333  return last_epoch
334 
335  def set_params(self, nodes, path_prefix=None, path_type=None):
336  """Set parameters associated with CP manager
337 
338  Args:
339  nodes: An array of nodes where this checkpoint manager is running.
340  path_prefix: Used to construct db name or path where checkpoint files are
341  stored.
342  path_type: Indicate the type of path where checkpoint files are stored.
343  """
344  if path_prefix:
345  self._path_prefix = path_prefix
346  if path_type:
347  self._path_type = path_type
348  if self._metadata_handler:
349  self._metadata_handler.set_params(
350  db_prefix=self._db_prefix,
351  db_type=self._db_type,
352  node_names=[str(self._node_name)],
353  path_prefix=self._path_prefix,
354  path_type=self._path_type)
355 
356  def cp_accessible(self, epoch=None):
357  """Returns True if Checkpoint data is accessible
358 
359  Args:
360  epoch: An integer. The epoch of the checkpoint. If None,
361  it implies we need to check if checkpoint directory is accessible
362 
363  Returns:
364  is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
365  """
366  if self._metadata_handler is not None:
367  return self._metadata_handler.cp_accessible(epoch)
368  else:
369  return True
370 
371 
373  """
374  Coordinates checkpointing and checkpointing across multiple nodes.
375  Each of `init`, `load` and `save` will build TaskGroups which will
376  trigger checkpointing on each of the nodes involved in a distributed job.
377 
378  Args:
379  db_prefix: The prefix used to construct full db name. Since `absolute_path`
380  is set to True, this will be used as db_name in SaveOp.
381  db_type: Type of database to use for storing checkpoint.
382  metadata_handler: An optional object capable of reading/writing
383  checkpoint info in storage of choice.
384  """
385  def __init__(self, db_prefix, db_type, metadata_handler=None):
386  self._node_managers = None
387  self._db_prefix = db_prefix
388  self._db_type = db_type
389  self._metadata_handler = metadata_handler
390  self._path_prefix = None
391  self._path_type = None
392 
393  def _task_group(self, func, *args, **kw):
394  assert self._node_managers is not None, 'init must be called first.'
395  with TaskGroup(WorkspaceType.GLOBAL) as task_group:
396  for node, manager in self._node_managers:
397  with Node(node):
398  func(manager, *args, **kw)
399  return task_group
400 
401  """
402  Args:
403  nodes: An array of nodes where this checkpoint manager is running.
404  retrieve_from_epoch: Set to a number to load blobs from this epoch.
405  path_prefix: Used to construct db name or path where checkpoint files are
406  stored.
407  path_type: Indicate the type of path where checkpoint files are stored.
408  """
409  def init(
410  self, nodes, retrieve_from_epoch=None, path_prefix=None, path_type=None
411  ):
412  if self._node_managers is not None:
413  assert [node for node, _ in self._node_managers] == nodes
414  return TaskGroup(WorkspaceType.GLOBAL)
415  self._node_managers = []
416  for node in nodes:
417  with Node(node):
418  manager = CheckpointManager(
419  db_prefix=self._db_prefix,
420  node_name=str(node),
421  db_type=self._db_type)
422  self._node_managers.append((node, manager))
423  return self._task_group(
424  CheckpointManager.init,
425  nodes=[node],
426  retrieve_from_epoch=retrieve_from_epoch,
427  path_prefix=path_prefix,
428  path_type=path_type)
429 
430  def load(self, epoch, path_prefix=None, path_type=None):
431  return self._task_group(
432  CheckpointManager.load,
433  epoch,
434  path_prefix=path_prefix,
435  path_type=path_type)
436 
437  def load_blobs_locally(self, nodes, blob_names, epoch, session):
438  """Loads the necessary blobs from the checkpoints to the current node.
439 
440  Args:
441  blob_names: A list of strings. Each string is the name of a
442  blob.
443  epoch: An integer. The checkpoint epoch to load from.
444  session: A Session object to execute the Load ops.
445  """
446  if self._node_managers is not None:
447  assert [node for node, _ in self._node_managers] == nodes
448  else:
449  self._node_managers = []
450  for node in nodes:
451  with Node(node):
452  manager = CheckpointManager(
453  db_prefix=self._db_prefix,
454  node_name=str(node),
455  db_type=self._db_type)
456  self._node_managers.append((node, manager))
457  assert self._node_managers is not None, 'must initialize node managers'
458  for _, manager in self._node_managers:
459  existence_task = manager.check_db_exists(epoch)
460  session.run(existence_task)
461  existence = existence_task.outputs()[0].fetch()
462  if not existence:
463  logger.info('DB %s does not exist!' %
464  db_name(epoch, manager._node_name, manager._db_prefix))
465  return False
466  load_task = manager.load_blobs_from_checkpoint(blob_names, epoch)
467  session.run(load_task)
468  logger.info('Successfully loaded from checkpoints.')
469  return True
470 
471  def get_ckpt_db_name(self, node_name, epoch):
472  """Returns the DB name of the given node and the given epoch.
473 
474  The DB name is effectively the checkpoint path of the given node and
475  the given epoch.
476 
477  Args:
478  node_name: A string. The node name of interest.
479  epoch: An integer. The epoch of the checkpoint.
480 
481  Returns:
482  checkpoint_db_name: A string. The checkpoint path of the given
483  node and the given epoch.
484  """
485  for node, manager in self._node_managers:
486  if str(node) == node_name:
487  return db_name(epoch, manager._node_name, manager._db_prefix)
488 
489  def save(self, epoch):
490  """
491  Build a Task that will execute a Save ops to serialize and persist
492  blobs present in the global workspace.
493  """
494  return self._task_group(CheckpointManager.save, epoch)
495 
496  def write_checkpoint_metadata(self, epoch):
497  """
498  Write metadata for checkpoint
499 
500  Args:
501  epoch: An integer. The epoch-id for which checkpoint metadata is
502  written
503  """
504  if self._metadata_handler is not None:
505  self._metadata_handler.write(epoch=epoch)
506 
507  def get_resume_from_epoch_id(self, user_epoch=None):
508  """
509  Identify the epoch-id from which Job must resume
510 
511  Args:
512  user_epoch: An integer. Optional parameter for user to explicitly
513  identify the epoch-id to load checkpoint from
514  Retruns:
515  epoch: the epoch-id to load checkpoints from
516  or None if no checkpoints were written
517  """
518  last_epoch = user_epoch
519  if self._metadata_handler is not None:
520  last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
521  return last_epoch
522 
523  def set_params(self, nodes, path_prefix=None, path_type=None):
524  """Set parameters associated with CP manager
525 
526  Args:
527  nodes: An array of nodes where this checkpoint manager is running.
528  path_prefix: Used to construct db name or path where checkpoint files are
529  stored.
530  path_type: Indicate the type of path where checkpoint files are stored.
531  """
532  self._node_names = [str(node) for node in nodes]
533  if path_prefix:
534  self._path_prefix = path_prefix
535  if path_type:
536  self._path_type = path_type
537  if self._metadata_handler:
538  self._metadata_handler.set_params(
539  db_prefix=self._db_prefix,
540  db_type=self._db_type,
541  node_names=self._node_names,
542  path_prefix=self._path_prefix,
543  path_type=self._path_type)
544 
545  def cp_accessible(self, epoch=None):
546  """Returns True if Checkpoint data is accessible
547 
548  Args:
549  epoch: An integer. The epoch of the checkpoint. If None,
550  it implies we need to check if checkpoint directory is accessible
551 
552  Returns:
553  is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
554  """
555  if self._metadata_handler is not None:
556  return self._metadata_handler.cp_accessible(epoch)
557  else:
558  return True
559 
560 
562  """A simple class to upload checkpoints."""
563  def build(self, epoch, checkpoint_manager):
564  """Builds the task group to upload checkpoints.
565 
566  Args:
567  epoch: An integer. The checkpoint epoch to be uploaded.
568  checkpoint_manager: Can be a CheckpointManager for single machine
569  or a MultiNodeCheckpointManager for multi-machine. The manager
570  that initializes/saves/loads checkpoints.
571 
572  Raises:
573  NotImplementedError: This base class only has the interface,
574  the implementation will be in the subclasses.
575  """
576  raise NotImplementedError()
577 
578 
579 class JobRunner(object):
580  """
581  Implement the runtime logic for jobs with checkpointing at the level of
582  epoch. Can be used to run either single-host or distributed jobs. Job
583  runner is a callable to be called once from the master, passing a session
584  as an argument. This call will block until the Job execution is complete.
585 
586  If a checkpoint_manager is passed, checkpoints will be taken after
587  initialization and after each epoch execution. If, in addition,
588  `resume_from_epoch` is an epoch number, the corresponding checkpoint will
589  be loaded and job execution will continue from the given epoch. In
590  this case, the job's init_group will not be run.
591 
592  Refer to checkpoint_test.py for an example.
593  """
594  def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None,
595  upload_task_group_builder=None):
596  """Initializes the JobRunner.
597 
598  Args:
599  job: A Job object. The job to be executed.
600  checkpoint_manager: Can be a CheckpointManager for single machine
601  or a MultiNodeCheckpointManager for multi-machine. The manager
602  that initializes/saves/loads checkpoints.
603  resume_from_epoch: An integer. The epoch to resume from.
604  upload_task_group_builder: A subclass of the
605  UploadTaskGroupBuilder. Creates a task group to upload
606  checkpoints.
607  """
608  self.resume_from_epoch = resume_from_epoch
609  self.checkpoint_manager = checkpoint_manager
610  self.job = job
611  self.upload_task_group_builder = upload_task_group_builder
612 
613  def __call__(self, session):
614  """Runs the training flow.
615 
616  Args:
617  session: A Session object. Valid choises are: LocalSession,
618  LocalHostScheduler, and DistributedSession. It is used to
619  execute one TaskGroup a time.
620  """
621  # identify the epoch we must resume from
622  if self.checkpoint_manager:
623  self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
625  get_resume_from_epoch_id(self.resume_from_epoch)
626  if self.resume_from_epoch is not None:
627  logger.info('Resuming from epoch {}'.format(self.resume_from_epoch))
628 
629  # Initialize all the nodes.
630  from_scratch = self.resume_from_epoch is None
631  if from_scratch:
632  session.run(self.job.init_group)
633 
634  if self.checkpoint_manager:
635  logger.info('Preparing checkpoints ...')
636  session.run(self.checkpoint_manager.init(
637  self.job.nodes_to_checkpoint(),
638  retrieve_from_epoch=self.resume_from_epoch))
639  # Save the first checkpoint before training starts, or resume from
640  # a previously saved checkpoint.
641  if from_scratch:
642  self.save_checkpoints(0, session)
643  else:
644  logger.info('Loading checkpoints for epoch {} ...'.format(
645  self.resume_from_epoch))
646  session.run(
647  self.checkpoint_manager.load(self.resume_from_epoch))
648  logger.info('Checkpoint loaded')
649 
650  logger.info("Finished initializing")
651 
652  # Start training.
653  epoch = 1 if from_scratch else self.resume_from_epoch + 1
654  while True:
655  logger.info('Starting epoch %d' % epoch)
656  session.run(self.job.epoch_group)
657  logger.info('Finished epoch %d' % epoch)
658  stop_signals = [o.fetch() for o in self.job.stop_signals]
659 
660  if self.checkpoint_manager:
661  self.save_checkpoints(epoch, session)
662 
663  if any(stop_signals):
664  logger.info('Stopping')
665  break
666  epoch += 1
667  logger.info('Finished training')
668  # Upload the checkpoints.
669  if (self.upload_task_group_builder):
670  upload_task_group = self.upload_task_group_builder.build(
671  epoch, self.checkpoint_manager)
672  session.run(upload_task_group)
673  logger.info('Finished uploading the checkpoints')
674 
675  # Download the parameters to save
676  session.run(self.job.download_group)
677  logger.info('Finished downloading the parameters')
678 
679  # Finally run the exit step to save nets
680  session.run(self.job.exit_group)
681  logger.info('Finished running the exit group')
682  return epoch
683 
684  def load_blobs_from_checkpoints(self, blob_names, epoch, session):
685  """Loads the necessary blobs from the checkpoints.
686 
687  Checkpoints store the snapshots of the workspace in each node.
688  Sometimes we only need to load a subset of the blobs from the
689  checkpoints. One common scenario is to load only the model blobs from
690  the checkpoints for evaluation purpose. Given the names of the
691  necessary blobs, this function goes over all the checkpoints of all the
692  nodes, but only loads the blobs specified in the blob_names to the
693  current workspace.
694 
695  Args:
696  blob_names: A list of strings. Each string is the name of a
697  blob.
698  epoch: An integer. The checkpoint epoch to load from.
699  session: A Session object to execute the load ops.
700 
701  Raises:
702  ValueError: When the checkpoint manager is invalid.
703  """
704  if not self.checkpoint_manager:
705  raise ValueError('Checkpoint manager is None')
706  logger.info('Loading checkpoint for epoch {} ...'.format(epoch))
707  return self.checkpoint_manager.load_blobs_locally(
708  self.job.nodes_to_checkpoint(), blob_names, epoch, session)
709 
710  def save_checkpoints(self, epoch, session):
711  """Triggers operation to save checkpoints
712 
713  This method will trigger the Save ops to serialize and persist the
714  blobs present in the global workspaace.
715 
716  Args:
717  epoch: An integer. The checkpoint epoch-id that we are saving.
718  session: A Session object to execute the save ops.
719 
720  Raises:
721  ValueError: When the checkpoint manager is invalid.
722  """
723  if not self.checkpoint_manager:
724  raise ValueError('Checkpoint manager is None')
725  try:
726  is_accessible = self.checkpoint_manager.cp_accessible(epoch=None)
727  if is_accessible:
728  logger.info('Saving checkpoints for epoch {}'.format(epoch))
729  session.run(self.checkpoint_manager.save(epoch))
730  self.checkpoint_manager.write_checkpoint_metadata(epoch)
731  logger.info('Checkpoints saved')
732  else:
733  logger.warning("Checkpoint files cannot be accessed!")
734  except Exception as ex:
735  logger.warning("Unable to write checkpoint for epoch {}. Error={}".
736  format(epoch, ex))
737 
738 
739 def epoch_limiter(num_epochs):
740  """
741  Creates a task that will output True when a given
742  number of epochs has finished.
743  """
744  with Job.current().init_group:
745  init_net = core.Net('epoch_counter_init')
746  counter = init_net.CreateCounter([], init_count=num_epochs - 1)
747  Task(step=init_net)
748  epoch_net = core.Net('epoch_countdown')
749  finished = epoch_net.CountDown(counter)
750  output = Task(step=epoch_net, outputs=finished).outputs()[0]
751  Job.current().add_stop_signal(output)
def set_params(self, nodes, path_prefix=None, path_type=None)
Definition: checkpoint.py:523
def init(self, nodes=None, retrieve_from_epoch=None, path_prefix=None, path_type=None)
Definition: checkpoint.py:202
def load_blobs_locally(self, nodes, blob_names, epoch, session)
Definition: checkpoint.py:437
def load_blobs_from_checkpoints(self, blob_names, epoch, session)
Definition: checkpoint.py:684
def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None, upload_task_group_builder=None)
Definition: checkpoint.py:595
def build(self, epoch, checkpoint_manager)
Definition: checkpoint.py:563
def save_checkpoints(self, epoch, session)
Definition: checkpoint.py:710
def get_resume_from_epoch_id(self, user_epoch=None)
Definition: checkpoint.py:319
def load(self, epoch, path_prefix=None, path_type=None)
Definition: checkpoint.py:236
def load_blobs_from_checkpoint(self, blob_names, epoch)
Definition: checkpoint.py:254
def set_params(self, nodes, path_prefix=None, path_type=None)
Definition: checkpoint.py:335
def get_resume_from_epoch_id(self, user_epoch=None)
Definition: checkpoint.py:507