@@ -20,10 +20,21 @@ limitations under the License.
20
20
#include " tensorflow/core/framework/tracking_allocator.h"
21
21
#include " tensorflow/core/graph/costmodel.h"
22
22
#include " tensorflow/core/lib/core/stringpiece.h"
23
+ #include " tensorflow/core/lib/strings/numbers.h"
23
24
#include " tensorflow/core/lib/strings/scanner.h"
24
25
#include " tensorflow/core/platform/logging.h"
25
26
26
27
namespace tensorflow {
28
+ namespace {
29
+ const int kMaxAllocReportNodes = 100 ;
30
+ const float kMaxAllocReportFraction = 0.99 ;
31
+
32
+ struct AllocStats {
33
+ std::map<int64, std::vector<string>> nodes_by_size;
34
+ int64 total_bytes = 0 ;
35
+ int64 total_nodes = 0 ;
36
+ };
37
+ } // namespace
27
38
28
39
NodeExecStatsWrapper::NodeExecStatsWrapper ()
29
40
: NodeExecStatsWrapper(new NodeExecStats) {}
@@ -267,6 +278,85 @@ void StepStatsCollector::Save(const string& device,
267
278
}
268
279
}
269
280
281
+ string StepStatsCollector::ReportAllocsOnResourceExhausted (const string& err) {
282
+ mutex_lock l (mu_);
283
+ if (err.find (" OOM" ) == err.npos ) {
284
+ return " " ;
285
+ }
286
+ // <device, allocator> -> AllocStats
287
+ std::map<std::pair<string, string>, AllocStats> allocs_map;
288
+ string report = " \n " ;
289
+ for (const auto & dev_stat : dev_stats_) {
290
+ const string& device = dev_stat.first ;
291
+ // Only print the device that has OOM.
292
+ // TODO(xpan): Extract device from err first to speed it up.
293
+ if (err.find (device) == err.npos ) {
294
+ continue ;
295
+ }
296
+ // NodeExecStatsWrapper*
297
+ for (const auto & stats : dev_stat.second ) {
298
+ // std::pair<AllocatorMemoryUsed*, TrackingAllocator*>
299
+ for (const auto & alloc : stats->allocations_ ) {
300
+ // Only print the allocator that has OOM.
301
+ // TODO(xpan): Extract device from err first to speed it up.
302
+ if (err.find (alloc.first ->allocator_name ()) == err.npos ) {
303
+ continue ;
304
+ }
305
+ auto dev_allocator =
306
+ std::make_pair (dev_stat.first , alloc.first ->allocator_name ());
307
+ AllocStats& dev_allocs_stats = allocs_map[dev_allocator];
308
+ TrackingAllocator* tracking_alloc = alloc.second ;
309
+ gtl::InlinedVector<AllocRecord, 4 > cur_records =
310
+ tracking_alloc->GetCurrentRecords ();
311
+ int64 cur_bytes = 0 ;
312
+ for (const auto & r : cur_records) {
313
+ cur_bytes += r.alloc_bytes ;
314
+ }
315
+ if (cur_bytes > 0 ) {
316
+ dev_allocs_stats.total_bytes += cur_bytes;
317
+ dev_allocs_stats.total_nodes ++;
318
+ dev_allocs_stats.nodes_by_size [cur_bytes].push_back (
319
+ stats->stats ()->node_name ());
320
+ }
321
+ }
322
+ }
323
+ }
324
+
325
+ for (const auto & dev_allocs_it : allocs_map) {
326
+ const auto & dev = dev_allocs_it.first ;
327
+ const AllocStats& dev_allocs_stats = dev_allocs_it.second ;
328
+ int64 reported_bytes = 0 ;
329
+ int64 reported_nodes = 0 ;
330
+ bool done = false ;
331
+ strings::StrAppend (&report, " \n Current usage from device: " , dev.first ,
332
+ " , allocator: " , dev.second , " \n " );
333
+ // Print allocations stats of the <device, allocator> pair.
334
+ for (auto it = dev_allocs_stats.nodes_by_size .rbegin ();
335
+ it != dev_allocs_stats.nodes_by_size .rend (); ++it) {
336
+ for (const string& node_name : it->second ) {
337
+ reported_bytes += it->first ;
338
+ strings::StrAppend (&report, " " ,
339
+ strings::HumanReadableNumBytes (it->first ), " from " ,
340
+ node_name, " \n " );
341
+ if (++reported_nodes > kMaxAllocReportNodes ||
342
+ reported_bytes >=
343
+ dev_allocs_stats.total_bytes * kMaxAllocReportFraction ) {
344
+ done = true ;
345
+ break ;
346
+ }
347
+ }
348
+ if (done) break ;
349
+ }
350
+ int64 remain_nodes = dev_allocs_stats.total_nodes - reported_nodes;
351
+ int64 remain_bytes = dev_allocs_stats.total_bytes - reported_bytes;
352
+ if (remain_nodes > 0 ) {
353
+ strings::StrAppend (&report, " Remaining " , remain_nodes, " nodes with " ,
354
+ strings::HumanReadableNumBytes (remain_bytes), " \n " );
355
+ }
356
+ }
357
+ return report;
358
+ }
359
+
270
360
void StepStatsCollector::Finalize () {
271
361
mutex_lock l (mu_);
272
362
FinalizeInternal ();
0 commit comments