-
Notifications
You must be signed in to change notification settings - Fork 70
/
test_utils_hipgraphs.hpp
91 lines (73 loc) · 3.39 KB
/
test_utils_hipgraphs.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
// Copyright (c) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
#ifndef ROCRAND_TEST_UTILS_HIPGRAPHS_HPP
#define ROCRAND_TEST_UTILS_HIPGRAPHS_HPP
#include <hip/hip_runtime.h>
#include "test_common.hpp"
// Helper functions for testing with hipGraph stream capture.
// Note: graphs will not work on the default stream.
namespace test_utils
{
inline hipGraph_t createGraphHelper(hipStream_t& stream, const bool beginCapture=true)
{
// Create a new graph
hipGraph_t graph;
HIP_CHECK_NON_VOID(hipGraphCreate(&graph, 0));
// Optionally begin stream capture
if (beginCapture)
HIP_CHECK_NON_VOID(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal));
return graph;
}
inline void cleanupGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance)
{
HIP_CHECK_NON_VOID(hipGraphDestroy(graph));
HIP_CHECK_NON_VOID(hipGraphExecDestroy(instance));
}
inline void resetGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance, hipStream_t& stream, const bool beginCapture=true)
{
// Destroy the old graph and instance
cleanupGraphHelper(graph, instance);
// Create a new graph and optionally begin capture
graph = createGraphHelper(stream, beginCapture);
}
inline hipGraphExec_t endCaptureGraphHelper(hipGraph_t& graph, hipStream_t& stream, const bool launchGraph=false, const bool sync=false)
{
// End the capture
HIP_CHECK_NON_VOID(hipStreamEndCapture(stream, &graph));
// Instantiate the graph
hipGraphExec_t instance;
HIP_CHECK_NON_VOID(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0));
// Optionally launch the graph
if (launchGraph)
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream));
// Optionally synchronize the stream when we're done
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
return instance;
}
inline void launchGraphHelper(hipGraphExec_t& instance, hipStream_t& stream, const bool sync=false)
{
HIP_CHECK_NON_VOID(hipGraphLaunch(instance, stream));
// Optionally sync after the launch
if (sync)
HIP_CHECK_NON_VOID(hipStreamSynchronize(stream));
}
} // end namespace test_utils
#endif //ROCRAND_TEST_UTILS_HIPGRAPHS_HPP