18 from __future__
import absolute_import
19 from __future__
import division
20 from __future__
import print_function
21 from __future__
import unicode_literals
33 Simple benchmark that creates a data-parallel resnet-50 model 39 log = logging.getLogger(
"net_construct_bench")
40 log.setLevel(logging.DEBUG)
43 def AddMomentumParameterUpdate(train_model, LR):
45 Add the momentum-SGD update. 47 params = train_model.GetParams()
48 assert(len(params) > 0)
49 ONE = train_model.param_init_net.ConstantFill(
50 [],
"ONE", shape=[1], value=1.0,
52 NEGONE = train_model.param_init_net.ConstantFill(
53 [],
'NEGONE', shape=[1], value=-1.0,
57 param_grad = train_model.param_to_grad[param]
58 param_momentum = train_model.param_init_net.ConstantFill(
59 [param], param +
'_momentum', value=0.0
63 train_model.net.MomentumSGD(
64 [param_grad, param_momentum, LR],
65 [param_grad, param_momentum],
71 train_model.WeightedSum(
72 [param, ONE, param_grad, NEGONE],
78 gpus = list(range(args.num_gpus))
79 log.info(
"Running on gpus: {}".format(gpus))
82 train_model = cnn.CNNModelHelper(
86 cudnn_exhaustive_search=
False 90 def create_resnet50_model_ops(model, loss_scale):
91 [softmax, loss] = resnet.create_resnet50(
98 model.Accuracy([softmax,
"label"],
"accuracy")
102 def add_parameter_update_ops(model):
103 model.AddWeightDecay(1e-4)
104 ITER = model.Iter(
"ITER")
106 LR = model.net.LearningRate(
114 AddMomentumParameterUpdate(model, LR)
116 def add_image_input(model):
119 start_time = time.time()
122 data_parallel_model.Parallelize_GPU(
124 input_builder_fun=add_image_input,
125 forward_pass_builder_fun=create_resnet50_model_ops,
126 param_update_builder_fun=add_parameter_update_ops,
130 ct = time.time() - start_time
131 train_model.net._CheckLookupTables()
133 log.info(
"Model create for {} gpus took: {} secs".format(len(gpus), ct))
138 parser = argparse.ArgumentParser(
139 description=
"Caffe2: Benchmark for net construction" 141 parser.add_argument(
"--num_gpus", type=int, default=1,
142 help=
"Number of GPUs.")
143 args = parser.parse_args()
148 if __name__ ==
'__main__':
149 workspace.GlobalInit([
'caffe2',
'--caffe2_log_level=2'])
152 cProfile.run(
'main()', sort=
"cumulative")