1 #include <THD/base/data_channels/DataChannelTCP.hpp> 2 #include <THD/test/TestUtils.hpp> 4 #include <THPP/tensors/THTensor.hpp> 12 constexpr
int WORKERS_NUM = 2;
13 constexpr
int MASTER_PORT = 45679;
15 std::vector<std::thread> g_all_workers;
20 setenv(WORLD_SIZE_ENV, std::to_string((WORKERS_NUM + 1)).data(), 1);
21 setenv(RANK_ENV,
"0", 1);
22 setenv(MASTER_PORT_ENV, std::to_string(MASTER_PORT).data(), 1);
23 auto masterChannel = std::make_shared<thd::DataChannelTCP>(
24 thd::getInitConfig(
"env://"));
28 std::this_thread::sleep_for(std::chrono::seconds(4));
30 assert(masterChannel->init());
32 auto float_tensor = buildTensor<float>({1, 2, 3}, 4);
33 masterChannel->broadcast(*float_tensor, 0);
36 for (
auto& worker : g_all_workers) {
43 setenv(RANK_ENV, std::to_string(
id).data(), 1);
46 std::string(
"127.0.0.1:" + std::to_string(MASTER_PORT)).data(),
48 auto workerChannel = std::make_shared<thd::DataChannelTCP>(
49 thd::getInitConfig(
"env://"));
52 assert(workerChannel->init());
54 auto float_tensor = buildTensor<float>({1, 2, 3}, -1);
55 workerChannel->broadcast(*float_tensor, 0);
56 ASSERT_TENSOR_VALUE(
float, *float_tensor, 4)
61 std::thread master_thread(master);
64 for (
int id = 1;
id <= WORKERS_NUM; ++id) {
65 g_all_workers.push_back(std::thread(worker,
id));
69 std::cout <<
"OK" << std::endl;