Caffe2 - Python API
A deep learning, cross platform ML framework
store_ops_test_util.py
1 ## @package store_ops_test_util
2 # Module caffe2.distributed.store_ops_test_util
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from multiprocessing import Process, Queue
9 
10 import numpy as np
11 
12 from caffe2.python import core, workspace
13 
14 
15 class StoreOpsTests(object):
16  @classmethod
17  def _test_set_get(cls, queue, create_store_handler_fn, index, num_procs):
18  store_handler = create_store_handler_fn()
19  blob = "blob"
20  value = np.full(1, 1, np.float32)
21 
22  # Use last process to set blob to make sure other processes
23  # are waiting for the blob before it is set.
24  if index == (num_procs - 1):
25  workspace.FeedBlob(blob, value)
26  workspace.RunOperatorOnce(
27  core.CreateOperator(
28  "StoreSet",
29  [store_handler, blob],
30  [],
31  blob_name=blob))
32 
33  output_blob = "output_blob"
34  workspace.RunOperatorOnce(
35  core.CreateOperator(
36  "StoreGet",
37  [store_handler],
38  [output_blob],
39  blob_name=blob))
40 
41  try:
42  np.testing.assert_array_equal(workspace.FetchBlob(output_blob), 1)
43  except AssertionError as err:
44  queue.put(err)
45 
46  workspace.ResetWorkspace()
47 
48  @classmethod
49  def test_set_get(cls, create_store_handler_fn):
50  # Queue for assertion errors on subprocesses
51  queue = Queue()
52 
53  # Start N processes in the background
54  num_procs = 4
55  procs = []
56  for index in range(num_procs):
57  proc = Process(
58  target=cls._test_set_get,
59  args=(queue, create_store_handler_fn, index, num_procs, ))
60  proc.start()
61  procs.append(proc)
62 
63  # Test complete, join background processes
64  for proc in procs:
65  proc.join()
66 
67  # Raise first error we find, if any
68  if not queue.empty():
69  raise queue.get()
70 
71  @classmethod
72  def test_get_timeout(cls, create_store_handler_fn):
73  store_handler = create_store_handler_fn()
74  net = core.Net('get_missing_blob')
75  net.StoreGet([store_handler], 1, blob_name='blob')
76  workspace.RunNetOnce(net)
def _test_set_get(cls, queue, create_store_handler_fn, index, num_procs)