Caffe2 - Python API
A deep learning, cross platform ML framework
embedding_generation_benchmark.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package embedding_generation_benchmark
17 # Module caffe2.python.embedding_generation_benchmark
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from caffe2.proto import caffe2_pb2
24 from caffe2.python import workspace, core, utils, model_helper
25 
26 import argparse
27 import numpy as np
28 import time
29 
30 import logging
31 
32 logging.basicConfig()
33 log = logging.getLogger("embedding_generation_benchmark")
34 log.setLevel(logging.DEBUG)
35 
36 
37 def generate_data(T, batch_size, max_seq_length):
38  '''
39  Fill a queue with input data
40  '''
41  log.info("Generating T={} batches".format(T))
42 
43  generate_input_init_net = core.Net('generate_input_init')
44  queue = generate_input_init_net.CreateBlobsQueue(
45  [], "inputqueue", num_blobs=1, capacity=T,
46  )
47  workspace.RunNetOnce(generate_input_init_net)
48 
49  generate_input_net = core.Net('generate_input')
50  generate_input_net.EnqueueBlobs([queue, "scratch"], ["scratch"])
51  np.random.seed(2603)
52 
53  for t in range(T):
54  if (t % (max(10, T // 10)) == 0):
55  log.info("Generating data {}/{}".format(t, T))
56  X = np.tile(np.arange(max_seq_length), [batch_size, 1]).transpose()
57  workspace.FeedBlob("scratch", X)
58  workspace.RunNetOnce(generate_input_net.Proto())
59 
60  log.info("Finished data generation")
61  return queue
62 
63 
64 def generate_embedding_table(vocab_size, embedding_size):
65  log.info("Generating embedding table with dimensions {}"
66  .format([vocab_size, embedding_size]))
67 
68  generate_table_net = core.Net('generate_table')
69  table = generate_table_net.GaussianFill(
70  [],
71  ['embedding_table'],
72  shape=[vocab_size, embedding_size],
73  )
74 
75  workspace.RunNetOnce(generate_table_net)
76  return table
77 
78 
79 def create_model(args, queue, embedding_table, embedding_size):
80  model = model_helper.ModelHelper(name='embedding_generation_bench')
81  input_blob = model.net.DequeueBlobs(queue, 'input_data')
82 
83  if args.implementation == 'sinusoid':
84  model.net.SinusoidPositionEncoding(
85  [input_blob],
86  ['output'],
87  embedding_size=embedding_size
88  )
89  else:
90  model.net.Gather(
91  [embedding_table, input_blob],
92  ['output'],
93  )
94 
95  return model
96 
97 
98 def Caffe2EmbeddingGeneration(args):
99  T = args.data_size // args.batch_size
100 
101  queue = generate_data(T, args.batch_size, args.seq_length)
102 
103  embedding_table = None
104  if args.implementation == 'table':
105  embedding_table = generate_embedding_table(
106  args.seq_length,
107  args.embedding_size,
108  )
109 
110  model = create_model(args, queue, embedding_table, args.embedding_size)
111 
112  workspace.RunNetOnce(model.param_init_net)
113  workspace.CreateNet(model.net)
114 
115  start_time = time.time()
116  num_iters = T
117  total_iters = 0
118 
119  # Run the Benchmark
120  log.info("------ Warming up ------")
121  workspace.RunNet(model.net.Proto().name)
122 
123  log.info("------ Starting benchmark ------")
124  start_time = time.time()
125  last_time = time.time()
126  for iteration in range(1, num_iters, args.iters_to_report):
127  iters_once = min(args.iters_to_report, num_iters - iteration)
128  total_iters += iters_once
129  workspace.RunNet(model.net.Proto().name, iters_once)
130 
131  new_time = time.time()
132  log.info(
133  "Iter: {} / {}. Embeddings Generated Per Second: {}k.".format(
134  iteration,
135  num_iters,
136  (iters_once * args.batch_size * args.seq_length) /
137  (new_time - last_time) // 100 / 10,
138  )
139  )
140  last_time = new_time
141 
142  total_per_sec = (num_iters - 1) * args.batch_size * args.seq_length
143  total_per_sec = total_per_sec / (time.time() - start_time) // 100 / 10
144 
145  log.info("Done. Total embeddings generated per second " +
146  "excluding 1st iteration: {}k".format(total_per_sec))
147 
148  return time.time() - start_time
149 
150 
151 @utils.debug
152 def Benchmark(args):
153  return Caffe2EmbeddingGeneration(args)
154 
155 
156 def GetArgumentParser():
157  parser = argparse.ArgumentParser(
158  description="Embedding generation benchmark."
159  )
160 
161  parser.add_argument(
162  "--embedding_size",
163  type=int,
164  default=512,
165  help="Embedding size",
166  )
167  parser.add_argument(
168  "--batch_size",
169  type=int,
170  default=16,
171  help="The batch size."
172  )
173  parser.add_argument(
174  "--data_size",
175  type=int,
176  default=10000,
177  help="Number of sequences to generate"
178  )
179  parser.add_argument(
180  "--seq_length",
181  type=int,
182  default=128,
183  help="Max sequence length"
184  )
185  parser.add_argument(
186  "--iters_to_report",
187  type=int,
188  default=20,
189  help="Number of iterations to report progress"
190  )
191  parser.add_argument(
192  "--implementation",
193  type=str,
194  default="sinusoid",
195  help="'table' or 'sinusoid'",
196  )
197  return parser
198 
199 
200 if __name__ == '__main__':
201  args, extra_args = GetArgumentParser().parse_known_args()
202 
203  workspace.GlobalInit([
204  'caffe2',
205  '--caffe2_log_level=0',
206  '--caffe2_print_blob_sizes_at_exit=0'] + extra_args)
207 
208  device = core.DeviceOption(caffe2_pb2.CPU)
209 
210  with core.DeviceScope(device):
211  Benchmark(args)