Caffe2 - Python API
A deep learning, cross platform ML framework
experiment_util.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package experiment_util
17 # Module caffe2.python.experiment_util
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 import datetime
24 import time
25 import logging
26 import socket
27 import abc
28 import six
29 
30 from collections import OrderedDict
31 from future.utils import viewkeys, viewvalues
32 
33 '''
34 Utilities for logging experiment run stats, such as accuracy
35 and loss over time for different runs. Runtime arguments are stored
36 in the log.
37 
38 Optionally, ModelTrainerLog calls out to a logger to log to
39 an external log destination.
40 '''
41 
42 
43 class ExternalLogger(object):
44  six.add_metaclass(abc.ABCMeta)
45 
46  @abc.abstractmethod
47  def set_runtime_args(self, runtime_args):
48  """
49  Set runtime arguments for the logger.
50  runtime_args: dict of runtime arguments.
51  """
52  raise NotImplementedError(
53  'Must define set_runtime_args function to use this base class'
54  )
55 
56  @abc.abstractmethod
57  def log(self, log_dict):
58  """
59  log a dict of key/values to an external destination
60  log_dict: input dict
61  """
62  raise NotImplementedError(
63  'Must define log function to use this base class'
64  )
65 
66 
68 
69  def __init__(self, expname, runtime_args, external_loggers=None):
70  now = datetime.datetime.fromtimestamp(time.time())
71  self.experiment_id = \
72  "{}_{}".format(expname, now.strftime('%Y%m%d_%H%M%S'))
73  self.filename = "{}.log".format(self.experiment_id)
74  self.logstr("# %s" % str(runtime_args))
75  self.headers = None
76  self.start_time = time.time()
77  self.last_time = self.start_time
78  self.last_input_count = 0
79  self.external_loggers = None
80 
81  if external_loggers is not None:
82  self.external_loggers = external_loggers
83  if not isinstance(runtime_args, dict):
84  runtime_args = dict(vars(runtime_args))
85  runtime_args['experiment_id'] = self.experiment_id
86  runtime_args['hostname'] = socket.gethostname()
87  for logger in self.external_loggers:
88  logger.set_runtime_args(runtime_args)
89  else:
90  self.external_loggers = []
91 
92  def logstr(self, str):
93  with open(self.filename, "a") as f:
94  f.write(str + "\n")
95  f.close()
96  logging.getLogger("experiment_logger").info(str)
97 
98  def log(self, input_count, batch_count, additional_values):
99  logdict = OrderedDict()
100  delta_t = time.time() - self.last_time
101  delta_count = input_count - self.last_input_count
102  self.last_time = time.time()
103  self.last_input_count = input_count
104 
105  logdict['time_spent'] = delta_t
106  logdict['cumulative_time_spent'] = time.time() - self.start_time
107  logdict['input_count'] = delta_count
108  logdict['cumulative_input_count'] = input_count
109  logdict['cumulative_batch_count'] = batch_count
110  if delta_t > 0:
111  logdict['inputs_per_sec'] = delta_count / delta_t
112  else:
113  logdict['inputs_per_sec'] = 0.0
114 
115  for k in sorted(viewkeys(additional_values)):
116  logdict[k] = additional_values[k]
117 
118  # Write the headers if they are not written yet
119  if self.headers is None:
120  self.headers = list(viewkeys(logdict))
121  self.logstr(",".join(self.headers))
122 
123  self.logstr(",".join(str(v) for v in viewvalues(logdict)))
124 
125  for logger in self.external_loggers:
126  try:
127  logger.log(logdict)
128  except Exception as e:
129  logging.warn(
130  "Failed to call ExternalLogger: {}".format(e), e)