Caffe2 - Python API
A deep learning, cross platform ML framework
lmdb_create_example.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package lmdb_create_example
17 # Module caffe2.python.examples.lmdb_create_example
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 import argparse
24 import numpy as np
25 
26 import lmdb
27 from caffe2.proto import caffe2_pb2
28 from caffe2.python import workspace, model_helper
29 
30 '''
31 Simple example to create an lmdb database of random image data and labels.
32 This can be used a skeleton to write your own data import.
33 
34 It also runs a dummy-model with Caffe2 that reads the data and
35 validates the checksum is same.
36 '''
37 
38 
39 def create_db(output_file):
40  print(">>> Write database...")
41  LMDB_MAP_SIZE = 1 << 40 # MODIFY
42  env = lmdb.open(output_file, map_size=LMDB_MAP_SIZE)
43 
44  checksum = 0
45  with env.begin(write=True) as txn:
46  for j in range(0, 128):
47  # MODIFY: add your own data reader / creator
48  label = j % 10
49  width = 64
50  height = 32
51 
52  img_data = np.random.rand(3, width, height)
53  # ...
54 
55  # Create TensorProtos
56  tensor_protos = caffe2_pb2.TensorProtos()
57  img_tensor = tensor_protos.protos.add()
58  img_tensor.dims.extend(img_data.shape)
59  img_tensor.data_type = 1
60 
61  flatten_img = img_data.reshape(np.prod(img_data.shape))
62  img_tensor.float_data.extend(flatten_img)
63 
64  label_tensor = tensor_protos.protos.add()
65  label_tensor.data_type = 2
66  label_tensor.int32_data.append(label)
67  txn.put(
68  '{}'.format(j).encode('ascii'),
69  tensor_protos.SerializeToString()
70  )
71 
72  checksum += np.sum(img_data) * label
73  if (j % 16 == 0):
74  print("Inserted {} rows".format(j))
75 
76  print("Checksum/write: {}".format(int(checksum)))
77  return checksum
78 
79 
80 def read_db_with_caffe2(db_file, expected_checksum):
81  print(">>> Read database...")
82  model = model_helper.ModelHelper(name="lmdbtest")
83  batch_size = 32
84  data, label = model.TensorProtosDBInput(
85  [], ["data", "label"], batch_size=batch_size,
86  db=db_file, db_type="lmdb")
87 
88  checksum = 0
89 
90  workspace.RunNetOnce(model.param_init_net)
91  workspace.CreateNet(model.net)
92 
93  for _ in range(0, 4):
94  workspace.RunNet(model.net.Proto().name)
95 
96  img_datas = workspace.FetchBlob("data")
97  labels = workspace.FetchBlob("label")
98  for j in range(batch_size):
99  checksum += np.sum(img_datas[j, :]) * labels[j]
100 
101  print("Checksum/read: {}".format(int(checksum)))
102  assert np.abs(expected_checksum - checksum < 0.1), \
103  "Read/write checksums dont match"
104 
105 
106 def main():
107  parser = argparse.ArgumentParser(
108  description="Example LMDB creation"
109  )
110  parser.add_argument("--output_file", type=str, default=None,
111  help="Path to write the database to",
112  required=True)
113 
114  args = parser.parse_args()
115  checksum = create_db(args.output_file)
116 
117  # For testing reading:
118  read_db_with_caffe2(args.output_file, checksum)
119 
120 
121 if __name__ == '__main__':
122  main()