Skip to content

Commit

Permalink
added the new columns for --benchmark_format=csv as well
Browse files Browse the repository at this point in the history
  • Loading branch information
NguyenNhuDi committed Aug 19, 2024
1 parent ac1a167 commit b56cb56
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 203 deletions.
53 changes: 16 additions & 37 deletions benchmark/benchmark_rocrand_device_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
#include <rocrand/rocrand_kernel.h>
#include <rocrand/rocrand_mtgp32_11213.h>

#include "custom_csv_formater.hpp"
#include <algorithm>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <numeric>
#include <string>
#include <vector>
#include <fstream>
#include "custom_csv_formater.hpp"

#ifndef DEFAULT_RAND_N
#define DEFAULT_RAND_N (1024 * 1024 * 128)
Expand Down Expand Up @@ -723,26 +723,13 @@ void add_benchmarks(const benchmark_context &ctx, const hipStream_t stream,
}

int main(int argc, char *argv[]) {

// get the out format and out file name thats being passed into
// get paramaters before they are passed into
// benchmark::Initialize()
std::string outFormat = "";
std::string outFile = "";
std::string filter = "";
for (int i = 1; i < argc; i++) {
std::string input(argv[i]);

int equalPos = input.find("=");
std::string arg = std::string(input.begin() + 2, input.begin() + equalPos);
std::string argVal = std::string(input.begin() + 1 + equalPos, input.end());
std::string consoleFormat = "";

if (arg == "benchmark_out_format")
outFormat = argVal;
else if (arg == "benchmark_out")
outFile = argVal;
else if (arg == "benchmark_filter")
filter = argVal;
}
getFormats(argc, argv, outFormat, filter, consoleFormat);

benchmark::Initialize(&argc, argv);

Expand Down Expand Up @@ -818,27 +805,19 @@ int main(int argc, char *argv[]) {
b->Unit(benchmark::kMillisecond);
}

if (outFormat == "csv") {
std::string spec = (filter == "" || filter == "all") ? "." : filter;
std::ofstream output_file;
benchmark::BenchmarkReporter *console_reporter =
getConsoleReporter(consoleFormat);
benchmark::BenchmarkReporter *out_file_reporter =
getOutFileReporter(outFormat);

benchmark::ConsoleReporter console_reporter;
benchmark::customCSVReporter csv_reporter;
std::string spec = (filter == "" || filter == "all") ? "." : filter;

auto &Err = console_reporter.GetErrorStream();

csv_reporter.SetOutputStream(&output_file);
csv_reporter.SetErrorStream(&Err);

benchmark::BenchmarkReporter *console_ptr = &console_reporter;
benchmark::BenchmarkReporter *csv_ptr = &csv_reporter;

benchmark::RunSpecifiedBenchmarks(console_ptr, csv_ptr, spec);

} else {
// Run benchmarks
benchmark::RunSpecifiedBenchmarks();
}
// Run benchmarks
if (outFormat == "") // default case
benchmark::RunSpecifiedBenchmarks(console_reporter, spec);
else
benchmark::RunSpecifiedBenchmarks(console_reporter, out_file_reporter,
spec);
HIP_CHECK(hipStreamDestroy(stream));

return 0;
Expand Down
53 changes: 15 additions & 38 deletions benchmark/benchmark_rocrand_host_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,28 +117,13 @@ void run_benchmark(benchmark::State &state, generate_func_type<T> generate_func,

int main(int argc, char *argv[]) {

// get the out format and out file name thats being passed into
// get paramaters before they are passed into
// benchmark::Initialize()
std::string outFormat = "";
std::string outFile = "";
std::string filter = "";
for (int i = 1; i < argc; i++) {
std::string input(argv[i]);
int equalPos = input.find("=");

if(equalPos < 0)
continue;

std::string arg = std::string(input.begin() + 2, input.begin() + equalPos);
std::string argVal = std::string(input.begin() + 1 + equalPos, input.end());

if (arg == "benchmark_out_format")
outFormat = argVal;
else if (arg == "benchmark_out")
outFile = argVal;
else if (arg == "benchmark_filter")
filter = argVal;
}
std::string consoleFormat = "";

getFormats(argc, argv, outFormat, filter, consoleFormat);

// Parse argv
benchmark::Initialize(&argc, argv);
Expand Down Expand Up @@ -376,27 +361,19 @@ int main(int argc, char *argv[]) {
b->Unit(benchmark::kMillisecond);
}

if (outFormat == "csv") {
std::string spec = (filter == "" || filter == "all") ? "." : filter;
std::ofstream output_file;

benchmark::ConsoleReporter console_reporter;
benchmark::customCSVReporter csv_reporter;

auto &Err = console_reporter.GetErrorStream();
benchmark::BenchmarkReporter *console_reporter =
getConsoleReporter(consoleFormat);
benchmark::BenchmarkReporter *out_file_reporter =
getOutFileReporter(outFormat);

csv_reporter.SetOutputStream(&output_file);
csv_reporter.SetErrorStream(&Err);
std::string spec = (filter == "" || filter == "all") ? "." : filter;

benchmark::BenchmarkReporter *console_ptr = &console_reporter;
benchmark::BenchmarkReporter *csv_ptr = &csv_reporter;

benchmark::RunSpecifiedBenchmarks(console_ptr, csv_ptr, spec);

} else {
// Run benchmarks
benchmark::RunSpecifiedBenchmarks();
}
// Run benchmarks
if (outFormat == "") // default case
benchmark::RunSpecifiedBenchmarks(console_reporter, spec);
else
benchmark::RunSpecifiedBenchmarks(console_reporter, out_file_reporter,
spec);

HIP_CHECK(hipStreamDestroy(stream));

Expand Down
Loading

0 comments on commit b56cb56

Please sign in to comment.