Skip to content

Commit adf2e98

Browse files
ermilovmaximtensorflower-gardener
authored andcommitted
[XLA:GPU] Enable chlo.sinh -> kSinh HloInstruction lowering.
PiperOrigin-RevId: 815903539
1 parent 3bc0da7 commit adf2e98

File tree

18 files changed

+166
-11
lines changed

18 files changed

+166
-11
lines changed

third_party/xla/xla/backends/gpu/codegen/emitters/transforms/optimize_loops.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ bool IsExpensiveToUnroll(mlir::Operation* op) {
5454
mlir::math::AcosOp,
5555
mlir::math::AcoshOp,
5656
mlir::math::AtanhOp,
57+
mlir::math::SinhOp,
5758
mlir::scf::ForOp
5859
// go/keep-sorted end
5960
// clang-format on

third_party/xla/xla/hlo/builder/lib/math.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1399,7 +1399,11 @@ XlaOp Cosh(XlaOp x, const std::optional<ResultAccuracy>& result_accuracy,
13991399
// +/-89.4159851, due to rounding error when computing x +/- log(1/2). The
14001400
// correct answer of 3.40281961e+38 (0x7f7fffec) is very close to max-float, so
14011401
// we deem this acceptable.
1402-
XlaOp Sinh(XlaOp x) {
1402+
XlaOp Sinh(XlaOp x, const std::optional<ResultAccuracy>& result_accuracy,
1403+
bool expand) {
1404+
if (!expand) {
1405+
return x.builder()->UnaryOp(HloOpcode::kSinh, x, result_accuracy);
1406+
}
14031407
XlaBuilder* b = x.builder();
14041408
auto do_it = [&](XlaOp x) -> absl::StatusOr<XlaOp> {
14051409
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));

third_party/xla/xla/hlo/builder/lib/math.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ XlaOp Cosh(XlaOp x,
115115
bool expand = true);
116116

117117
// Computes the hyperbolic sine of 'x'.
118-
XlaOp Sinh(XlaOp x);
118+
XlaOp Sinh(XlaOp x,
119+
const std::optional<ResultAccuracy>& result_accuracy = std::nullopt,
120+
bool expand = true);
119121

120122
// Applies a complex conjugation operation if 'a' is complex and 'conjugate'
121123
// is true, otherwise returns its argument.

third_party/xla/xla/hlo/builder/xla_builder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,6 +1772,9 @@ class XlaBuilder {
17721772
bool expand);
17731773
friend XlaOp Sin(XlaOp operand,
17741774
const std::optional<ResultAccuracy>& result_accuracy);
1775+
friend XlaOp Sinh(XlaOp x,
1776+
const std::optional<ResultAccuracy>& result_accuracy,
1777+
bool expand);
17751778
friend XlaOp Tan(XlaOp operand,
17761779
const std::optional<ResultAccuracy>& result_accuracy);
17771780
friend XlaOp Tanh(XlaOp operand,

third_party/xla/xla/hlo/ir/hlo_instruction.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2759,6 +2759,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
27592759
case HloOpcode::kLogistic:
27602760
case HloOpcode::kSign:
27612761
case HloOpcode::kSin:
2762+
case HloOpcode::kSinh:
27622763
case HloOpcode::kSqrt:
27632764
case HloOpcode::kCbrt:
27642765
case HloOpcode::kTan:

third_party/xla/xla/hlo/translate/mhlo_to_hlo/gen_hlo_op_writer.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ defvar CustomHloConverterOps = [
257257
MHLO_SendOp,
258258
MHLO_SetDimensionSizeOp,
259259
MHLO_SineOp,
260+
MHLO_SinhOp,
260261
MHLO_SortOp,
261262
MHLO_StochasticConvertOp,
262263
MHLO_SubtractOp,

third_party/xla/xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5246,6 +5246,17 @@ LogicalResult ExportXlaOp(CoshOp op, OpLoweringContext ctx) {
52465246
return success();
52475247
}
52485248

5249+
LogicalResult ExportXlaOp(SinhOp op, OpLoweringContext ctx) {
5250+
auto& value_map = *ctx.values;
5251+
xla::XlaOp operand;
5252+
if (failed(GetXlaOp(op.getOperand(), value_map, &operand, op))) {
5253+
return failure();
5254+
}
5255+
value_map[op] =
5256+
xla::Sinh(operand, /*result_accuracy=*/std::nullopt, /*expand=*/false);
5257+
return success();
5258+
}
5259+
52495260
LogicalResult ExportXlaOp(AcoshOp op, OpLoweringContext ctx) {
52505261
return ExportElementwiseXlaOp<AcoshOp, xla::Acosh>(op, ctx);
52515262
}

third_party/xla/xla/hlo/translate/mhlo_to_hlo/tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ lit_test_suite(
1919
"case.mlir",
2020
"composite.mlir",
2121
"cosh.mlir",
22+
"sinh.mlir",
2223
"dynamic.mlir",
2324
"export-with-layouts.mlir",
2425
"export.mlir",
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: xla-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
2+
3+
module {
4+
func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
5+
// CHECK: f32[4] sinh
6+
%0 = "mhlo.sinh"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
7+
func.return %0 : tensor<4xf32>
8+
}
9+
}

third_party/xla/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo_pass.cc

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,14 @@ namespace {
5252
ChloLegalizeToHighLevelMhloPassOptions FromPassOptions(bool enableAcosh,
5353
bool enableAcos,
5454
bool enableAtanh,
55-
bool enableCosh) {
55+
bool enableCosh,
56+
bool enableSinh) {
5657
ChloLegalizeToHighLevelMhloPassOptions options;
5758
options.enable_acosh_ = enableAcosh;
5859
options.enable_acos_ = enableAcos;
5960
options.enable_atanh_ = enableAtanh;
6061
options.enable_cosh_ = enableCosh;
62+
options.enable_sinh_ = enableSinh;
6163
return options;
6264
}
6365

@@ -77,6 +79,10 @@ static bool qualifiesForDirectMhloLoweringCosh(chlo::CoshOp op) {
7779
return llvm::isa<FloatType>(getElementTypeOrSelf(op.getType()));
7880
}
7981

82+
static bool qualifiesForDirectMhloLoweringSinh(chlo::SinhOp op) {
83+
return llvm::isa<FloatType>(getElementTypeOrSelf(op.getType()));
84+
}
85+
8086
struct ChloLegalizeToHighLevelMhloPass
8187
: public impl::ChloLegalizeToHighLevelMhloPassBase<
8288
ChloLegalizeToHighLevelMhloPass> {
@@ -94,7 +100,7 @@ struct ChloLegalizeToHighLevelMhloPass
94100
chlo::populateChloToHighLevelMhloOpPatterns(
95101
&context, &conversionPatterns,
96102
FromPassOptions(enable_acosh_, enable_acos_, enable_atanh_,
97-
enable_cosh_));
103+
enable_cosh_, enable_sinh_));
98104

99105
// Consider the mhlo dialect legal for tests. Also add helper dialects
100106
// that are needed by the patterns.
@@ -121,6 +127,11 @@ struct ChloLegalizeToHighLevelMhloPass
121127
return !qualifiesForDirectMhloLoweringCosh(op);
122128
});
123129
}
130+
if (enable_sinh_) {
131+
conversionTarget.addDynamicallyLegalOp<chlo::SinhOp>([](chlo::SinhOp op) {
132+
return !qualifiesForDirectMhloLoweringSinh(op);
133+
});
134+
}
124135
conversionTarget
125136
.addIllegalOp<chlo::TopKOp, chlo::ErfOp, chlo::RaggedDotOp>();
126137

@@ -254,6 +265,15 @@ LogicalResult convertCoshChloToMhlo(chlo::CoshOp op,
254265
return success();
255266
}
256267

268+
LogicalResult convertSinhChloToMhlo(chlo::SinhOp op,
269+
PatternRewriter& rewriter) {
270+
if (!mhlo::qualifiesForDirectMhloLoweringSinh(op)) {
271+
return failure();
272+
}
273+
rewriter.replaceOpWithNewOp<mhlo::SinhOp>(op, op->getOperands());
274+
return success();
275+
}
276+
257277
} // namespace
258278

259279
ChloLegalizeToHighLevelMhloPassOptions getDefaultChloToHighLevelMhloOptions() {
@@ -266,6 +286,7 @@ ChloLegalizeToHighLevelMhloPassOptions getGpuChloToHighLevelMhloOptions() {
266286
opts.enable_acos_ = true;
267287
opts.enable_atanh_ = true;
268288
opts.enable_cosh_ = true;
289+
opts.enable_sinh_ = true;
269290
return opts;
270291
}
271292

@@ -292,6 +313,9 @@ void populateChloToHighLevelMhloOpPatterns(
292313
if (options.enable_cosh_) {
293314
patterns->add(mhlo::convertCoshChloToMhlo, kBenefit);
294315
}
316+
if (options.enable_sinh_) {
317+
patterns->add(mhlo::convertSinhChloToMhlo, kBenefit);
318+
}
295319
patterns->add(mhlo::convertRaggedDotChloToMhlo, kBenefit);
296320
populateWithGenerated(*patterns);
297321
}

0 commit comments

Comments
 (0)