@@ -52,12 +52,14 @@ namespace {
52
52
ChloLegalizeToHighLevelMhloPassOptions FromPassOptions (bool enableAcosh,
53
53
bool enableAcos,
54
54
bool enableAtanh,
55
- bool enableCosh) {
55
+ bool enableCosh,
56
+ bool enableSinh) {
56
57
ChloLegalizeToHighLevelMhloPassOptions options;
57
58
options.enable_acosh_ = enableAcosh;
58
59
options.enable_acos_ = enableAcos;
59
60
options.enable_atanh_ = enableAtanh;
60
61
options.enable_cosh_ = enableCosh;
62
+ options.enable_sinh_ = enableSinh;
61
63
return options;
62
64
}
63
65
@@ -77,6 +79,10 @@ static bool qualifiesForDirectMhloLoweringCosh(chlo::CoshOp op) {
77
79
return llvm::isa<FloatType>(getElementTypeOrSelf (op.getType ()));
78
80
}
79
81
82
+ static bool qualifiesForDirectMhloLoweringSinh (chlo::SinhOp op) {
83
+ return llvm::isa<FloatType>(getElementTypeOrSelf (op.getType ()));
84
+ }
85
+
80
86
struct ChloLegalizeToHighLevelMhloPass
81
87
: public impl::ChloLegalizeToHighLevelMhloPassBase<
82
88
ChloLegalizeToHighLevelMhloPass> {
@@ -94,7 +100,7 @@ struct ChloLegalizeToHighLevelMhloPass
94
100
chlo::populateChloToHighLevelMhloOpPatterns (
95
101
&context, &conversionPatterns,
96
102
FromPassOptions (enable_acosh_, enable_acos_, enable_atanh_,
97
- enable_cosh_));
103
+ enable_cosh_, enable_sinh_ ));
98
104
99
105
// Consider the mhlo dialect legal for tests. Also add helper dialects
100
106
// that are needed by the patterns.
@@ -121,6 +127,11 @@ struct ChloLegalizeToHighLevelMhloPass
121
127
return !qualifiesForDirectMhloLoweringCosh (op);
122
128
});
123
129
}
130
+ if (enable_sinh_) {
131
+ conversionTarget.addDynamicallyLegalOp <chlo::SinhOp>([](chlo::SinhOp op) {
132
+ return !qualifiesForDirectMhloLoweringSinh (op);
133
+ });
134
+ }
124
135
conversionTarget
125
136
.addIllegalOp <chlo::TopKOp, chlo::ErfOp, chlo::RaggedDotOp>();
126
137
@@ -254,6 +265,15 @@ LogicalResult convertCoshChloToMhlo(chlo::CoshOp op,
254
265
return success ();
255
266
}
256
267
268
+ LogicalResult convertSinhChloToMhlo (chlo::SinhOp op,
269
+ PatternRewriter& rewriter) {
270
+ if (!mhlo::qualifiesForDirectMhloLoweringSinh (op)) {
271
+ return failure ();
272
+ }
273
+ rewriter.replaceOpWithNewOp <mhlo::SinhOp>(op, op->getOperands ());
274
+ return success ();
275
+ }
276
+
257
277
} // namespace
258
278
259
279
ChloLegalizeToHighLevelMhloPassOptions getDefaultChloToHighLevelMhloOptions () {
@@ -266,6 +286,7 @@ ChloLegalizeToHighLevelMhloPassOptions getGpuChloToHighLevelMhloOptions() {
266
286
opts.enable_acos_ = true ;
267
287
opts.enable_atanh_ = true ;
268
288
opts.enable_cosh_ = true ;
289
+ opts.enable_sinh_ = true ;
269
290
return opts;
270
291
}
271
292
@@ -292,6 +313,9 @@ void populateChloToHighLevelMhloOpPatterns(
292
313
if (options.enable_cosh_ ) {
293
314
patterns->add (mhlo::convertCoshChloToMhlo, kBenefit );
294
315
}
316
+ if (options.enable_sinh_ ) {
317
+ patterns->add (mhlo::convertSinhChloToMhlo, kBenefit );
318
+ }
295
319
patterns->add (mhlo::convertRaggedDotChloToMhlo, kBenefit );
296
320
populateWithGenerated (*patterns);
297
321
}
0 commit comments