17 #include "caffe2/core/net_async_tracing.h" 19 #include "caffe2/utils/proto_utils.h" 20 #include "caffe2/utils/string_utils.h" 23 caffe2_net_async_tracing_filepath,
25 "Path to save tracing information");
28 caffe2_net_async_names_to_trace,
30 "Comma-separated list of net names to trace");
32 C10_DEFINE_int(caffe2_net_async_tracing_nth, 100,
"Trace every Nth batch");
37 caffe2_net_async_tracing_dumping_nth,
39 "Dump profiling result file every Nth batch");
44 int getCounterForNetName(
const std::string& net_name) {
48 static std::unordered_map<std::string, int> net_name_to_counter;
49 static std::mutex map_mutex;
50 std::unique_lock<std::mutex> map_lock(map_mutex);
51 int counter = net_name_to_counter[net_name] + 1;
52 net_name_to_counter[net_name] = counter;
58 const std::string& net_name,
65 std::replace(filename_.begin(), filename_.end(),
'/',
'_');
66 filename_ = this->config().filepath +
"/" + filename_ +
"_id_" +
67 c10::to_string(getCounterForNetName(net_name));
71 void Tracer::recordEvent(
const TracerEvent& event) {
72 std::lock_guard<std::mutex> lock(tracer_mutex_);
73 events_.push_back(event);
77 int getUniqueShardId(
const OperatorDef& op_def);
80 std::string Tracer::opTraceName(
const OperatorBase* op) {
82 op->has_debug_def() ? getUniqueShardId(op->debug_def()) : -1;
83 if (unique_shard_id != -1) {
84 return op->type() +
":" + c10::to_string(unique_shard_id);
90 std::string Tracer::opBlobsInfo(
const OperatorBase& op) {
91 std::string blobs_info;
92 if (op.has_debug_def()) {
94 const auto& op_def = op.debug_def();
95 for (
const auto& input : op_def.input()) {
96 blobs_info += input +
"; ";
99 for (
const auto& output : op_def.output()) {
100 blobs_info += output +
"; ";
106 std::string Tracer::serializeEvent(
const TracerEvent& event) {
107 std::stringstream serialized_event;
108 serialized_event << std::fixed;
109 serialized_event <<
"{\n";
110 serialized_event <<
" \"ts\": " <<
event.timestamp_ <<
",\n";
111 serialized_event <<
" \"pid\": 0,\n";
112 if (event.thread_label_ >= 0) {
113 serialized_event <<
" \"tid\": " <<
event.thread_label_ <<
",\n";
115 serialized_event <<
" \"tid\": " <<
event.tid_ <<
",\n";
118 if (event.is_beginning_) {
119 std::unordered_map<std::string, int> int_args;
120 std::unordered_map<std::string, std::string> string_args;
122 serialized_event <<
" \"name\": \"" <<
event.name_ <<
"\",\n";
123 }
else if (event.op_id_ >= 0) {
124 auto* op = net_->GetOperators().at(event.op_id_);
125 serialized_event <<
" \"name\": \"" << opTraceName(op) <<
"\",\n";
127 serialized_event <<
" \"name\": \"n/a\",\n";
130 if (event.category_) {
131 serialized_event <<
" \"cat\": \"" <<
event.category_ <<
"\",\n";
133 serialized_event <<
" \"cat\": \"net\",\n";
136 if (event.op_id_ >= 0) {
137 auto* op = net_->GetOperators().at(event.op_id_);
138 int_args[
"op_id"] =
event.op_id_;
139 int_args[
"device_type"] = op->device_option().device_type();
140 int_args[
"device_id"] = DeviceId(op->device_option());
141 string_args[
"blobs"] = opBlobsInfo(*op);
144 if (event.task_id_ >= 0) {
145 int_args[
"task_id"] =
event.task_id_;
148 if (event.stream_id_ >= 0) {
149 int_args[
"stream_id"] =
event.stream_id_;
152 serialized_event <<
" \"ph\": \"B\"";
153 if (!int_args.empty() || !string_args.empty()) {
154 serialized_event <<
",\n \"args\": {\n";
155 auto left_to_output = int_args.size() + string_args.size();
156 for (
const auto& kv : int_args) {
157 serialized_event <<
" \"" << kv.first <<
"\": " << kv.second;
159 if (left_to_output > 0) {
160 serialized_event <<
",\n";
163 for (
const auto& kv : string_args) {
164 serialized_event <<
" \"" << kv.first <<
"\": \"" << kv.second <<
"\"";
166 if (left_to_output > 0) {
167 serialized_event <<
",\n";
170 serialized_event <<
"\n }";
173 serialized_event <<
" \"ph\": \"E\"\n";
175 serialized_event <<
"\n}";
177 return serialized_event.str();
181 void Tracer::linearizeEvents() {
182 std::unordered_map<long, long> time_offsets;
183 std::unordered_map<long, long> last_times;
184 std::hash<std::thread::id> hasher;
185 const long time_eps = 1;
186 for (
auto& event : events_) {
188 (
event.thread_label_ >= 0) ? event.thread_label_ : hasher(event.tid_);
189 auto event_ts =
event.timestamp_;
190 if (last_times.count(tid)) {
191 event_ts += time_offsets[tid];
192 CAFFE_ENFORCE(event_ts >= last_times[tid]);
193 if (event_ts <= last_times[tid] + time_eps) {
194 event_ts += time_eps;
195 time_offsets[tid] += time_eps;
196 }
else if (event_ts > last_times[tid] + 2 * time_eps) {
197 long eps_len = (event_ts - last_times[tid]) / time_eps;
198 if (time_offsets[tid] >= time_eps * (eps_len - 1)) {
199 time_offsets[tid] -= time_eps * (eps_len - 1);
200 event_ts -= time_eps * (eps_len - 1);
202 event_ts -= time_offsets[tid];
203 time_offsets[tid] = 0;
206 event.timestamp_ = event_ts;
207 last_times[tid] = event_ts;
209 last_times[tid] = event_ts;
210 time_offsets[tid] = 0;
215 void Tracer::renameThreads() {
216 std::unordered_map<long, int> tids;
217 std::unordered_map<int, int> numa_counters;
218 std::unordered_map<long, int> tid_to_numa;
219 std::hash<std::thread::id> hasher;
220 const long numa_multiplier = 1000000000;
221 for (
auto& event : events_) {
222 if (event.thread_label_ >= 0 || event.op_id_ < 0) {
225 auto* op = net_->GetOperators().at(event.op_id_);
226 if (!op->device_option().has_numa_node_id()) {
229 int numa_node_id = op->device_option().numa_node_id();
230 CAFFE_ENFORCE_GE(numa_node_id, 0,
"Invalid NUMA node id: ", numa_node_id);
231 long tid = hasher(event.tid_);
233 if (!tid_to_numa.count(tid)) {
234 tid_to_numa[tid] = numa_node_id;
236 CAFFE_ENFORCE_EQ(tid_to_numa[tid], numa_node_id);
239 if (!numa_counters.count(numa_node_id)) {
240 numa_counters[numa_node_id] = 1;
242 if (!tids.count(tid)) {
243 tids[tid] = numa_counters[numa_node_id]++;
245 event.thread_label_ = numa_multiplier * (numa_node_id + 1) + tids[tid];
249 void Tracer::setEnabled(
bool enabled) {
253 bool Tracer::isEnabled()
const {
257 int Tracer::bumpIter() {
261 int Tracer::bumpDumpingIter() {
262 return dumping_iter_++;
265 void Tracer::dumpTracingResultAndClearEvents(
const std::string& file_suffix) {
266 if (events_.empty() || filename_.empty()) {
271 std::stringstream serialized;
273 for (
size_t idx = 0; idx < events_.size(); ++idx) {
274 serialized << serializeEvent(events_[idx]);
275 if (idx != events_.size() - 1) {
279 serialized <<
"\n]\n";
281 auto output_file_name = filename_ +
"_iter_" + file_suffix +
".json";
282 LOG(INFO) <<
"Dumping profiling result file to " << output_file_name;
283 WriteStringToFile(serialized.str(), output_file_name.c_str());
288 dumpTracingResultAndClearEvents(
"final_batch");
291 void TracerGuard::init(Tracer* tracer) {
296 void TracerGuard::addArgument() {}
298 void TracerGuard::addArgument(TracingField field,
const char* value) {
301 event_.name_ = value;
304 case TRACE_CATEGORY: {
305 event_.category_ = value;
309 CAFFE_THROW(
"Unexpected tracing string field ", field);
314 void TracerGuard::addArgument(TracingField field,
int value) {
317 event_.op_id_ = value;
321 event_.task_id_ = value;
325 event_.stream_id_ = value;
329 event_.thread_label_ = value;
333 CAFFE_THROW(
"Unexpected tracing int field ", field);
338 void TracerGuard::recordEventStart() {
340 if (event_.thread_label_ < 0) {
341 event_.tid_ = std::this_thread::get_id();
343 event_.is_beginning_ =
true;
344 event_.timestamp_ = (long)caffe2::round(tracer_->timer_.MicroSeconds());
345 tracer_->recordEvent(event_);
349 TracerGuard::~TracerGuard() {
351 event_.is_beginning_ =
false;
352 event_.timestamp_ = (long)caffe2::round(tracer_->timer_.MicroSeconds());
353 tracer_->recordEvent(event_);
357 int extractShardId(
const std::string& name) {
358 const std::string kShard =
"shard:";
362 auto pos = name.rfind(kShard);
363 if (pos != std::string::npos) {
364 int left_pos = pos + kShard.length();
365 int right_pos = left_pos;
366 while (right_pos < name.length() && isdigit(name[right_pos])) {
369 return c10::stoi(name.substr(left_pos, right_pos - left_pos));
376 int getUniqueShardId(
const OperatorDef& op_def) {
377 int unique_shard_id = -1;
378 for (
const auto& names : {op_def.input(), op_def.output()}) {
379 for (
const auto& name : names) {
380 int shard_id = extractShardId(name);
381 if (shard_id != -1) {
382 if (unique_shard_id != -1) {
385 unique_shard_id = shard_id;
389 return unique_shard_id;
392 bool isTraceableNetName(
const std::string& net_name) {
393 auto tracing_nets = caffe2::split(
',', FLAGS_caffe2_net_async_names_to_trace);
394 return !net_name.empty() &&
395 std::find(tracing_nets.begin(), tracing_nets.end(), net_name) !=
399 bool hasEnableTracingFlag(
const NetBase* net) {
400 if (!net->has_debug_def()) {
403 return GetFlagArgument(net->debug_def(),
"enable_tracing",
false);
406 TracingConfig getTracingConfigFromNet(
const NetBase* net) {
407 ArgumentHelper arg_helper(net->debug_def());
410 cfg.mode = (arg_helper.GetSingleArgument<std::string>(
"tracing_mode",
"") ==
412 ? TracingMode::GLOBAL_TIMESLICE
413 : TracingMode::EVERY_K_ITERATIONS;
415 cfg.filepath = arg_helper.GetSingleArgument<std::string>(
416 "tracing_filepath", FLAGS_caffe2_net_async_tracing_filepath);
418 cfg.trace_every_nth_batch = arg_helper.GetSingleArgument<
int>(
419 "trace_every_nth_batch", FLAGS_caffe2_net_async_tracing_nth);
420 cfg.dump_every_nth_batch = arg_helper.GetSingleArgument<
int>(
421 "dump_every_nth_batch", FLAGS_caffe2_net_async_tracing_dumping_nth);
424 arg_helper.GetSingleArgument<
int>(
"trace_for_n_ms", cfg.trace_for_n_ms);
425 cfg.trace_every_n_ms = arg_helper.GetSingleArgument<
int>(
426 "trace_every_n_ms", cfg.trace_every_n_ms);
431 std::shared_ptr<Tracer> create(
433 const std::string& net_name) {
437 bool trace_net = hasEnableTracingFlag(net) || isTraceableNetName(net_name);
439 ? std::make_shared<Tracer>(net, net_name, getTracingConfigFromNet(net))
443 bool startIter(
const std::shared_ptr<Tracer>& tracer) {
447 auto iter = tracer->bumpIter();
450 if (tracer->config().mode == TracingMode::EVERY_K_ITERATIONS) {
451 is_enabled = iter % tracer->config().trace_every_nth_batch == 0;
452 should_dump = iter % tracer->config().dump_every_nth_batch == 0;
454 using namespace std::chrono;
456 duration_cast<milliseconds>(system_clock::now().time_since_epoch())
458 is_enabled = (ms % tracer->config().trace_every_n_ms) <
459 tracer->config().trace_for_n_ms;
461 should_dump = tracer->isEnabled() && !is_enabled;
463 tracer->setEnabled(is_enabled);
465 int dumping_iter = tracer->bumpDumpingIter();
466 tracer->dumpTracingResultAndClearEvents(c10::to_string(dumping_iter));
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...