@@ -20,207 +20,22 @@ limitations under the License.
20
20
#ifndef XLA_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
21
21
#define XLA_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
22
22
23
- #include < algorithm>
24
23
#include < cassert>
25
24
#include < cstdint>
26
- #include < cstring>
27
25
#include < string>
28
26
#include < type_traits>
29
27
#include < utility>
30
28
#include < variant>
31
- #include < vector>
32
29
33
- #include " absl/algorithm/container.h"
34
30
#include " absl/status/statusor.h"
35
- #include " absl/strings/match.h"
36
- #include " absl/strings/str_join.h"
37
- #include " absl/strings/str_split.h"
38
- #include " absl/strings/string_view.h"
39
31
#include " xla/stream_executor/cuda/cuda_compute_capability.h"
40
32
#include " xla/stream_executor/device_description.pb.h"
41
33
#include " xla/stream_executor/launch_dim.h"
34
+ #include " xla/stream_executor/rocm/rocm_compute_capability.h"
42
35
#include " xla/stream_executor/semantic_version.h"
43
36
44
37
namespace stream_executor {
45
38
46
- // ROCm compute capability, as reported by the device description.
47
- class RocmComputeCapability {
48
- public:
49
- // gcn_arch_name example -- gfx90a:sramecc+:xnack-
50
- // gfx_version is the "gfx90a" part of the gcn_arch_name
51
- explicit RocmComputeCapability (std::string gcn_arch_name)
52
- : gcn_arch_name_(std::move(gcn_arch_name)) {}
53
-
54
- explicit RocmComputeCapability (const RocmComputeCapabilityProto& proto)
55
- : gcn_arch_name_(proto.gcn_arch_name()) {}
56
-
57
- RocmComputeCapability () = default ;
58
-
59
- std::string gcn_arch_name () const { return gcn_arch_name_; }
60
-
61
- std::string ToString () const { return gcn_arch_name (); }
62
-
63
- RocmComputeCapabilityProto ToProto () const {
64
- RocmComputeCapabilityProto proto;
65
- proto.set_gcn_arch_name (gcn_arch_name_);
66
- return proto;
67
- }
68
-
69
- bool operator ==(const RocmComputeCapability& other) const {
70
- return gcn_arch_name_ == other.gcn_arch_name_ ;
71
- }
72
-
73
- bool operator !=(const RocmComputeCapability& other) const {
74
- return !this ->operator ==(other);
75
- }
76
-
77
- std::string gfx_version () const {
78
- // std::strchr() is faster for the case than std::string::find()
79
- const char * const p_colon = std::strchr (gcn_arch_name_.c_str (), ' :' );
80
- if (nullptr == p_colon) {
81
- return gcn_arch_name_; // likely it's the default invalid value
82
- }
83
- return std::string (gcn_arch_name_.c_str (), p_colon);
84
- }
85
-
86
- // note, while there's no particular reason to make the lists public, it won't
87
- // hurt since they are immutable, but keeping them close to methods simplifies
88
- // maintanance.
89
- static constexpr absl::string_view kSupportedGfxVersions []{
90
- " gfx900" , // MI25
91
- " gfx906" , // MI50 / MI60
92
- " gfx908" , // MI100
93
- " gfx90a" , // MI200
94
- " gfx942" , // MI300
95
- " gfx950" , // MI350
96
- " gfx1030" , // RX68xx / RX69xx
97
- " gfx1100" , // RX7900
98
- " gfx1101" , // RX7700 / RX7800
99
- " gfx1103" , " gfx1150" , " gfx1151" , " gfx1200" , " gfx1201" ,
100
- };
101
-
102
- bool is_supported_gfx_version () const {
103
- return IsThisGfxInAnyList (kSupportedGfxVersions );
104
- }
105
-
106
- std::string supported_gfx_versions_str () const {
107
- return absl::StrJoin (kSupportedGfxVersions , " , " );
108
- }
109
-
110
- bool gfx9_mi100 () const { return gfx_version () == " gfx908" ; }
111
-
112
- static constexpr absl::string_view kMI100Series [] = {" gfx908" };
113
-
114
- bool gfx9_mi200 () const { return gfx_version () == " gfx90a" ; }
115
-
116
- static constexpr absl::string_view kMI200Series [] = {" gfx90a" };
117
-
118
- bool gfx9_mi300 () const { return gfx_version () == " gfx942" ; }
119
-
120
- bool gfx9_mi350 () const { return gfx_version () == " gfx950" ; }
121
-
122
- static constexpr absl::string_view kMI300Series [] = {" gfx942" , " gfx950" };
123
- bool gfx9_mi300_series () const { return IsThisGfxInAnyList (kMI300Series ); }
124
-
125
- bool gfx9_mi100_or_later () const {
126
- return IsThisGfxInAnyList (kMI300Series , kMI200Series , kMI100Series );
127
- }
128
-
129
- bool gfx9_mi200_or_later () const {
130
- return IsThisGfxInAnyList (kMI300Series , kMI200Series );
131
- }
132
-
133
- bool gfx10_rx68xx () const { return gfx_version () == " gfx1030" ; }
134
-
135
- bool gfx10_rx69xx () const { return gfx_version () == " gfx1030" ; }
136
-
137
- bool gfx11 () const { return absl::StartsWith (gfx_version (), " gfx11" ); }
138
-
139
- static constexpr absl::string_view kGfx11Discrete [] = {" gfx1100" , " gfx1101" };
140
- bool gfx11_discrete () const { return IsThisGfxInAnyList (kGfx11Discrete ); }
141
-
142
- static constexpr absl::string_view kGfx11Apu [] = {" gfx1103" , " gfx1150" ,
143
- " gfx1151" };
144
- bool gfx11_apu () const { return IsThisGfxInAnyList (kGfx11Apu ); }
145
-
146
- static constexpr absl::string_view kGfx11Rx7900 [] = {" gfx1100" , " gfx1101" ,
147
- " gfx1102" };
148
- bool gfx11_rx7900 () const {
149
- // TODO(AMD/TF): instead of this, other gfx11*() methods might be better
150
- return IsThisGfxInAnyList (kGfx11Rx7900 );
151
- }
152
-
153
- bool gfx12 () const { return absl::StartsWith (gfx_version (), " gfx12" ); }
154
-
155
- static constexpr absl::string_view kGfx12Discrete [] = {" gfx1200" , " gfx1201" };
156
- bool gfx12_discrete () const { return IsThisGfxInAnyList (kGfx12Discrete ); }
157
-
158
- bool gfx12_rx8900 () const { return gfx12_discrete (); }
159
-
160
- bool has_nhwc_layout_support () const { return gfx9_mi100_or_later (); }
161
-
162
- bool has_bf16_dtype_support () const {
163
- return gfx9_mi100_or_later () || gfx12 () || gfx11 ();
164
- }
165
-
166
- bool has_fast_fp16_support () const {
167
- return gfx9_mi100_or_later () || gfx11 () || gfx10_rx68xx () || gfx10_rx69xx ();
168
- }
169
-
170
- bool has_mfma_instr_support () const { return gfx9_mi100_or_later (); }
171
-
172
- bool has_amd_matrix_core () const {
173
- return gfx9_mi100_or_later () || gfx12 () || gfx11 ();
174
- }
175
-
176
- bool has_packed_fp16_atomics_support () const { return gfx9_mi100_or_later (); }
177
-
178
- bool has_packed_bf16_atomics_support () const { return gfx9_mi300_series (); }
179
-
180
- bool fence_before_barrier () const {
181
- static constexpr absl::string_view kList [] = {" gfx900" , " gfx906" };
182
- return !IsThisGfxInAnyList (kList );
183
- }
184
-
185
- bool has_hipblaslt () const {
186
- return IsThisGfxInAnyList (kMI300Series , kMI200Series , kGfx12Discrete ,
187
- kGfx11Discrete , kGfx11Apu );
188
- }
189
-
190
- bool has_fp8_support () const {
191
- return has_ocp_fp8_support () || has_nanoo_fp8_support ();
192
- }
193
-
194
- bool has_ocp_fp8_support () const { return gfx9_mi350 () || gfx12_discrete (); }
195
-
196
- bool has_nanoo_fp8_support () const { return gfx9_mi300 (); }
197
-
198
- // / \brief Invalid gfx id for default gcn_arch_name_ value and testing
199
- static constexpr absl::string_view kInvalidGfx = " gfx000" ;
200
-
201
- private:
202
- // / \brief Takes one or more arrays of string-like objects and tests if the
203
- // / result of `gfx_version()` matches to any string in any of the arrays.
204
- template <typename ... ArrayOfStrings>
205
- bool IsThisGfxInAnyList (ArrayOfStrings&&... arr) const {
206
- static_assert (sizeof ...(arr) >= 1 );
207
- const auto gfx = gfx_version ();
208
- return (implIsThisGfxInAnyList (std::begin (arr), std::end (arr), gfx) || ...);
209
- }
210
-
211
- // / \brief Template-less implementation of IsThisGfxInAnyList().
212
- // / \warning Don't use directly!
213
- bool implIsThisGfxInAnyList (const absl::string_view* beg,
214
- const absl::string_view* end,
215
- const std::string& gfx) const {
216
- return std::any_of (beg, end, [&gfx = gfx](const absl::string_view& s) {
217
- return gfx == s;
218
- });
219
- }
220
-
221
- std::string gcn_arch_name_{kInvalidGfx }; // default to invalid arch.
222
- };
223
-
224
39
using GpuComputeCapability =
225
40
std::variant<CudaComputeCapability, RocmComputeCapability>;
226
41
0 commit comments