1 #include <THD/base/data_channels/DataChannelTCP.hpp> 2 #include <THD/test/TestUtils.hpp> 4 #include <THPP/tensors/THTensor.hpp> 12 constexpr
int WORKERS_NUM = 2;
13 constexpr
int MASTER_PORT = 45678;
15 std::vector<std::thread> g_all_workers;
20 setenv(WORLD_SIZE_ENV, std::to_string((WORKERS_NUM + 1)).data(), 1);
21 setenv(RANK_ENV,
"0", 1);
22 setenv(MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
23 auto masterChannel = std::make_shared<thd::DataChannelTCP>(
24 thd::getInitConfig(
"env://"));
27 assert(masterChannel->init());
28 assert(masterChannel->getRank() == 0);
29 assert(masterChannel->getNumProcesses() == WORKERS_NUM + 1);
32 for (
auto& worker : g_all_workers) {
39 setenv(RANK_ENV, std::to_string(
id).data(), 1);
42 std::string(
"127.0.0.1:" + std::to_string(MASTER_PORT)).data(),
44 auto workerChannel = std::make_shared<thd::DataChannelTCP>(
45 thd::getInitConfig(
"env://"));
48 assert(workerChannel->init());
49 assert(workerChannel->getRank() == id);
50 assert(workerChannel->getNumProcesses() == WORKERS_NUM + 1);
55 std::thread master_thread(master);
58 for (
int id = 1;
id <= WORKERS_NUM; ++id) {
59 g_all_workers.push_back(std::thread(worker,
id));
63 std::cout <<
"OK" << std::endl;