Skip to content

Commit 4c61287

Browse files
ermilovmaximtensorflower-gardener
authored andcommitted
[XLA:GPU] move RocmComputeCapability into its own header
PiperOrigin-RevId: 816348852
1 parent 28b5e42 commit 4c61287

File tree

5 files changed

+228
-188
lines changed

5 files changed

+228
-188
lines changed

third_party/xla/xla/service/conditional_code_motion.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@ limitations under the License.
3030
#include "absl/log/log.h"
3131
#include "absl/status/status.h"
3232
#include "absl/status/statusor.h"
33+
#include "absl/strings/numbers.h"
3334
#include "absl/strings/str_cat.h"
35+
#include "absl/strings/str_split.h"
36+
#include "absl/strings/string_view.h"
3437
#include "absl/types/span.h"
3538
#include "xla/debug_options_flags.h"
3639
#include "xla/hlo/ir/hlo_casting_utils.h"
40+
#include "xla/hlo/ir/hlo_clone_context.h"
3741
#include "xla/hlo/ir/hlo_computation.h"
3842
#include "xla/hlo/ir/hlo_instruction.h"
3943
#include "xla/hlo/ir/hlo_instructions.h"

third_party/xla/xla/stream_executor/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,12 @@ cc_library(
6363
":launch_dim",
6464
":semantic_version",
6565
"//xla/stream_executor/cuda:cuda_compute_capability",
66+
"//xla/stream_executor/rocm:rocm_compute_capability",
6667
"//xla/tsl/lib/math:math_util",
6768
"//xla/tsl/platform:statusor",
68-
"@com_google_absl//absl/algorithm:container",
6969
"@com_google_absl//absl/log",
7070
"@com_google_absl//absl/log:check",
7171
"@com_google_absl//absl/status:statusor",
72-
"@com_google_absl//absl/strings",
7372
],
7473
)
7574

third_party/xla/xla/stream_executor/device_description.h

Lines changed: 1 addition & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -20,207 +20,22 @@ limitations under the License.
2020
#ifndef XLA_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
2121
#define XLA_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
2222

23-
#include <algorithm>
2423
#include <cassert>
2524
#include <cstdint>
26-
#include <cstring>
2725
#include <string>
2826
#include <type_traits>
2927
#include <utility>
3028
#include <variant>
31-
#include <vector>
3229

33-
#include "absl/algorithm/container.h"
3430
#include "absl/status/statusor.h"
35-
#include "absl/strings/match.h"
36-
#include "absl/strings/str_join.h"
37-
#include "absl/strings/str_split.h"
38-
#include "absl/strings/string_view.h"
3931
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
4032
#include "xla/stream_executor/device_description.pb.h"
4133
#include "xla/stream_executor/launch_dim.h"
34+
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
4235
#include "xla/stream_executor/semantic_version.h"
4336

