Caffe2 - Python API
A deep learning, cross platform ML framework
experiment_util.py
1 ## @package experiment_util
2 # Module caffe2.python.experiment_util
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import datetime
9 import time
10 import logging
11 import socket
12 import abc
13 import six
14 
15 from collections import OrderedDict
16 from future.utils import viewkeys, viewvalues
17 
18 '''
19 Utilities for logging experiment run stats, such as accuracy
20 and loss over time for different runs. Runtime arguments are stored
21 in the log.
22 
23 Optionally, ModelTrainerLog calls out to a logger to log to
24 an external log destination.
25 '''
26 
27 
28 class ExternalLogger(object):
29  six.add_metaclass(abc.ABCMeta)
30 
31  @abc.abstractmethod
32  def set_runtime_args(self, runtime_args):
33  """
34  Set runtime arguments for the logger.
35  runtime_args: dict of runtime arguments.
36  """
37  raise NotImplementedError(
38  'Must define set_runtime_args function to use this base class'
39  )
40 
41  @abc.abstractmethod
42  def log(self, log_dict):
43  """
44  log a dict of key/values to an external destination
45  log_dict: input dict
46  """
47  raise NotImplementedError(
48  'Must define log function to use this base class'
49  )
50 
51 
53 
54  def __init__(self, expname, runtime_args, external_loggers=None):
55  now = datetime.datetime.fromtimestamp(time.time())
56  self.experiment_id = \
57  "{}_{}".format(expname, now.strftime('%Y%m%d_%H%M%S'))
58  self.filename = "{}.log".format(self.experiment_id)
59  self.logstr("# %s" % str(runtime_args))
60  self.headers = None
61  self.start_time = time.time()
62  self.last_time = self.start_time
63  self.last_input_count = 0
64  self.external_loggers = None
65 
66  if external_loggers is not None:
67  self.external_loggers = external_loggers
68  if not isinstance(runtime_args, dict):
69  runtime_args = dict(vars(runtime_args))
70  runtime_args['experiment_id'] = self.experiment_id
71  runtime_args['hostname'] = socket.gethostname()
72  for logger in self.external_loggers:
73  logger.set_runtime_args(runtime_args)
74  else:
75  self.external_loggers = []
76 
77  def logstr(self, str):
78  with open(self.filename, "a") as f:
79  f.write(str + "\n")
80  f.close()
81  logging.getLogger("experiment_logger").info(str)
82 
83  def log(self, input_count, batch_count, additional_values):
84  logdict = OrderedDict()
85  delta_t = time.time() - self.last_time
86  delta_count = input_count - self.last_input_count
87  self.last_time = time.time()
88  self.last_input_count = input_count
89 
90  logdict['time_spent'] = delta_t
91  logdict['cumulative_time_spent'] = time.time() - self.start_time
92  logdict['input_count'] = delta_count
93  logdict['cumulative_input_count'] = input_count
94  logdict['cumulative_batch_count'] = batch_count
95  if delta_t > 0:
96  logdict['inputs_per_sec'] = delta_count / delta_t
97  else:
98  logdict['inputs_per_sec'] = 0.0
99 
100  for k in sorted(viewkeys(additional_values)):
101  logdict[k] = additional_values[k]
102 
103  # Write the headers if they are not written yet
104  if self.headers is None:
105  self.headers = list(viewkeys(logdict))
106  self.logstr(",".join(self.headers))
107 
108  self.logstr(",".join(str(v) for v in viewvalues(logdict)))
109 
110  for logger in self.external_loggers:
111  try:
112  logger.log(logdict)
113  except Exception as e:
114  logging.warn(
115  "Failed to call ExternalLogger: {}".format(e), e)