Caffe2 - C++ API
A deep learning, cross platform ML framework
data_channel_gloo_cache.cpp
1 #include <THD/base/data_channels/DataChannelGloo.hpp>
2 #include <THD/test/TestUtils.hpp>
3 
4 #include <THPP/tensors/THTensor.hpp>
5 
6 #include <unistd.h>
7 #include <array>
8 #include <cassert>
9 #include <iostream>
10 #include <memory>
11 #include <mutex>
12 #include <thread>
13 #include <vector>
14 
15 constexpr std::array<int, 1> WORKERS_NUM = {10};
16 constexpr int MASTER_PORT = 45678;
17 
18 std::vector<std::thread> g_all_workers;
19 std::mutex g_mutex;
20 
21 void test(std::shared_ptr<thd::DataChannel> data_channel) {
22  for (size_t dest = 0; dest < data_channel->getNumProcesses(); ++dest) {
23  if (data_channel->getRank() == dest) {
24  auto float_tensor = buildTensor<float>({1, 2, 3, 4, 5}, 10.123);
25  data_channel->broadcast(*float_tensor, dest);
26  } else {
27  auto float_tensor = buildTensor<float>({1, 2, 3, 4, 5}, -1.0);
28  data_channel->broadcast(*float_tensor, dest);
29  ASSERT_TENSOR_VALUE(float, *float_tensor, 10.123)
30  }
31  }
32 }
33 
34 void run_all_tests(
35  std::shared_ptr<thd::DataChannel> data_channel,
36  int workers) {
37  // NOTE: without properly working GlooCache this test would create
38  // about (1000 * WORKERS ^ 3) connections what is over 'normal' system
39  // configuration
40  for (size_t i = 0; i < 1000; ++i) {
41  test(data_channel);
42  }
43 }
44 
45 void init_gloo_master(int workers) {
46  g_mutex.lock();
47  setenv(WORLD_SIZE_ENV, std::to_string((workers + 1)).data(), 1);
48  setenv(RANK_ENV, "0", 1);
49  setenv(MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
50  auto masterChannel = std::make_shared<thd::DataChannelGloo>(
51  thd::getInitConfig("env://")); // reads all env variable
52  g_mutex.unlock();
53 
54  assert(masterChannel->init());
55  run_all_tests(masterChannel, workers);
56 }
57 
58 void init_gloo_worker(unsigned int id, int workers) {
59  g_mutex.lock();
60  setenv(RANK_ENV, std::to_string(id).data(), 1);
61  setenv(
62  MASTER_ADDR_ENV,
63  std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(),
64  1);
65  auto worker_channel = std::make_shared<thd::DataChannelGloo>(
66  thd::getInitConfig("env://")); // reads all env variable
67  g_mutex.unlock();
68 
69  assert(worker_channel->init());
70  run_all_tests(worker_channel, workers);
71 }
72 
73 int main(void) {
74  for (auto workers : WORKERS_NUM) {
75  std::cout << "Gloo (workers: " << workers << "):" << std::endl;
76  // start gloo master
77  std::thread gloo_master_thread(init_gloo_master, workers);
78 
79  // start gloo worker
80  for (int id = 1; id <= workers; ++id) {
81  g_all_workers.push_back(std::thread(init_gloo_worker, id, workers));
82  }
83 
84  // wait for all workers to finish
85  for (auto& worker : g_all_workers) {
86  worker.join();
87  }
88 
89  gloo_master_thread.join();
90  g_all_workers.clear();
91 
92  std::cout << "Gloo - OK" << std::endl;
93  }
94 }
Definition: module.cpp:17