Caffe2 - C++ API
A deep learning, cross platform ML framework
data_channel_tcp_smoke.cpp
1 #include <THD/base/data_channels/DataChannelTCP.hpp>
2 #include <THD/test/TestUtils.hpp>
3 
4 #include <THPP/tensors/THTensor.hpp>
5 
6 #include <cassert>
7 #include <iostream>
8 #include <memory>
9 #include <mutex>
10 #include <thread>
11 
12 constexpr int WORKERS_NUM = 2;
13 constexpr int MASTER_PORT = 45678;
14 
15 std::vector<std::thread> g_all_workers;
16 std::mutex g_mutex;
17 
18 void master() {
19  g_mutex.lock();
20  setenv(WORLD_SIZE_ENV, std::to_string((WORKERS_NUM + 1)).data(), 1);
21  setenv(RANK_ENV, "0", 1);
22  setenv(MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
23  auto masterChannel = std::make_shared<thd::DataChannelTCP>(
24  thd::getInitConfig("env://")); // reads all env variable
25  g_mutex.unlock();
26 
27  assert(masterChannel->init());
28  assert(masterChannel->getRank() == 0);
29  assert(masterChannel->getNumProcesses() == WORKERS_NUM + 1);
30 
31  // wait for all workers to finish
32  for (auto& worker : g_all_workers) {
33  worker.join();
34  }
35 }
36 
37 void worker(int id) {
38  g_mutex.lock();
39  setenv(RANK_ENV, std::to_string(id).data(), 1);
40  setenv(
41  MASTER_ADDR_ENV,
42  std::string("127.0.0.1:" + std::to_string(MASTER_PORT)).data(),
43  1);
44  auto workerChannel = std::make_shared<thd::DataChannelTCP>(
45  thd::getInitConfig("env://")); // reads all env variable
46  g_mutex.unlock();
47 
48  assert(workerChannel->init());
49  assert(workerChannel->getRank() == id);
50  assert(workerChannel->getNumProcesses() == WORKERS_NUM + 1);
51 }
52 
53 int main() {
54  // start master
55  std::thread master_thread(master);
56 
57  // start worker
58  for (int id = 1; id <= WORKERS_NUM; ++id) {
59  g_all_workers.push_back(std::thread(worker, id));
60  }
61 
62  master_thread.join();
63  std::cout << "OK" << std::endl;
64  return 0;
65 }