4437
namespace stream_executor {
4538

46-
// ROCm compute capability, as reported by the device description.
47-
class RocmComputeCapability {
48-
public:
49-
// gcn_arch_name example -- gfx90a:sramecc+:xnack-
50-
// gfx_version is the "gfx90a" part of the gcn_arch_name
51-
explicit RocmComputeCapability(std::string gcn_arch_name)
52-
: gcn_arch_name_(std::move(gcn_arch_name)) {}
53-
54-
explicit RocmComputeCapability(const RocmComputeCapabilityProto& proto)
55-
: gcn_arch_name_(proto.gcn_arch_name()) {}
56-
57-
RocmComputeCapability() = default;
58-
59-
std::string gcn_arch_name() const { return gcn_arch_name_; }
60-
61-
std::string ToString() const { return gcn_arch_name(); }
62-
63-
RocmComputeCapabilityProto ToProto() const {
64-
RocmComputeCapabilityProto proto;
65-
proto.set_gcn_arch_name(gcn_arch_name_);
66-
return proto;
67-
}
68-
69-
bool operator==(const RocmComputeCapability& other) const {
70-
return gcn_arch_name_ == other.gcn_arch_name_;
71-
}
72-
73-
bool operator!=(const RocmComputeCapability& other) const {
74-
return !this->operator==(other);
75-
}
76-
77-
std::string gfx_version() const {
78-
// std::strchr() is faster for the case than std::string::find()
79-
const char* const p_colon = std::strchr(gcn_arch_name_.c_str(), ':');
80-
if (nullptr == p_colon) {
81-
return gcn_arch_name_; // likely it's the default invalid value
82-
}
83-
return std::string(gcn_arch_name_.c_str(), p_colon);
84-
}
85-
86-
// note, while there's no particular reason to make the lists public, it won't
87-
// hurt since they are immutable, but keeping them close to methods simplifies
88-
// maintanance.
89-
static constexpr absl::string_view kSupportedGfxVersions[]{
90-
"gfx900", // MI25
91-
"gfx906", // MI50 / MI60
92-
"gfx908", // MI100
93-
"gfx90a", // MI200
94-
"gfx942", // MI300
95-
"gfx950", // MI350
96-
"gfx1030", // RX68xx / RX69xx
97-
"gfx1100", // RX7900
98-
"gfx1101", // RX7700 / RX7800
99-
"gfx1103", "gfx1150", "gfx1151", "gfx1200", "gfx1201",
100-
};
101-
102-
bool is_supported_gfx_version() const {
103-
return IsThisGfxInAnyList(kSupportedGfxVersions);
104-
}
105-
106-
std::string supported_gfx_versions_str() const {
107-
return absl::StrJoin(kSupportedGfxVersions, ", ");
108-
}
109-
110-
bool gfx9_mi100() const { return gfx_version() == "gfx908"; }
111-
112-
static constexpr absl::string_view kMI100Series[] = {"gfx908"};
113-
114-
bool gfx9_mi200() const { return gfx_version() == "gfx90a"; }
115-
116-
static constexpr absl::string_view kMI200Series[] = {"gfx90a"};
117-
118-
bool gfx9_mi300() const { return gfx_version() == "gfx942"; }
119-
120-
bool gfx9_mi350() const { return gfx_version() == "gfx950"; }
121-
122-
static constexpr absl::string_view kMI300Series[] = {"gfx942", "gfx950"};
123-
bool gfx9_mi300_series() const { return IsThisGfxInAnyList(kMI300Series); }
124-
125-
bool gfx9_mi100_or_later() const {
126-
return IsThisGfxInAnyList(kMI300Series, kMI200Series, kMI100Series);
127-
}
128-
129-
bool gfx9_mi200_or_later() const {
130-
return IsThisGfxInAnyList(kMI300Series, kMI200Series);
131-
}
132-
133-
bool gfx10_rx68xx() const { return gfx_version() == "gfx1030"; }
134-
135-
bool gfx10_rx69xx() const { return gfx_version() == "gfx1030"; }
136-
137-
bool gfx11() const { return absl::StartsWith(gfx_version(), "gfx11"); }
138-
139-
static constexpr absl::string_view kGfx11Discrete[] = {"gfx1100", "gfx1101"};
140-
bool gfx11_discrete() const { return IsThisGfxInAnyList(kGfx11Discrete); }
141-
142-
static constexpr absl::string_view kGfx11Apu[] = {"gfx1103", "gfx1150",
143-
"gfx1151"};
144-
bool gfx11_apu() const { return IsThisGfxInAnyList(kGfx11Apu); }
145-
146-
static constexpr absl::string_view kGfx11Rx7900[] = {"gfx1100", "gfx1101",
147-
"gfx1102"};
148-
bool gfx11_rx7900() const {
149-
// TODO(AMD/TF): instead of this, other gfx11*() methods might be better
150-
return IsThisGfxInAnyList(kGfx11Rx7900);
151-
}
152-
153-
bool gfx12() const { return absl::StartsWith(gfx_version(), "gfx12"); }
154-
155-
static constexpr absl::string_view kGfx12Discrete[] = {"gfx1200", "gfx1201"};
156-
bool gfx12_discrete() const { return IsThisGfxInAnyList(kGfx12Discrete); }
157-
158-
bool gfx12_rx8900() const { return gfx12_discrete(); }
159-
160-
bool has_nhwc_layout_support() const { return gfx9_mi100_or_later(); }
161-
162-
bool has_bf16_dtype_support() const {
163-
return gfx9_mi100_or_later() || gfx12() || gfx11();
164-
}
165-
166-
bool has_fast_fp16_support() const {
167-
return gfx9_mi100_or_later() || gfx11() || gfx10_rx68xx() || gfx10_rx69xx();
168-
}
169-
170-
bool has_mfma_instr_support() const { return gfx9_mi100_or_later(); }
171-
172-
bool has_amd_matrix_core() const {
173-
return gfx9_mi100_or_later() || gfx12() || gfx11();
174-
}
175-
176-
bool has_packed_fp16_atomics_support() const { return gfx9_mi100_or_later(); }
177-
178-
bool has_packed_bf16_atomics_support() const { return gfx9_mi300_series(); }
179-
180-
bool fence_before_barrier() const {
181-
static constexpr absl::string_view kList[] = {"gfx900", "gfx906"};
182-
return !IsThisGfxInAnyList(kList);
183-
}
184-
185-
bool has_hipblaslt() const {
186-
return IsThisGfxInAnyList(kMI300Series, kMI200Series, kGfx12Discrete,
187-
kGfx11Discrete, kGfx11Apu);
188-
}
189-
190-
bool has_fp8_support() const {
191-
return has_ocp_fp8_support() || has_nanoo_fp8_support();
192-
}
193-
194-
bool has_ocp_fp8_support() const { return gfx9_mi350() || gfx12_discrete(); }
195-
196-
bool has_nanoo_fp8_support() const { return gfx9_mi300(); }
197-
198-
/// \brief Invalid gfx id for default gcn_arch_name_ value and testing
199-
static constexpr absl::string_view kInvalidGfx = "gfx000";
200-
201-
private:
202-
/// \brief Takes one or more arrays of string-like objects and tests if the
203-
/// result of `gfx_version()` matches to any string in any of the arrays.
204-
template <typename... ArrayOfStrings>
205-
bool IsThisGfxInAnyList(ArrayOfStrings&&... arr) const {
206-
static_assert(sizeof...(arr) >= 1);
207-
const auto gfx = gfx_version();
208-
return (implIsThisGfxInAnyList(std::begin(arr), std::end(arr), gfx) || ...);
209-
}
210-
211-
/// \brief Template-less implementation of IsThisGfxInAnyList().
212-
/// \warning Don't use directly!
213-
bool implIsThisGfxInAnyList(const absl::string_view* beg,
214-
const absl::string_view* end,
215-
const std::string& gfx) const {
216-
return std::any_of(beg, end, [&gfx = gfx](const absl::string_view& s) {
217-
return gfx == s;
218-
});
219-
}
220-
221-
std::string gcn_arch_name_{kInvalidGfx}; // default to invalid arch.
222-
};
223-
22439
using GpuComputeCapability =
22540
std::variant<CudaComputeCapability, RocmComputeCapability>;
22641

third_party/xla/xla/stream_executor/rocm/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,15 @@ package_group(
3030
packages = stream_executor_friends(),
3131
)
3232

33+
cc_library(
34+
name = "rocm_compute_capability",
35+
hdrs = ["rocm_compute_capability.h"],
36+
deps = [
37+
"//xla/stream_executor:device_description_proto_cc",
38+
"@com_google_absl//absl/strings",
39+
],
40+
)
41+
3342
cc_library(
3443
name = "rocm_diagnostics",
3544
srcs = ["rocm_diagnostics.cc"],

0 commit comments

Comments
 (0)