3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
21 logger = logging.getLogger(__name__)
22 logger.setLevel(logging.INFO)
26 @context.define_context()
29 A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the 30 `exit_group` which will be run by a JobRunner. 32 The `init_group` will be run only once at startup. Its role is to 33 initialize globally persistent blobs such as model weights, accumulators 36 The `epoch_group` will be run in a loop after init_group. The loop will 37 exit when any of the stop signals added with `add_stop_condition` is True 38 at the end of an epoch. 40 The download_group will be run only once, after all the executions of 41 epoch_group finish. Its role is to collect the distribute scattered 42 parameters back after training. 44 The `exit_group` will be run only once at the very end of the job, the 45 role of this group is to save the results of training in the end of the job. 47 Jobs are context-driven, so that Tasks can be added to the active Job 48 without having to explicitly pass the job object around. 52 def build_reader(partitions): 53 with Job.current().init_group: 54 reader = HiveReader(init_reader, ..., partitions) 55 Task(step=init_reader) 56 with Job.current().epoch_group: 57 limited_reader = ReaderWithLimit(reader, num_iter=10000) 58 data_queue = pipe(limited_reader, num_threads=8) 59 Job.current().add_stop_condition(limited_reader.data_finished()) 62 def build_hogwild_trainer(reader, model): 63 with Job.current().init_group: 64 Task(step=model.param_init_net) 65 with Job.current().epoch_group: 66 pipe(reader, processor=model, num_threads=8) 67 with Job.current().exit_group: 68 Task(step=model.save_model_net) 71 reader = build_reader(partitions) 72 model = build_model(params) 73 build_hogwild_trainer(reader, model) 76 init_group=
None, epoch_group=
None,
77 download_group=
None, exit_group=
None,
78 stop_conditions=
None, nodes_to_checkpoint=
None):
80 workspace_type=WorkspaceType.GLOBAL)
87 def nodes_to_checkpoint(self):
91 return self.init_group.used_nodes()
93 def compile(self, session_class):
101 self.epoch_group.__enter__()
104 def __exit__(self, *args):
105 self.epoch_group.__exit__()
107 def add_stop_condition(self, output):
110 output = t.outputs()[0]
111 assert isinstance(output, TaskOutput)
112 self.stop_conditions.append(output)
115 def get_ckpt_filename(node_name, epoch):
116 """Returns the checkpoint filename. 119 node_name: A string. The name of the node. 120 epoch: An integer. The checkpoint epoch. 123 ckpt_filename: A string. The filename of the checkpoint. 125 return node_name +
'.' + str(epoch)
128 def db_name(epoch, node_name, db_prefix, path_prefix=None):
129 """Returns the full db name where checkpoint files are saved. 132 epoch: An integer. The checkpoint epoch. 133 node_name: A string. The name of the node. 134 db_prefix: A string. The prefix used to construct full db name. 135 path_prefix: A string. Optional param used to construct db name or path 136 where checkpoint files are are stored. 138 db_name: A string. The absolute path of full_db_name where checkpoint 142 db_name = path_prefix + get_ckpt_filename(node_name, epoch)
144 ckpt_filename = get_ckpt_filename(node_name, epoch)
145 db_name = os.path.join(db_prefix, ckpt_filename)
151 Controls saving and loading of workspaces on every epoch boundary of a job. 152 If a CheckpointManager instance is passed to JobRunner, then JobRunner will 153 call `init`, `read` and `save` at different moments in between epoch runs. 156 db_prefix: The prefix used to construct full db name. Since `absolute_path` 157 is set to True, this will be used as db_name in SaveOp. 158 node_name: Name of the node where this checkpoint_manager is used. 159 db_type: Type of database to use for storing checkpoint. 160 metadata_handler: An optional object capable of reading/writing 161 checkpoint info in storage of choice. 164 BLOB_NAMES =
"blob_names" 166 def __init__(self, db_prefix, node_name, db_type, metadata_handler=None):
181 Initialize the checkpoint manager. Determines all blobs that need to be saved 182 or loads from a checkpoint. 185 nodes: An array of nodes where this checkpoint manager is running. Should 186 only contain a single node. 187 retrieve_from_epoch: Set to a number to load blobs from this epoch. 188 path_prefix: Used to construct db name or path where checkpoint files are 190 path_type: Indicate the type of path where checkpoint files are stored. 195 retrieve_from_epoch=
None,
200 Build a Task that will be run once after the job's `init_group` is run. 201 This task will determine which blobs need to be checkpointed. 202 If retrieve_from_epoch is not None, then the checkpoint metadata is 203 retrieved from a previously saved checkpoint. 205 assert nodes
is None or len(nodes) == 1, (
206 'CheckpointManager only supports single node.')
209 if retrieve_from_epoch
is None:
213 include_shared=
False)
215 full_db_name = db_name(retrieve_from_epoch,
217 db_type = path_type
or self.
_db_type 218 logger.info(
"Initializing checkpoints from = %s" 232 return self._names_output.fetch().tolist()
234 def _timed_task(self, cp_op_name, add_op):
236 Build a Task that will measure the time span of checkpoint operations, 237 once operation is done, time can be read from _current_checkpoint_duration. 240 cp_op_name: A string name of the checkpoint operation. 241 add_op: A functor to add the checkpoint operation. 246 with Task(name=cp_op_name)
as task:
247 with ops.task_init():
248 timer = ops.TimerBegin([], counter_name=self.
_node_name)
250 with ops.task_exit():
251 time_span_blob = ops.TimerGetAndEnd(timer)
257 Add one checkpoint stats into the stats. 260 stats: A dict of checkpoint stats that will be reported. 266 "Failed to collect checkpoint stats: {}".format(
271 def load(self, epoch, path_prefix=None, path_type=None):
273 Build a Task that will be run by JobRunner when the job is to be 274 resumed from a given epoch. This task will run a Load op that will 275 load and deserialize all relevant blobs from a persistent storage. 280 db_type = path_type
or self.
_db_type 297 Builds a Task that loads only the necessary blobs from a checkpoint of 298 the given epoch. The necessary blobs are given in the blob_names 302 blob_names: A list of strings. Each string is the name of a 304 epoch: The checkpoint epoch to load from. 307 A Task which loads the specified blobs from the checkpoint of the 320 allow_incomplete=
True)
322 return self.
_timed_task(
'checkpoint_partial_load', add_op)
324 def check_db_exists(self, epoch):
325 logger.info(
'Check existence of %s' %
328 existence = ops.Const(
False)
335 task.add_output(existence)
340 Report checkpoint operation stats for current node. 343 action_name: A string of the name of checkpoint operation. 348 self._metadata_handler.report(action_name, all_stats)
352 Build a Task that is run once after `init_group` and after each 353 epoch is run. This will execute a Save ops to serialize and persist 354 blobs present in the global workspace. 370 Write metadata for checkpoint 373 epoch: An integer. The epoch-id for which checkpoint metadata is 377 self._metadata_handler.write(epoch=epoch)
381 Identify the epoch-id from which Job must resume 384 user_epoch: An integer. Optional parameter for user to explicitly 385 identify the epoch-id to load checkpoint from 387 epoch: the epoch-id to load checkpoints from 388 or None if no checkpoints were written 390 last_epoch = user_epoch
392 last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
395 def set_params(self, nodes, path_prefix=None, path_type=None):
396 """Set parameters associated with CP manager 399 nodes: An array of nodes where this checkpoint manager is running. 400 path_prefix: Used to construct db name or path where checkpoint files are 402 path_type: Indicate the type of path where checkpoint files are stored. 409 self._metadata_handler.set_params(
417 """Returns True if Checkpoint data is accessible 420 epoch: An integer. The epoch of the checkpoint. If None, 421 it implies we need to check if checkpoint directory is accessible 424 is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible 427 return self._metadata_handler.cp_accessible(epoch)
434 Coordinates checkpointing and checkpointing across multiple nodes. 435 Each of `init`, `load` and `save` will build TaskGroups which will 436 trigger checkpointing on each of the nodes involved in a distributed job. 439 db_prefix: The prefix used to construct full db name. Since `absolute_path` 440 is set to True, this will be used as db_name in SaveOp. 441 db_type: Type of database to use for storing checkpoint. 442 metadata_handler: An optional object capable of reading/writing 443 checkpoint info in storage of choice. 445 def __init__(self, db_prefix, db_type, metadata_handler=None):
453 def _task_group(self, func, *args, **kw):
454 assert self.
_node_managers is not None,
'init must be called first.' 455 with TaskGroup(WorkspaceType.GLOBAL)
as task_group:
458 func(manager, *args, **kw)
463 nodes: An array of nodes where this checkpoint manager is running. 464 retrieve_from_epoch: Set to a number to load blobs from this epoch. 465 path_prefix: Used to construct db name or path where checkpoint files are 467 path_type: Indicate the type of path where checkpoint files are stored. 470 self, nodes, retrieve_from_epoch=
None, path_prefix=
None, path_type=
None 474 return TaskGroup(WorkspaceType.GLOBAL)
482 self._node_managers.append((node, manager))
484 CheckpointManager.init,
486 retrieve_from_epoch=retrieve_from_epoch,
487 path_prefix=path_prefix,
490 def load(self, epoch, path_prefix=None, path_type=None):
492 CheckpointManager.load,
494 path_prefix=path_prefix,
498 """Loads the necessary blobs from the checkpoints to the current node. 501 blob_names: A list of strings. Each string is the name of a 503 epoch: An integer. The checkpoint epoch to load from. 504 session: A Session object to execute the Load ops. 516 self._node_managers.append((node, manager))
517 assert self.
_node_managers is not None,
'must initialize node managers' 519 existence_task = manager.check_db_exists(epoch)
520 session.run(existence_task)
521 existence = existence_task.outputs()[0].fetch()
523 logger.info(
'DB %s does not exist!' %
524 db_name(epoch, manager._node_name, manager._db_prefix))
526 load_task = manager.load_blobs_from_checkpoint(blob_names, epoch)
527 session.run(load_task)
528 logger.info(
'Successfully loaded from checkpoints.')
532 """Returns the DB name of the given node and the given epoch. 534 The DB name is effectively the checkpoint path of the given node and 538 node_name: A string. The node name of interest. 539 epoch: An integer. The epoch of the checkpoint. 542 checkpoint_db_name: A string. The checkpoint path of the given 543 node and the given epoch. 546 if str(node) == node_name:
547 return db_name(epoch, manager._node_name, manager._db_prefix)
551 Report the checkpoint stats for all the nodes, we need to aggregate all 552 the node's stats together so that we know which node's checkpoint 556 action_name: A string of the name of checkpoint operation. 560 manager.collect_checkpoint_stats(all_stats)
561 logger.debug(
"checkpoint stats: {}".format(all_stats))
563 self._metadata_handler.report(action_name, all_stats)
567 Build a Task that will execute a Save ops to serialize and persist 568 blobs present in the global workspace. 570 return self.
_task_group(CheckpointManager.save, epoch)
574 Write metadata for checkpoint 577 epoch: An integer. The epoch-id for which checkpoint metadata is 581 self._metadata_handler.write(epoch=epoch)
585 Identify the epoch-id from which Job must resume 588 user_epoch: An integer. Optional parameter for user to explicitly 589 identify the epoch-id to load checkpoint from 591 epoch: the epoch-id to load checkpoints from 592 or None if no checkpoints were written 594 last_epoch = user_epoch
596 last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
599 def set_params(self, nodes, path_prefix=None, path_type=None):
600 """Set parameters associated with CP manager 603 nodes: An array of nodes where this checkpoint manager is running. 604 path_prefix: Used to construct db name or path where checkpoint files are 606 path_type: Indicate the type of path where checkpoint files are stored. 614 self._metadata_handler.set_params(
622 """Returns True if Checkpoint data is accessible 625 epoch: An integer. The epoch of the checkpoint. If None, 626 it implies we need to check if checkpoint directory is accessible 629 is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible 632 return self._metadata_handler.cp_accessible(epoch)
638 """A simple class to upload checkpoints.""" 639 def build(self, epoch, checkpoint_manager):
640 """Builds the task group to upload checkpoints. 643 epoch: An integer. The checkpoint epoch to be uploaded. 644 checkpoint_manager: Can be a CheckpointManager for single machine 645 or a MultiNodeCheckpointManager for multi-machine. The manager 646 that initializes/saves/loads checkpoints. 649 NotImplementedError: This base class only has the interface, 650 the implementation will be in the subclasses. 652 raise NotImplementedError()
657 Implement the runtime logic for jobs with checkpointing at the level of 658 epoch. Can be used to run either single-host or distributed jobs. Job 659 runner is a callable to be called once from the master, passing a session 660 as an argument. This call will block until the Job execution is complete. 662 If a checkpoint_manager is passed, checkpoints will be taken after 663 initialization and after each epoch execution. If, in addition, 664 `resume_from_epoch` is an epoch number, the corresponding checkpoint will 665 be loaded and job execution will continue from the given epoch. In 666 this case, the job's init_group will not be run. 668 Refer to checkpoint_test.py for an example. 670 def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None,
671 upload_task_group_builder=
None):
672 """Initializes the JobRunner. 675 job: A Job object. The job to be executed. 676 checkpoint_manager: Can be a CheckpointManager for single machine 677 or a MultiNodeCheckpointManager for multi-machine. The manager 678 that initializes/saves/loads checkpoints. 679 resume_from_epoch: An integer. The epoch to resume from. 680 upload_task_group_builder: A subclass of the 681 UploadTaskGroupBuilder. Creates a task group to upload 690 """Runs the training flow. 693 session: A Session object. Valid choises are: LocalSession, 694 LocalHostScheduler, and DistributedSession. It is used to 695 execute one TaskGroup a time. 699 self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
708 session.run(self.job.init_group)
711 logger.info(
'Preparing checkpoints ...')
712 session.run(self.checkpoint_manager.init(
713 self.job.nodes_to_checkpoint(),
720 logger.info(
'Loading checkpoints for epoch {} ...'.format(
724 self.checkpoint_manager.report_checkpoint_stats(
'checkpoint_load')
725 logger.info(
'Checkpoint loaded')
727 logger.info(
"Finished initializing")
732 logger.info(
'Starting epoch %d' % epoch)
733 session.run(self.job.epoch_group)
734 logger.info(
'Finished epoch %d' % epoch)
735 stop_conditions = [o.fetch()
for o
in self.job.stop_conditions]
740 if any(stop_conditions):
741 logger.info(
'Stopping')
744 logger.info(
'Finished training')
747 upload_task_group = self.upload_task_group_builder.build(
749 session.run(upload_task_group)
750 logger.info(
'Finished uploading the checkpoints')
753 session.run(self.job.download_group)
754 logger.info(
'Finished downloading the parameters')
757 session.run(self.job.exit_group)
758 logger.info(
'Finished running the exit group')
762 """Loads the necessary blobs from the checkpoints. 764 Checkpoints store the snapshots of the workspace in each node. 765 Sometimes we only need to load a subset of the blobs from the 766 checkpoints. One common scenario is to load only the model blobs from 767 the checkpoints for evaluation purpose. Given the names of the 768 necessary blobs, this function goes over all the checkpoints of all the 769 nodes, but only loads the blobs specified in the blob_names to the 773 blob_names: A list of strings. Each string is the name of a 775 epoch: An integer. The checkpoint epoch to load from. 776 session: A Session object to execute the load ops. 779 ValueError: When the checkpoint manager is invalid. 782 raise ValueError(
'Checkpoint manager is None')
783 logger.info(
'Loading checkpoint for epoch {} ...'.format(epoch))
784 result = self.checkpoint_manager.load_blobs_locally(
785 self.job.nodes_to_checkpoint(), blob_names, epoch, session)
786 self.checkpoint_manager.report_checkpoint_stats(
'checkpoint_partial_load')
790 """Triggers operation to save checkpoints 792 This method will trigger the Save ops to serialize and persist the 793 blobs present in the global workspaace. 796 epoch: An integer. The checkpoint epoch-id that we are saving. 797 session: A Session object to execute the save ops. 800 ValueError: When the checkpoint manager is invalid. 803 raise ValueError(
'Checkpoint manager is None')
805 is_accessible = self.checkpoint_manager.cp_accessible(epoch=
None)
807 logger.info(
'Saving checkpoints for epoch {}'.format(epoch))
808 session.run(self.checkpoint_manager.save(epoch))
809 self.checkpoint_manager.write_checkpoint_metadata(epoch)
810 logger.info(
'Checkpoints saved')
811 self.checkpoint_manager.report_checkpoint_stats(
'checkpoint_save')
813 logger.warning(
"Checkpoint files cannot be accessed!")
814 except Exception
as ex:
815 logger.warning(
"Unable to write checkpoint for epoch {}. Error={}".
819 def epoch_limiter(job, num_epochs):
821 Creates a task that will output True when a given 822 number of epochs has finished. 825 init_net =
core.Net(
'epoch_counter_init')
826 counter = init_net.CreateCounter([], init_count=num_epochs - 1)
829 with job.epoch_group:
830 epoch_net =
core.Net(
'epoch_countdown')
831 finished = epoch_net.CountDown(counter)
832 output = Task(step=epoch_net, outputs=finished).outputs()[0]
833 job.add_stop_condition(output)
def set_params(self, nodes, path_prefix=None, path_type=None)
def collect_checkpoint_stats(self, stats)
_current_checkpoint_duration
def nodes_to_checkpoint(self)
def _timed_task(self, cp_op_name, add_op)
def init(self, nodes=None, retrieve_from_epoch=None, path_prefix=None, path_type=None)
def load_blobs_locally(self, nodes, blob_names, epoch, session)
def load_blobs_from_checkpoints(self, blob_names, epoch, session)
def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None, upload_task_group_builder=None)
upload_task_group_builder
def write_checkpoint_metadata(self, epoch)
def _task_group(self, func, args, kw)
def build(self, epoch, checkpoint_manager)
def get_ckpt_db_name(self, node_name, epoch)
def save_checkpoints(self, epoch, session)
def report_checkpoint_stats(self, action_name)
def get_resume_from_epoch_id(self, user_epoch=None)
def cp_accessible(self, epoch=None)
def load(self, epoch, path_prefix=None, path_type=None)
def load_blobs_from_checkpoint(self, blob_names, epoch)
def cp_accessible(self, epoch=None)
def set_params(self, nodes, path_prefix=None, path_type=None)
def get_resume_from_epoch_id(self, user_epoch=None)
def report_checkpoint_stats(self, action_name)
def write_checkpoint_metadata(self, epoch)