Caffe2 - Python API
A deep learning, cross platform ML framework
app.py
1 ## @package app
2 # Module caffe2.python.mint.app
3 import argparse
4 import flask
5 import glob
6 import numpy as np
7 import nvd3
8 import os
9 import sys
10 import tornado.httpserver
11 import tornado.wsgi
12 
13 __folder__ = os.path.abspath(os.path.dirname(__file__))
14 
15 app = flask.Flask(
16  __name__,
17  template_folder=os.path.join(__folder__, "templates"),
18  static_folder=os.path.join(__folder__, "static")
19 )
20 args = None
21 
22 
23 def jsonify_nvd3(chart):
24  chart.buildcontent()
25  # Note(Yangqing): python-nvd3 does not seem to separate the built HTML part
26  # and the script part. Luckily, it seems to be the case that the HTML part is
27  # only a <div>, which can be accessed by chart.container; the script part,
28  # while the script part occupies the rest of the html content, which we can
29  # then find by chart.htmlcontent.find['<script>'].
30  script_start = chart.htmlcontent.find('<script>') + 8
31  script_end = chart.htmlcontent.find('</script>')
32  return flask.jsonify(
33  result=chart.container,
34  script=chart.htmlcontent[script_start:script_end].strip()
35  )
36 
37 
38 def visualize_summary(filename):
39  try:
40  data = np.loadtxt(filename)
41  except Exception as e:
42  return 'Cannot load file {}: {}'.format(filename, str(e))
43  chart_name = os.path.splitext(os.path.basename(filename))[0]
44  chart = nvd3.lineChart(
45  name=chart_name + '_summary_chart',
46  height=args.chart_height,
47  y_axis_format='.03g'
48  )
49  if args.sample < 0:
50  step = max(data.shape[0] / -args.sample, 1)
51  else:
52  step = args.sample
53  xdata = np.arange(0, data.shape[0], step)
54  # data should have 4 dimensions.
55  chart.add_serie(x=xdata, y=data[xdata, 0], name='min')
56  chart.add_serie(x=xdata, y=data[xdata, 1], name='max')
57  chart.add_serie(x=xdata, y=data[xdata, 2], name='mean')
58  chart.add_serie(x=xdata, y=data[xdata, 2] + data[xdata, 3], name='m+std')
59  chart.add_serie(x=xdata, y=data[xdata, 2] - data[xdata, 3], name='m-std')
60  return jsonify_nvd3(chart)
61 
62 
63 def visualize_print_log(filename):
64  try:
65  data = np.loadtxt(filename)
66  if data.ndim == 1:
67  data = data[:, np.newaxis]
68  except Exception as e:
69  return 'Cannot load file {}: {}'.format(filename, str(e))
70  chart_name = os.path.splitext(os.path.basename(filename))[0]
71  chart = nvd3.lineChart(
72  name=chart_name + '_log_chart',
73  height=args.chart_height,
74  y_axis_format='.03g'
75  )
76  if args.sample < 0:
77  step = max(data.shape[0] / -args.sample, 1)
78  else:
79  step = args.sample
80  xdata = np.arange(0, data.shape[0], step)
81  # if there is only one curve, we also show the running min and max
82  if data.shape[1] == 1:
83  # We also print the running min and max for the steps.
84  trunc_size = data.shape[0] / step
85  running_mat = data[:trunc_size * step].reshape((trunc_size, step))
86  chart.add_serie(
87  x=xdata[:trunc_size],
88  y=running_mat.min(axis=1),
89  name='running_min'
90  )
91  chart.add_serie(
92  x=xdata[:trunc_size],
93  y=running_mat.max(axis=1),
94  name='running_max'
95  )
96  chart.add_serie(x=xdata, y=data[xdata, 0], name=chart_name)
97  else:
98  for i in range(0, min(data.shape[1], args.max_curves)):
99  # data should have 4 dimensions.
100  chart.add_serie(
101  x=xdata,
102  y=data[xdata, i],
103  name='{}[{}]'.format(chart_name, i)
104  )
105 
106  return jsonify_nvd3(chart)
107 
108 
109 def visualize_file(filename):
110  fullname = os.path.join(args.root, filename)
111  if filename.endswith('summary'):
112  return visualize_summary(fullname)
113  elif filename.endswith('log'):
114  return visualize_print_log(fullname)
115  else:
116  return flask.jsonify(
117  result='Unsupport file: {}'.format(filename),
118  script=''
119  )
120 
121 
122 @app.route('/')
123 def index():
124  files = glob.glob(os.path.join(args.root, "*.*"))
125  files.sort()
126  names = [os.path.basename(f) for f in files]
127  return flask.render_template(
128  'index.html',
129  root=args.root,
130  names=names,
131  debug_messages=names
132  )
133 
134 
135 @app.route('/visualization/<string:name>')
136 def visualization(name):
137  ret = visualize_file(name)
138  return ret
139 
140 
141 def main(argv):
142  parser = argparse.ArgumentParser("The mint visualizer.")
143  parser.add_argument(
144  '-p',
145  '--port',
146  type=int,
147  default=5000,
148  help="The flask port to use."
149  )
150  parser.add_argument(
151  '-r',
152  '--root',
153  type=str,
154  default='.',
155  help="The root folder to read files for visualization."
156  )
157  parser.add_argument(
158  '--max_curves',
159  type=int,
160  default=5,
161  help="The max number of curves to show in a dump tensor."
162  )
163  parser.add_argument(
164  '--chart_height',
165  type=int,
166  default=300,
167  help="The chart height for nvd3."
168  )
169  parser.add_argument(
170  '-s',
171  '--sample',
172  type=int,
173  default=-200,
174  help="Sample every given number of data points. A negative "
175  "number means the total points we will sample on the "
176  "whole curve. Default 100 points."
177  )
178  global args
179  args = parser.parse_args(argv)
180  server = tornado.httpserver.HTTPServer(tornado.wsgi.WSGIContainer(app))
181  server.listen(args.port)
182  print("Tornado server starting on port {}.".format(args.port))
183  tornado.ioloop.IOLoop.instance().start()
184 
185 
186 if __name__ == '__main__':
187  main(sys.argv[1:])