Caffe2 - C++ API
A deep learning, cross platform ML framework
net_async_tracing.cc
1 
17 #include "caffe2/core/net_async_tracing.h"
18 
19 #include "caffe2/utils/proto_utils.h"
20 #include "caffe2/utils/string_utils.h"
21 
22 C10_DEFINE_string(
23  caffe2_net_async_tracing_filepath,
24  "/tmp",
25  "Path to save tracing information");
26 
27 C10_DEFINE_string(
28  caffe2_net_async_names_to_trace,
29  "",
30  "Comma-separated list of net names to trace");
31 
32 C10_DEFINE_int(caffe2_net_async_tracing_nth, 100, "Trace every Nth batch");
33 
34 // For every Nth iterations, we will dump the tracing results to a json file
35 // The file is appended with the iteration number.
36 C10_DEFINE_int(
37  caffe2_net_async_tracing_dumping_nth,
38  10000,
39  "Dump profiling result file every Nth batch");
40 
41 namespace caffe2 {
42 namespace tracing {
43 
44 int getCounterForNetName(const std::string& net_name) {
45  // Append a unique number suffix because there could be multiple instances
46  // of the same net and we want to uniquely associate each instance with
47  // a profiling trace.
48  static std::unordered_map<std::string, int> net_name_to_counter;
49  static std::mutex map_mutex;
50  std::unique_lock<std::mutex> map_lock(map_mutex);
51  int counter = net_name_to_counter[net_name] + 1;
52  net_name_to_counter[net_name] = counter;
53  return counter;
54 }
55 
56 Tracer::Tracer(
57  const NetBase* net,
58  const std::string& net_name,
59  TracingConfig config)
60  : net_(net),
61  filename_(net_name),
62  iter_(0),
63  dumping_iter_(0),
64  config_(config) {
65  std::replace(filename_.begin(), filename_.end(), '/', '_');
66  filename_ = this->config().filepath + "/" + filename_ + "_id_" +
67  c10::to_string(getCounterForNetName(net_name));
68  timer_.Start();
69 }
70 
71 void Tracer::recordEvent(const TracerEvent& event) {
72  std::lock_guard<std::mutex> lock(tracer_mutex_);
73  events_.push_back(event);
74 }
75 
76 // Forward
77 int getUniqueShardId(const OperatorDef& op_def);
78 
79 // Special handling of shard blob annotations
80 std::string Tracer::opTraceName(const OperatorBase* op) {
81  int unique_shard_id =
82  op->has_debug_def() ? getUniqueShardId(op->debug_def()) : -1;
83  if (unique_shard_id != -1) {
84  return op->type() + ":" + c10::to_string(unique_shard_id);
85  } else {
86  return op->type();
87  }
88 }
89 
90 std::string Tracer::opBlobsInfo(const OperatorBase& op) {
91  std::string blobs_info;
92  if (op.has_debug_def()) {
93  blobs_info += "I: ";
94  const auto& op_def = op.debug_def();
95  for (const auto& input : op_def.input()) {
96  blobs_info += input + "; ";
97  }
98  blobs_info += "O: ";
99  for (const auto& output : op_def.output()) {
100  blobs_info += output + "; ";
101  }
102  }
103  return blobs_info;
104 }
105 
106 std::string Tracer::serializeEvent(const TracerEvent& event) {
107  std::stringstream serialized_event;
108  serialized_event << std::fixed;
109  serialized_event << "{\n";
110  serialized_event << " \"ts\": " << event.timestamp_ << ",\n";
111  serialized_event << " \"pid\": 0,\n"; // not using pid field
112  if (event.thread_label_ >= 0) {
113  serialized_event << " \"tid\": " << event.thread_label_ << ",\n";
114  } else {
115  serialized_event << " \"tid\": " << event.tid_ << ",\n";
116  }
117 
118  if (event.is_beginning_) {
119  std::unordered_map<std::string, int> int_args;
120  std::unordered_map<std::string, std::string> string_args;
121  if (event.name_) {
122  serialized_event << " \"name\": \"" << event.name_ << "\",\n";
123  } else if (event.op_id_ >= 0) {
124  auto* op = net_->GetOperators().at(event.op_id_);
125  serialized_event << " \"name\": \"" << opTraceName(op) << "\",\n";
126  } else {
127  serialized_event << " \"name\": \"n/a\",\n";
128  }
129 
130  if (event.category_) {
131  serialized_event << " \"cat\": \"" << event.category_ << "\",\n";
132  } else {
133  serialized_event << " \"cat\": \"net\",\n";
134  }
135 
136  if (event.op_id_ >= 0) {
137  auto* op = net_->GetOperators().at(event.op_id_);
138  int_args["op_id"] = event.op_id_;
139  int_args["device_type"] = op->device_option().device_type();
140  int_args["device_id"] = DeviceId(op->device_option());
141  string_args["blobs"] = opBlobsInfo(*op);
142  }
143 
144  if (event.task_id_ >= 0) {
145  int_args["task_id"] = event.task_id_;
146  }
147 
148  if (event.stream_id_ >= 0) {
149  int_args["stream_id"] = event.stream_id_;
150  }
151 
152  serialized_event << " \"ph\": \"B\"";
153  if (!int_args.empty() || !string_args.empty()) {
154  serialized_event << ",\n \"args\": {\n";
155  auto left_to_output = int_args.size() + string_args.size();
156  for (const auto& kv : int_args) {
157  serialized_event << " \"" << kv.first << "\": " << kv.second;
158  --left_to_output;
159  if (left_to_output > 0) {
160  serialized_event << ",\n";
161  }
162  }
163  for (const auto& kv : string_args) {
164  serialized_event << " \"" << kv.first << "\": \"" << kv.second << "\"";
165  --left_to_output;
166  if (left_to_output > 0) {
167  serialized_event << ",\n";
168  }
169  }
170  serialized_event << "\n }";
171  }
172  } else {
173  serialized_event << " \"ph\": \"E\"\n";
174  }
175  serialized_event << "\n}";
176 
177  return serialized_event.str();
178 }
179 
180 // fix occasional cases with zero duration events
181 void Tracer::linearizeEvents() {
182  std::unordered_map<long, long> time_offsets;
183  std::unordered_map<long, long> last_times;
184  std::hash<std::thread::id> hasher;
185  const long time_eps = 1; // us
186  for (auto& event : events_) {
187  long tid =
188  (event.thread_label_ >= 0) ? event.thread_label_ : hasher(event.tid_);
189  auto event_ts = event.timestamp_;
190  if (last_times.count(tid)) {
191  event_ts += time_offsets[tid];
192  CAFFE_ENFORCE(event_ts >= last_times[tid]);
193  if (event_ts <= last_times[tid] + time_eps) {
194  event_ts += time_eps;
195  time_offsets[tid] += time_eps;
196  } else if (event_ts > last_times[tid] + 2 * time_eps) {
197  long eps_len = (event_ts - last_times[tid]) / time_eps;
198  if (time_offsets[tid] >= time_eps * (eps_len - 1)) {
199  time_offsets[tid] -= time_eps * (eps_len - 1);
200  event_ts -= time_eps * (eps_len - 1);
201  } else {
202  event_ts -= time_offsets[tid];
203  time_offsets[tid] = 0;
204  }
205  }
206  event.timestamp_ = event_ts;
207  last_times[tid] = event_ts;
208  } else {
209  last_times[tid] = event_ts;
210  time_offsets[tid] = 0;
211  }
212  }
213 }
214 
215 void Tracer::renameThreads() {
216  std::unordered_map<long, int> tids;
217  std::unordered_map<int, int> numa_counters;
218  std::unordered_map<long, int> tid_to_numa;
219  std::hash<std::thread::id> hasher;
220  const long numa_multiplier = 1000000000;
221  for (auto& event : events_) {
222  if (event.thread_label_ >= 0 || event.op_id_ < 0) {
223  continue;
224  }
225  auto* op = net_->GetOperators().at(event.op_id_);
226  if (!op->device_option().has_numa_node_id()) {
227  continue;
228  }
229  int numa_node_id = op->device_option().numa_node_id();
230  CAFFE_ENFORCE_GE(numa_node_id, 0, "Invalid NUMA node id: ", numa_node_id);
231  long tid = hasher(event.tid_);
232 
233  if (!tid_to_numa.count(tid)) {
234  tid_to_numa[tid] = numa_node_id;
235  } else {
236  CAFFE_ENFORCE_EQ(tid_to_numa[tid], numa_node_id);
237  }
238 
239  if (!numa_counters.count(numa_node_id)) {
240  numa_counters[numa_node_id] = 1;
241  }
242  if (!tids.count(tid)) {
243  tids[tid] = numa_counters[numa_node_id]++;
244  }
245  event.thread_label_ = numa_multiplier * (numa_node_id + 1) + tids[tid];
246  }
247 }
248 
249 void Tracer::setEnabled(bool enabled) {
250  enabled_ = enabled;
251 }
252 
253 bool Tracer::isEnabled() const {
254  return enabled_;
255 }
256 
257 int Tracer::bumpIter() {
258  return iter_++;
259 }
260 
261 int Tracer::bumpDumpingIter() {
262  return dumping_iter_++;
263 }
264 
265 void Tracer::dumpTracingResultAndClearEvents(const std::string& file_suffix) {
266  if (events_.empty() || filename_.empty()) {
267  return;
268  }
269  linearizeEvents();
270  renameThreads();
271  std::stringstream serialized;
272  serialized << "[\n";
273  for (size_t idx = 0; idx < events_.size(); ++idx) {
274  serialized << serializeEvent(events_[idx]);
275  if (idx != events_.size() - 1) {
276  serialized << ",\n";
277  }
278  }
279  serialized << "\n]\n";
280 
281  auto output_file_name = filename_ + "_iter_" + file_suffix + ".json";
282  LOG(INFO) << "Dumping profiling result file to " << output_file_name;
283  WriteStringToFile(serialized.str(), output_file_name.c_str());
284  events_.clear();
285 }
286 
287 Tracer::~Tracer() {
288  dumpTracingResultAndClearEvents("final_batch");
289 }
290 
291 void TracerGuard::init(Tracer* tracer) {
292  enabled_ = true;
293  tracer_ = tracer;
294 }
295 
296 void TracerGuard::addArgument() {}
297 
298 void TracerGuard::addArgument(TracingField field, const char* value) {
299  switch (field) {
300  case TRACE_NAME: {
301  event_.name_ = value;
302  break;
303  }
304  case TRACE_CATEGORY: {
305  event_.category_ = value;
306  break;
307  }
308  default: {
309  CAFFE_THROW("Unexpected tracing string field ", field);
310  }
311  }
312 }
313 
314 void TracerGuard::addArgument(TracingField field, int value) {
315  switch (field) {
316  case TRACE_OP: {
317  event_.op_id_ = value;
318  break;
319  }
320  case TRACE_TASK: {
321  event_.task_id_ = value;
322  break;
323  }
324  case TRACE_STREAM: {
325  event_.stream_id_ = value;
326  break;
327  }
328  case TRACE_THREAD: {
329  event_.thread_label_ = value;
330  break;
331  }
332  default: {
333  CAFFE_THROW("Unexpected tracing int field ", field);
334  }
335  }
336 }
337 
338 void TracerGuard::recordEventStart() {
339  if (enabled_) {
340  if (event_.thread_label_ < 0) {
341  event_.tid_ = std::this_thread::get_id();
342  }
343  event_.is_beginning_ = true;
344  event_.timestamp_ = (long)caffe2::round(tracer_->timer_.MicroSeconds());
345  tracer_->recordEvent(event_);
346  }
347 }
348 
349 TracerGuard::~TracerGuard() {
350  if (enabled_) {
351  event_.is_beginning_ = false;
352  event_.timestamp_ = (long)caffe2::round(tracer_->timer_.MicroSeconds());
353  tracer_->recordEvent(event_);
354  }
355 }
356 
357 int extractShardId(const std::string& name) {
358  const std::string kShard = "shard:";
359  // We sometimes have multiple shards, but actually need the last one, hence
360  // using rfind here. Hacky but it works till we pass shard id in graph
361  // metadata.
362  auto pos = name.rfind(kShard);
363  if (pos != std::string::npos) {
364  int left_pos = pos + kShard.length();
365  int right_pos = left_pos;
366  while (right_pos < name.length() && isdigit(name[right_pos])) {
367  right_pos++;
368  }
369  return c10::stoi(name.substr(left_pos, right_pos - left_pos));
370  } else {
371  return -1;
372  }
373 }
374 
375 // Return unique shard id, or -1 if it is not unique.
376 int getUniqueShardId(const OperatorDef& op_def) {
377  int unique_shard_id = -1;
378  for (const auto& names : {op_def.input(), op_def.output()}) {
379  for (const auto& name : names) {
380  int shard_id = extractShardId(name);
381  if (shard_id != -1) {
382  if (unique_shard_id != -1) {
383  return -1;
384  }
385  unique_shard_id = shard_id;
386  }
387  }
388  }
389  return unique_shard_id;
390 }
391 
392 bool isTraceableNetName(const std::string& net_name) {
393  auto tracing_nets = caffe2::split(',', FLAGS_caffe2_net_async_names_to_trace);
394  return !net_name.empty() &&
395  std::find(tracing_nets.begin(), tracing_nets.end(), net_name) !=
396  tracing_nets.end();
397 }
398 
399 bool hasEnableTracingFlag(const NetBase* net) {
400  if (!net->has_debug_def()) {
401  return false;
402  }
403  return GetFlagArgument(net->debug_def(), "enable_tracing", false);
404 }
405 
406 TracingConfig getTracingConfigFromNet(const NetBase* net) {
407  ArgumentHelper arg_helper(net->debug_def());
408  TracingConfig cfg;
409 
410  cfg.mode = (arg_helper.GetSingleArgument<std::string>("tracing_mode", "") ==
411  "GLOBAL_TIMESLICE")
412  ? TracingMode::GLOBAL_TIMESLICE
413  : TracingMode::EVERY_K_ITERATIONS;
414 
415  cfg.filepath = arg_helper.GetSingleArgument<std::string>(
416  "tracing_filepath", FLAGS_caffe2_net_async_tracing_filepath);
417 
418  cfg.trace_every_nth_batch = arg_helper.GetSingleArgument<int>(
419  "trace_every_nth_batch", FLAGS_caffe2_net_async_tracing_nth);
420  cfg.dump_every_nth_batch = arg_helper.GetSingleArgument<int>(
421  "dump_every_nth_batch", FLAGS_caffe2_net_async_tracing_dumping_nth);
422 
423  cfg.trace_for_n_ms =
424  arg_helper.GetSingleArgument<int>("trace_for_n_ms", cfg.trace_for_n_ms);
425  cfg.trace_every_n_ms = arg_helper.GetSingleArgument<int>(
426  "trace_every_n_ms", cfg.trace_every_n_ms);
427 
428  return cfg;
429 };
430 
431 std::shared_ptr<Tracer> create(
432  const NetBase* net,
433  const std::string& net_name) {
434  // Enable the tracer if the net has the "enable_tracing" argument set OR
435  // if the command line option includes the net name option in the list of
436  // tracable nets.
437  bool trace_net = hasEnableTracingFlag(net) || isTraceableNetName(net_name);
438  return trace_net
439  ? std::make_shared<Tracer>(net, net_name, getTracingConfigFromNet(net))
440  : nullptr;
441 }
442 
443 bool startIter(const std::shared_ptr<Tracer>& tracer) {
444  if (!tracer) {
445  return false;
446  }
447  auto iter = tracer->bumpIter();
448  bool is_enabled;
449  bool should_dump;
450  if (tracer->config().mode == TracingMode::EVERY_K_ITERATIONS) {
451  is_enabled = iter % tracer->config().trace_every_nth_batch == 0;
452  should_dump = iter % tracer->config().dump_every_nth_batch == 0;
453  } else {
454  using namespace std::chrono;
455  auto ms =
456  duration_cast<milliseconds>(system_clock::now().time_since_epoch())
457  .count();
458  is_enabled = (ms % tracer->config().trace_every_n_ms) <
459  tracer->config().trace_for_n_ms;
460  // dump just after disabled tracing
461  should_dump = tracer->isEnabled() && !is_enabled;
462  }
463  tracer->setEnabled(is_enabled);
464  if (should_dump) {
465  int dumping_iter = tracer->bumpDumpingIter();
466  tracer->dumpTracingResultAndClearEvents(c10::to_string(dumping_iter));
467  }
468  return is_enabled;
469 }
470 
471 } // namespace tracing
472 
473 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13