Caffe2 - Python API
A deep learning, cross platform ML framework
store_ops_test_util.py
1 # Copyright (c) 2016-present, Facebook, Inc.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 # http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 ##############################################################################
15 
16 ## @package store_ops_test_util
17 # Module caffe2.distributed.store_ops_test_util
18 from __future__ import absolute_import
19 from __future__ import division
20 from __future__ import print_function
21 from __future__ import unicode_literals
22 
23 from multiprocessing import Process, Queue
24 
25 import numpy as np
26 
27 from caffe2.python import core, workspace
28 
29 
30 class StoreOpsTests(object):
31  @classmethod
32  def _test_set_get(cls, queue, create_store_handler_fn, index, num_procs):
33  store_handler = create_store_handler_fn()
34  blob = "blob"
35  value = np.full(1, 1, np.float32)
36 
37  # Use last process to set blob to make sure other processes
38  # are waiting for the blob before it is set.
39  if index == (num_procs - 1):
40  workspace.FeedBlob(blob, value)
41  workspace.RunOperatorOnce(
42  core.CreateOperator(
43  "StoreSet",
44  [store_handler, blob],
45  [],
46  blob_name=blob))
47 
48  output_blob = "output_blob"
49  workspace.RunOperatorOnce(
50  core.CreateOperator(
51  "StoreGet",
52  [store_handler],
53  [output_blob],
54  blob_name=blob))
55 
56  try:
57  np.testing.assert_array_equal(workspace.FetchBlob(output_blob), 1)
58  except AssertionError as err:
59  queue.put(err)
60 
61  workspace.ResetWorkspace()
62 
63  @classmethod
64  def test_set_get(cls, create_store_handler_fn):
65  # Queue for assertion errors on subprocesses
66  queue = Queue()
67 
68  # Start N processes in the background
69  num_procs = 4
70  procs = []
71  for index in range(num_procs):
72  proc = Process(
73  target=cls._test_set_get,
74  args=(queue, create_store_handler_fn, index, num_procs, ))
75  proc.start()
76  procs.append(proc)
77 
78  # Test complete, join background processes
79  for proc in procs:
80  proc.join()
81 
82  # Raise first error we find, if any
83  if not queue.empty():
84  raise queue.get()
def _test_set_get(cls, queue, create_store_handler_fn, index, num_procs)