Caffe2 - Python API
A deep learning, cross platform ML framework
app.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package app
17 # Module caffe2.python.mint.app
18 import argparse
19 import flask
20 import glob
21 import numpy as np
22 import nvd3
23 import os
24 import sys
25 import tornado.httpserver
26 import tornado.wsgi
27 
28 __folder__ = os.path.abspath(os.path.dirname(__file__))
29 
30 app = flask.Flask(
31  __name__,
32  template_folder=os.path.join(__folder__, "templates"),
33  static_folder=os.path.join(__folder__, "static")
34 )
35 args = None
36 
37 
38 def jsonify_nvd3(chart):
39  chart.buildcontent()
40  # Note(Yangqing): python-nvd3 does not seem to separate the built HTML part
41  # and the script part. Luckily, it seems to be the case that the HTML part is
42  # only a <div>, which can be accessed by chart.container; the script part,
43  # while the script part occupies the rest of the html content, which we can
44  # then find by chart.htmlcontent.find['<script>'].
45  script_start = chart.htmlcontent.find('<script>') + 8
46  script_end = chart.htmlcontent.find('</script>')
47  return flask.jsonify(
48  result=chart.container,
49  script=chart.htmlcontent[script_start:script_end].strip()
50  )
51 
52 
53 def visualize_summary(filename):
54  try:
55  data = np.loadtxt(filename)
56  except Exception as e:
57  return 'Cannot load file {}: {}'.format(filename, str(e))
58  chart_name = os.path.splitext(os.path.basename(filename))[0]
59  chart = nvd3.lineChart(
60  name=chart_name + '_summary_chart',
61  height=args.chart_height,
62  y_axis_format='.03g'
63  )
64  if args.sample < 0:
65  step = max(data.shape[0] / -args.sample, 1)
66  else:
67  step = args.sample
68  xdata = np.arange(0, data.shape[0], step)
69  # data should have 4 dimensions.
70  chart.add_serie(x=xdata, y=data[xdata, 0], name='min')
71  chart.add_serie(x=xdata, y=data[xdata, 1], name='max')
72  chart.add_serie(x=xdata, y=data[xdata, 2], name='mean')
73  chart.add_serie(x=xdata, y=data[xdata, 2] + data[xdata, 3], name='m+std')
74  chart.add_serie(x=xdata, y=data[xdata, 2] - data[xdata, 3], name='m-std')
75  return jsonify_nvd3(chart)
76 
77 
78 def visualize_print_log(filename):
79  try:
80  data = np.loadtxt(filename)
81  if data.ndim == 1:
82  data = data[:, np.newaxis]
83  except Exception as e:
84  return 'Cannot load file {}: {}'.format(filename, str(e))
85  chart_name = os.path.splitext(os.path.basename(filename))[0]
86  chart = nvd3.lineChart(
87  name=chart_name + '_log_chart',
88  height=args.chart_height,
89  y_axis_format='.03g'
90  )
91  if args.sample < 0:
92  step = max(data.shape[0] / -args.sample, 1)
93  else:
94  step = args.sample
95  xdata = np.arange(0, data.shape[0], step)
96  # if there is only one curve, we also show the running min and max
97  if data.shape[1] == 1:
98  # We also print the running min and max for the steps.
99  trunc_size = data.shape[0] / step
100  running_mat = data[:trunc_size * step].reshape((trunc_size, step))
101  chart.add_serie(
102  x=xdata[:trunc_size],
103  y=running_mat.min(axis=1),
104  name='running_min'
105  )
106  chart.add_serie(
107  x=xdata[:trunc_size],
108  y=running_mat.max(axis=1),
109  name='running_max'
110  )
111  chart.add_serie(x=xdata, y=data[xdata, 0], name=chart_name)
112  else:
113  for i in range(0, min(data.shape[1], args.max_curves)):
114  # data should have 4 dimensions.
115  chart.add_serie(
116  x=xdata,
117  y=data[xdata, i],
118  name='{}[{}]'.format(chart_name, i)
119  )
120 
121  return jsonify_nvd3(chart)
122 
123 
124 def visualize_file(filename):
125  fullname = os.path.join(args.root, filename)
126  if filename.endswith('summary'):
127  return visualize_summary(fullname)
128  elif filename.endswith('log'):
129  return visualize_print_log(fullname)
130  else:
131  return flask.jsonify(
132  result='Unsupport file: {}'.format(filename),
133  script=''
134  )
135 
136 
137 @app.route('/')
138 def index():
139  files = glob.glob(os.path.join(args.root, "*.*"))
140  files.sort()
141  names = [os.path.basename(f) for f in files]
142  return flask.render_template(
143  'index.html',
144  root=args.root,
145  names=names,
146  debug_messages=names
147  )
148 
149 
150 @app.route('/visualization/<string:name>')
151 def visualization(name):
152  ret = visualize_file(name)
153  return ret
154 
155 
156 def main(argv):
157  parser = argparse.ArgumentParser("The mint visualizer.")
158  parser.add_argument(
159  '-p',
160  '--port',
161  type=int,
162  default=5000,
163  help="The flask port to use."
164  )
165  parser.add_argument(
166  '-r',
167  '--root',
168  type=str,
169  default='.',
170  help="The root folder to read files for visualization."
171  )
172  parser.add_argument(
173  '--max_curves',
174  type=int,
175  default=5,
176  help="The max number of curves to show in a dump tensor."
177  )
178  parser.add_argument(
179  '--chart_height',
180  type=int,
181  default=300,
182  help="The chart height for nvd3."
183  )
184  parser.add_argument(
185  '-s',
186  '--sample',
187  type=int,
188  default=-200,
189  help="Sample every given number of data points. A negative "
190  "number means the total points we will sample on the "
191  "whole curve. Default 100 points."
192  )
193  global args
194  args = parser.parse_args(argv)
195  server = tornado.httpserver.HTTPServer(tornado.wsgi.WSGIContainer(app))
196  server.listen(args.port)
197  print("Tornado server starting on port {}.".format(args.port))
198  tornado.ioloop.IOLoop.instance().start()
199 
200 
201 if __name__ == '__main__':
202  main(sys.argv[1:])